patricia_trie.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #!/usr/bin/env python3
  2. """
  3. CIDR Registry using py-radix for efficient longest prefix match (LPM) lookups.
  4. Provides O(k) lookup where k = address bits (32 for IPv4, 128 for IPv6),
  5. using the py-radix library for production-ready Patricia Trie implementation.
  6. Radix Tree Organization (Patricia Trie):
  7. =========================================
  8. A radix tree is a compressed prefix tree that stores network prefixes.
  9. Each node represents a bit position in the binary representation of an IP address.
  10. Structure:
  11. -----------
  12. - Root node: represents the start of all addresses
  13. - Each edge: labeled with 0 or 1 (a single bit)
  14. - Each leaf node: represents a complete network prefix (CIDR)
  15. Example - inserting 10.0.0.0/8 and 10.1.0.0/16:
  16. ------------------------------------------------
  17. [root]
  18. / \\
  19. 0 1
  20. \\ \\
  21. [10.x.x.x] [other]
  22. \\
  23. ... (compressed path for /8)
  24. \\
  25. [node at bit 16, prefix=10.1.x.x, client_id=client2]
  26. \\
  27. ... (compressed path for /16)
  28. Key Properties:
  29. ----------------
  30. 1. Longest Prefix Match (LPM): When searching for an IP, the tree traversal
  31. continues until no matching child exists. The last node with a valid
  32. prefix that matches the search key is the best match.
  33. 2. Compression: Patricia trie compresses chains of single-child nodes into
  34. single nodes, reducing space complexity from O(k*n) to O(k) where k is
  35. address bits and n is number of prefixes.
  36. 3. search_best(): Traverses from root following bits of the IP address.
  37. Returns the most specific (longest) matching prefix.
  38. Lookup Example for IP 10.1.5.5:
  39. --------------------------------
  40. - Binary of 10.1.5.5: 00001010 00000001 00000101 00000101
  41. - Inserted prefixes: 10.0.0.0/8, 10.1.0.0/16
  42. 1. Start at root
  43. 2. Follow bit 0 (first bit of 10) -> child exists
  44. 3. Continue following bits 0,0,0,0,1,0,1,0 (first 8 bits = /8)
  45. - At position 8, /8 node has client_id=client1, but we continue...
  46. 4. Continue with bits for second octet (00000001 = 0,0,0,0,0,0,0,1)
  47. 5. At position 16, /16 node has client_id=client2 (more specific!)
  48. 6. Try to follow bit at position 16, but /16 is exact match, stop
  49. 7. Return client2 (the longest matching prefix)
  50. Memory Layout:
  51. ---------------
  52. RadixNode {
  53. prefix: str # e.g., "10.0.0.0/8"
  54. prefixlen: int # e.g., 8
  55. packed: bytes # binary representation of network address
  56. family: int # 2 for IPv4, 10 for IPv6
  57. data: dict # user data ({"client_id": uuid})
  58. children: dict # {0: child_node, 1: child_node}
  59. parent: node # pointer to parent node
  60. }
  61. """
  62. import radix
  63. import uuid
  64. from typing import Optional, Dict, List
  65. class CIDRRegistry:
  66. """
  67. Registry that maps CIDR ranges to client IDs using py-radix.
  68. This provides efficient longest-prefix-match lookups for IP addresses.
  69. """
  70. def __init__(self):
  71. self._tree = radix.Radix()
  72. # Track all CIDRs per client for rebuild purposes
  73. self._client_cidrs: Dict[uuid.UUID, List[str]] = {}
  74. def insert(self, cidr: str, client_id: uuid.UUID) -> None:
  75. """Insert a CIDR for a client."""
  76. node = self._tree.add(cidr)
  77. node.data["client_id"] = client_id
  78. if client_id not in self._client_cidrs:
  79. self._client_cidrs[client_id] = []
  80. if cidr not in self._client_cidrs[client_id]:
  81. self._client_cidrs[client_id].append(cidr)
  82. def remove_client(self, client_id: uuid.UUID) -> None:
  83. """Remove all CIDRs associated with a client."""
  84. if client_id in self._client_cidrs:
  85. del self._client_cidrs[client_id]
  86. self._rebuild()
  87. def update_client(self, client_id: uuid.UUID, cidrs: List[str]) -> None:
  88. """Update all CIDRs for a client."""
  89. self._client_cidrs[client_id] = list(cidrs)
  90. self._rebuild()
  91. def find_best_match(self, ip: str) -> Optional[uuid.UUID]:
  92. """Find the best matching client for an IP address."""
  93. try:
  94. node = self._tree.search_best(ip)
  95. if node:
  96. return node.data.get("client_id")
  97. except (ValueError, OSError):
  98. # Invalid IP format
  99. pass
  100. return None
  101. def _rebuild(self) -> None:
  102. """Rebuild the tree from the client_cidrs mapping."""
  103. self._tree = radix.Radix()
  104. for client_id, cidrs in self._client_cidrs.items():
  105. for cidr in cidrs:
  106. node = self._tree.add(cidr)
  107. node.data["client_id"] = client_id