test_patricia_trie.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #!/usr/bin/env python3
  2. """
  3. Tests for patricia_trie.py - CIDR Registry using py-radix
  4. """
  5. import uuid
  6. from gpustack.websocket_proxy.patricia_trie import CIDRRegistry
  7. class TestCIDRRegistry:
  8. """Tests for CIDRRegistry class."""
  9. def test_insert_and_find_exact_match(self):
  10. """Test basic insert and find for a single CIDR."""
  11. registry = CIDRRegistry()
  12. client_id = uuid.uuid4()
  13. registry.insert("10.0.0.0/8", client_id)
  14. result = registry.find_best_match("10.0.0.1")
  15. assert result == client_id
  16. def test_longest_prefix_match(self):
  17. """Test that more specific CIDRs take precedence over less specific ones."""
  18. registry = CIDRRegistry()
  19. client_class_a = uuid.uuid4()
  20. client_class_b = uuid.uuid4()
  21. registry.insert("10.0.0.0/8", client_class_a)
  22. registry.insert("10.1.0.0/16", client_class_b)
  23. # 10.1.x.x should match /16 (more specific)
  24. assert registry.find_best_match("10.1.0.1") == client_class_b
  25. # 10.0.x.x should match /8
  26. assert registry.find_best_match("10.0.0.1") == client_class_a
  27. # 10.5.x.x should match /8
  28. assert registry.find_best_match("10.5.5.5") == client_class_a
  29. def test_no_match(self):
  30. """Test that unmatched IPs return None."""
  31. registry = CIDRRegistry()
  32. client_id = uuid.uuid4()
  33. registry.insert("10.0.0.0/8", client_id)
  34. assert registry.find_best_match("192.168.1.1") is None
  35. assert registry.find_best_match("172.16.0.1") is None
  36. def test_default_route(self):
  37. """Test that 0.0.0.0/0 matches any IP."""
  38. registry = CIDRRegistry()
  39. default_client = uuid.uuid4()
  40. specific_client = uuid.uuid4()
  41. registry.insert("0.0.0.0/0", default_client)
  42. registry.insert("10.0.0.0/8", specific_client)
  43. assert registry.find_best_match("10.0.0.1") == specific_client
  44. assert registry.find_best_match("192.168.1.1") == default_client
  45. assert registry.find_best_match("8.8.8.8") == default_client
  46. def test_multiple_cidrs_same_client(self):
  47. """Test that the same client can have multiple CIDRs."""
  48. registry = CIDRRegistry()
  49. client_id = uuid.uuid4()
  50. registry.insert("10.0.0.0/8", client_id)
  51. registry.insert("172.16.0.0/12", client_id)
  52. assert registry.find_best_match("10.5.5.5") == client_id
  53. assert registry.find_best_match("172.16.0.1") == client_id
  54. assert registry.find_best_match("192.168.1.1") is None
  55. def test_exact_host_match(self):
  56. """Test /32 exact host match takes precedence over /24."""
  57. registry = CIDRRegistry()
  58. class_c_client = uuid.uuid4()
  59. host_client = uuid.uuid4()
  60. registry.insert("192.168.1.0/24", class_c_client)
  61. registry.insert("192.168.1.100/32", host_client)
  62. assert registry.find_best_match("192.168.1.100") == host_client
  63. assert registry.find_best_match("192.168.1.99") == class_c_client
  64. assert registry.find_best_match("192.168.1.101") == class_c_client
  65. def test_ipv6_support(self):
  66. """Test IPv6 CIDR matching."""
  67. registry = CIDRRegistry()
  68. client1 = uuid.uuid4()
  69. client2 = uuid.uuid4()
  70. registry.insert("2001:db8::/32", client1)
  71. registry.insert("2001:db8:1::/48", client2)
  72. assert registry.find_best_match("2001:db8::1") == client1
  73. assert registry.find_best_match("2001:db8:ffff::1") == client1
  74. assert registry.find_best_match("2001:db8:1::1") == client2
  75. assert registry.find_best_match("2001:db8:2::1") == client1
  76. assert registry.find_best_match("2001:dead::1") is None
  77. def test_remove_client(self):
  78. """Test removing a client's all CIDRs."""
  79. registry = CIDRRegistry()
  80. client1 = uuid.uuid4()
  81. client2 = uuid.uuid4()
  82. registry.insert("10.0.0.0/8", client1)
  83. registry.insert("10.1.0.0/16", client1)
  84. registry.insert("192.168.0.0/16", client2)
  85. # Verify both client1 CIDRs work
  86. assert registry.find_best_match("10.0.0.1") == client1
  87. assert registry.find_best_match("10.1.0.1") == client1
  88. # Remove client1
  89. registry.remove_client(client1)
  90. # client1's CIDRs should no longer match
  91. assert registry.find_best_match("10.0.0.1") is None
  92. assert registry.find_best_match("10.1.0.1") is None
  93. # client2 should still work
  94. assert registry.find_best_match("192.168.1.1") == client2
  95. def test_update_client(self):
  96. """Test updating a client's CIDRs."""
  97. registry = CIDRRegistry()
  98. client_id = uuid.uuid4()
  99. registry.insert("10.0.0.0/8", client_id)
  100. assert registry.find_best_match("10.0.0.1") == client_id
  101. assert registry.find_best_match("172.16.0.1") is None
  102. # Update client's CIDRs
  103. registry.update_client(client_id, ["172.16.0.0/12"])
  104. # Old CIDR should not match anymore
  105. assert registry.find_best_match("10.0.0.1") is None
  106. # New CIDR should match
  107. assert registry.find_best_match("172.16.0.1") == client_id
  108. def test_empty_registry(self):
  109. """Test that empty registry returns None for any IP."""
  110. registry = CIDRRegistry()
  111. assert registry.find_best_match("10.0.0.1") is None
  112. assert registry.find_best_match("192.168.1.1") is None
  113. assert registry.find_best_match("::1") is None
  114. def test_invalid_ip(self):
  115. """Test that invalid IP returns None."""
  116. registry = CIDRRegistry()
  117. client_id = uuid.uuid4()
  118. registry.insert("10.0.0.0/8", client_id)
  119. assert registry.find_best_match("not-an-ip") is None
  120. assert registry.find_best_match("") is None
  121. def test_complex_overlapping_cidrs(self):
  122. """Test complex overlapping CIDR scenarios."""
  123. registry = CIDRRegistry()
  124. c1 = uuid.uuid4()
  125. c2 = uuid.uuid4()
  126. c3 = uuid.uuid4()
  127. c4 = uuid.uuid4()
  128. registry.insert("0.0.0.0/0", c1)
  129. registry.insert("10.0.0.0/8", c2)
  130. registry.insert("10.1.0.0/16", c3)
  131. registry.insert("10.1.1.0/24", c4)
  132. tests = [
  133. ("1.1.1.1", c1),
  134. ("9.9.9.9", c1),
  135. ("10.0.0.1", c2),
  136. ("10.0.255.255", c2),
  137. ("10.1.0.1", c3),
  138. ("10.1.0.255", c3),
  139. ("10.1.1.0", c4),
  140. ("10.1.1.1", c4),
  141. ("10.1.1.255", c4),
  142. ("10.1.2.0", c3),
  143. ("10.2.0.0", c2),
  144. ]
  145. for ip, expected in tests:
  146. result = registry.find_best_match(ip)
  147. assert result == expected, f"IP {ip}: expected {expected}, got {result}"