message.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. import uuid
  2. import gzip
  3. import io
  4. import json
  5. import struct
  6. from dataclasses import dataclass, field
  7. from typing import Optional, List, Tuple, Type, Callable, TypeVar, Dict, Literal
  8. from enum import IntEnum
  9. # Protocol version
  10. PROTOCOL_VERSION = 0x01
  11. # TypeVar for message classes
  12. M = TypeVar('M', bound='BaseMessage')
  13. # ==================== Info Dataclasses ====================
  14. @dataclass
  15. class BaseClientInfo:
  16. """Base class for client information"""
  17. client_id: uuid.UUID
  18. cidrs: List[str] = field(default_factory=list)
  19. unix_sockets: List[str] = field(default_factory=list)
  20. def to_headers(self) -> Dict[str, str]:
  21. """Convert to headers dict for websockets client library."""
  22. return {
  23. 'x-client-id': str(self.client_id),
  24. 'x-cidrs': ','.join(self.cidrs),
  25. 'x-unix-sockets': ','.join(self.unix_sockets),
  26. }
  27. @dataclass
  28. class RegisteredClientInfo(BaseClientInfo):
  29. """Information about a registered client"""
  30. server_id: Optional[uuid.UUID] = None # Which server owns this client
  31. @classmethod
  32. def from_headers(cls, headers) -> Optional['RegisteredClientInfo']:
  33. """Create RegisteredClientInfo from headers dict or Starlette Headers object."""
  34. if hasattr(headers, 'get'):
  35. cidrs_str = headers.get('x-cidrs', '')
  36. unix_sockets_str = headers.get('x-unix-sockets', '')
  37. else:
  38. cidrs_str = None
  39. unix_sockets_str = None
  40. if not cidrs_str and not unix_sockets_str:
  41. return None
  42. client_id_str = headers.get('x-client-id', '')
  43. try:
  44. client_id = uuid.UUID(client_id_str) if client_id_str else uuid.uuid4()
  45. except ValueError:
  46. client_id = uuid.uuid4()
  47. cidr_list = [cidr.strip() for cidr in cidrs_str.split(',') if cidr.strip()]
  48. socket_list = [s.strip() for s in unix_sockets_str.split(',') if s.strip()]
  49. return cls(
  50. client_id=client_id,
  51. cidrs=cidr_list,
  52. unix_sockets=socket_list,
  53. )
  54. @dataclass
  55. class ServerInfo:
  56. """Information about a peer server"""
  57. server_id: uuid.UUID
  58. listen_address: Optional[str] = None
  59. listen_port: Optional[int] = None
  60. proxy_port: Optional[int] = None
  61. def to_headers(self) -> Dict[str, str]:
  62. """Convert to headers dict for websockets client library."""
  63. headers = {'x-server-id': str(self.server_id)}
  64. if self.listen_address is not None:
  65. headers['x-server-listen-address'] = self.listen_address
  66. if self.listen_port is not None:
  67. headers['x-server-listen-port'] = str(self.listen_port)
  68. if self.proxy_port is not None:
  69. headers['x-server-proxy-port'] = str(self.proxy_port)
  70. return headers
  71. def to_bytes_headers(self) -> List[Tuple[bytes, bytes]]:
  72. """Convert to headers list of tuples for Starlette WebSocket accept."""
  73. headers = [(b'x-server-id', str(self.server_id).encode())]
  74. if self.listen_address is not None:
  75. headers.append((b'x-server-listen-address', self.listen_address.encode()))
  76. if self.listen_port is not None:
  77. headers.append((b'x-server-listen-port', str(self.listen_port).encode()))
  78. if self.proxy_port is not None:
  79. headers.append((b'x-server-proxy-port', str(self.proxy_port).encode()))
  80. return headers
  81. @classmethod
  82. def from_headers(cls, headers) -> Optional['ServerInfo']:
  83. """Create ServerInfo from headers dict or Starlette Headers object."""
  84. try:
  85. def get_header(key: str) -> Optional[str]:
  86. # Starlette Headers object
  87. if hasattr(headers, 'get'):
  88. try:
  89. val = headers.get(key)
  90. if val is not None:
  91. return val
  92. except (TypeError, KeyError):
  93. pass
  94. # Dict with bytes or str keys
  95. if isinstance(headers, dict):
  96. val = (
  97. headers.get(key)
  98. or headers.get(key.encode())
  99. or headers.get(key.lower())
  100. )
  101. if val is not None:
  102. return val.decode() if isinstance(val, bytes) else val
  103. return None
  104. server_id_str = get_header('x-server-id')
  105. if not server_id_str:
  106. return None
  107. return cls(
  108. server_id=uuid.UUID(server_id_str),
  109. listen_address=get_header('x-server-listen-address'),
  110. listen_port=(
  111. int(v)
  112. if (v := get_header('x-server-listen-port')) is not None
  113. else None
  114. ),
  115. proxy_port=(
  116. int(v)
  117. if (v := get_header('x-server-proxy-port')) is not None
  118. else None
  119. ),
  120. )
  121. except (ValueError, TypeError):
  122. return None
  123. @dataclass
  124. class ServerPeer(ServerInfo):
  125. """Information about a connected peer server"""
  126. websocket: Optional[object] = None
  127. connected: bool = False
  128. # Binary message types
  129. class BinaryType(IntEnum):
  130. # Client <-> Server messages
  131. CONNECT_REQUEST = 0x01
  132. CONNECT_RESPONSE = 0x02
  133. DATA = 0x03
  134. DISCONNECT = 0x04
  135. HEARTBEAT = 0x05
  136. LIST_CLIENTS = 0x06
  137. LIST_CLIENTS_RESPONSE = 0x07
  138. # Server <-> Server messages
  139. CLIENT_UPDATE = 0x08
  140. # Protocol types
  141. class BinaryProtocol(IntEnum):
  142. TCP = 0x01
  143. UDP = 0x02
  144. UNIX = 0x03
  145. # Type strings
  146. TYPE_CONNECT_REQUEST = "connect_request"
  147. TYPE_CONNECT_RESPONSE = "connect_response"
  148. TYPE_DATA = "data"
  149. TYPE_DISCONNECT = "disconnect"
  150. TYPE_HEARTBEAT = "heartbeat"
  151. TYPE_LIST_CLIENTS = "list_clients"
  152. TYPE_LIST_CLIENTS_RESPONSE = "list_clients_response"
  153. TYPE_CLIENT_UPDATE = "client_update"
  154. # Compression flags
  155. DATA_COMPRESSION_NONE = 0x00
  156. DATA_COMPRESSION_GZIP = 0x01
  157. # ==================== Protocol Helpers ====================
  158. def protocol_to_bytes(protocol: str) -> int:
  159. if protocol == "tcp":
  160. return BinaryProtocol.TCP
  161. elif protocol == "udp":
  162. return BinaryProtocol.UDP
  163. elif protocol == "unix":
  164. return BinaryProtocol.UNIX
  165. else:
  166. return 0
  167. def bytes_to_protocol(b: int) -> str:
  168. if b == BinaryProtocol.TCP:
  169. return "tcp"
  170. elif b == BinaryProtocol.UDP:
  171. return "udp"
  172. elif b == BinaryProtocol.UNIX:
  173. return "unix"
  174. else:
  175. return ""
  176. # ==================== Message Registry ====================
  177. class MessageRegistry:
  178. """Registry for message types and their serialization"""
  179. _registry: dict[BinaryType, Type[M]] = {}
  180. _type_to_binary: dict[str, BinaryType] = {}
  181. @classmethod
  182. def register(
  183. cls, binary_type: BinaryType, msg_type: str
  184. ) -> Callable[[Type[M]], Type[M]]:
  185. def decorator(message_cls: Type[M]) -> Type[M]:
  186. cls._registry[binary_type] = message_cls
  187. cls._type_to_binary[msg_type] = binary_type
  188. return message_cls
  189. return decorator
  190. @classmethod
  191. def get_message_class(cls, binary_type: BinaryType) -> Optional[Type[M]]:
  192. return cls._registry.get(binary_type)
  193. @classmethod
  194. def get_binary_type(cls, msg_type: str) -> Optional[BinaryType]:
  195. return cls._type_to_binary.get(msg_type)
  196. # ==================== Base Message Class ====================
  197. class BaseMessage:
  198. """Base class for all messages with built-in serialization"""
  199. def get_type(self) -> str:
  200. raise NotImplementedError
  201. def pack(self) -> bytes:
  202. """Serialize message to binary format"""
  203. result = bytearray([PROTOCOL_VERSION])
  204. result.append(MessageRegistry.get_binary_type(self.get_type()))
  205. result.extend(self._pack_payload())
  206. return bytes(result)
  207. def _pack_payload(self) -> bytes:
  208. """Subclasses implement this to pack their payload"""
  209. raise NotImplementedError
  210. @classmethod
  211. def parse(cls, data: bytes) -> M:
  212. """Parse binary data into a message"""
  213. if len(data) < 2:
  214. raise ValueError("Message too short")
  215. version = data[0]
  216. if version != PROTOCOL_VERSION:
  217. raise ValueError(f"Unsupported protocol version: {version}")
  218. msg_type = data[1]
  219. payload = data[2:]
  220. # Check if msg_type is a valid BinaryType value
  221. if msg_type not in BinaryType._value2member_map_:
  222. raise ValueError(f"Unknown binary message type: {msg_type}")
  223. message_cls = MessageRegistry.get_message_class(BinaryType(msg_type))
  224. if not message_cls:
  225. raise ValueError(f"Unknown binary message type: {msg_type}")
  226. return message_cls._parse_payload(payload)
  227. @classmethod
  228. def _parse_payload(cls, payload: bytes) -> M:
  229. """Subclasses implement this to parse their payload"""
  230. raise NotImplementedError
  231. @dataclass
  232. class SessionBaseMessage(BaseMessage):
  233. """Base class for messages that have a session_id (used for dispatching to ConnectionManager)"""
  234. session_id: uuid.UUID
  235. # ==================== Message Definitions ====================
  236. DataCompressor = Callable[[bytes], bytes]
  237. def compress_gzip(data: bytes) -> bytes:
  238. buf = io.BytesIO()
  239. with gzip.GzipFile(fileobj=buf, mode="wb") as gz:
  240. gz.write(data)
  241. return buf.getvalue()
  242. def decompress_gzip(data: bytes) -> bytes:
  243. with gzip.GzipFile(fileobj=io.BytesIO(data), mode="rb") as gz:
  244. return gz.read()
  245. _DATA_COMPRESSORS: dict[int, DataCompressor] = {
  246. DATA_COMPRESSION_NONE: lambda x: x,
  247. DATA_COMPRESSION_GZIP: compress_gzip,
  248. }
  249. _DATA_DECOMPRESSORS: dict[int, DataCompressor] = {
  250. DATA_COMPRESSION_NONE: lambda x: x,
  251. DATA_COMPRESSION_GZIP: decompress_gzip,
  252. }
  253. def _pack_string_list(strings: List[str]) -> bytes:
  254. """Pack a list of strings"""
  255. result = struct.pack(">H", len(strings))
  256. for s in strings:
  257. encoded = s.encode()
  258. result += bytes([len(encoded)])
  259. result += encoded
  260. return result
  261. def _unpack_string_list(payload: bytes, pos: int) -> tuple[List[str], int]:
  262. """Unpack a list of strings, returns (list, new_pos)"""
  263. count = struct.unpack(">H", payload[pos : pos + 2])[0]
  264. pos += 2
  265. strings = []
  266. for _ in range(count):
  267. if pos >= len(payload):
  268. raise ValueError("Invalid message")
  269. slen = payload[pos]
  270. pos += 1
  271. if pos + slen > len(payload):
  272. raise ValueError("Invalid message")
  273. strings.append(payload[pos : pos + slen].decode())
  274. pos += slen
  275. return strings, pos
  276. def _pack_error(error: Optional[str]) -> bytes:
  277. if error:
  278. encoded = error.encode()
  279. return bytes([len(encoded)]) + encoded
  280. return bytes([0])
  281. def _unpack_error(payload: bytes, pos: int) -> tuple[Optional[str], int]:
  282. if pos >= len(payload):
  283. return None, pos
  284. err_len = payload[pos]
  285. pos += 1
  286. if err_len == 0:
  287. return None, pos
  288. if pos + err_len > len(payload):
  289. raise ValueError("Invalid message")
  290. return payload[pos : pos + err_len].decode(), pos + err_len
  291. @MessageRegistry.register(BinaryType.CONNECT_REQUEST, TYPE_CONNECT_REQUEST)
  292. @dataclass
  293. class ConnectRequestMessage(SessionBaseMessage):
  294. """Server tells client to connect to target
  295. URL format: tcp://host:port or unix:///path/to/socket
  296. """
  297. target_url: str
  298. def get_type(self) -> str:
  299. return TYPE_CONNECT_REQUEST
  300. def _pack_payload(self) -> bytes:
  301. url_bytes = self.target_url.encode()
  302. return self.session_id.bytes + struct.pack(">H", len(url_bytes)) + url_bytes
  303. @classmethod
  304. def _parse_payload(cls, payload: bytes) -> 'ConnectRequestMessage':
  305. if len(payload) < 18:
  306. raise ValueError("Invalid connect request message")
  307. session_id = uuid.UUID(bytes=payload[:16])
  308. url_len = struct.unpack(">H", payload[16:18])[0]
  309. if len(payload) < 18 + url_len:
  310. raise ValueError("Invalid connect request message")
  311. target_url = payload[18 : 18 + url_len].decode()
  312. return cls(session_id=session_id, target_url=target_url)
  313. @MessageRegistry.register(BinaryType.CONNECT_RESPONSE, TYPE_CONNECT_RESPONSE)
  314. @dataclass
  315. class ConnectResponseMessage(SessionBaseMessage):
  316. """Client response to connect request"""
  317. success: bool
  318. error: Optional[str] = None
  319. def get_type(self) -> str:
  320. return TYPE_CONNECT_RESPONSE
  321. def _pack_payload(self) -> bytes:
  322. result = self.session_id.bytes
  323. result += bytes([1 if self.success else 0])
  324. result += _pack_error(self.error)
  325. return result
  326. @classmethod
  327. def _parse_payload(cls, payload: bytes) -> 'ConnectResponseMessage':
  328. if len(payload) < 17:
  329. raise ValueError("Invalid connect response message")
  330. session_id = uuid.UUID(bytes=payload[:16])
  331. success = bool(payload[16])
  332. error, _ = _unpack_error(payload, 17)
  333. return cls(session_id=session_id, success=success, error=error)
  334. @MessageRegistry.register(BinaryType.DATA, TYPE_DATA)
  335. @dataclass
  336. class DataMessage(SessionBaseMessage):
  337. """Data transmission message"""
  338. data: bytes
  339. compression: int = DATA_COMPRESSION_NONE
  340. def get_type(self) -> str:
  341. return TYPE_DATA
  342. def _pack_payload(self) -> bytes:
  343. compressor = _DATA_COMPRESSORS.get(self.compression, lambda x: x)
  344. compressed = compressor(self.data)
  345. result = self.session_id.bytes
  346. result += bytes([self.compression])
  347. result += struct.pack(">I", len(compressed))
  348. result += compressed
  349. return result
  350. @classmethod
  351. def _parse_payload(cls, payload: bytes) -> 'DataMessage':
  352. if len(payload) < 21:
  353. raise ValueError("Invalid data message")
  354. session_id = uuid.UUID(bytes=payload[:16])
  355. compression = payload[16]
  356. data_len = struct.unpack(">I", payload[17:21])[0]
  357. if len(payload) < 21 + data_len:
  358. raise ValueError("Invalid data message")
  359. decompressor = _DATA_DECOMPRESSORS.get(compression, lambda x: x)
  360. data = decompressor(payload[21 : 21 + data_len])
  361. return cls(session_id=session_id, data=data, compression=compression)
  362. @MessageRegistry.register(BinaryType.DISCONNECT, TYPE_DISCONNECT)
  363. @dataclass
  364. class DisconnectMessage(SessionBaseMessage):
  365. """Connection close message"""
  366. error: Optional[str] = None
  367. def get_type(self) -> str:
  368. return TYPE_DISCONNECT
  369. def _pack_payload(self) -> bytes:
  370. result = self.session_id.bytes
  371. result += _pack_error(self.error)
  372. return result
  373. @classmethod
  374. def _parse_payload(cls, payload: bytes) -> 'DisconnectMessage':
  375. if len(payload) < 16:
  376. raise ValueError("Invalid disconnect message")
  377. session_id = uuid.UUID(bytes=payload[:16])
  378. error, _ = _unpack_error(payload, 16)
  379. return cls(session_id=session_id, error=error)
  380. @MessageRegistry.register(BinaryType.HEARTBEAT, TYPE_HEARTBEAT)
  381. @dataclass
  382. class HeartbeatMessage(BaseMessage):
  383. """Keep-alive heartbeat"""
  384. timestamp: int = 0
  385. def get_type(self) -> str:
  386. return TYPE_HEARTBEAT
  387. def _pack_payload(self) -> bytes:
  388. return struct.pack(">Q", self.timestamp)
  389. @classmethod
  390. def _parse_payload(cls, payload: bytes) -> 'HeartbeatMessage':
  391. if len(payload) < 8:
  392. raise ValueError("Invalid heartbeat message")
  393. timestamp = struct.unpack(">Q", payload[:8])[0]
  394. return cls(timestamp=timestamp)
  395. @MessageRegistry.register(BinaryType.LIST_CLIENTS, TYPE_LIST_CLIENTS)
  396. @dataclass
  397. class ListClientsMessage(BaseMessage):
  398. """Request list of connected clients"""
  399. def get_type(self) -> str:
  400. return TYPE_LIST_CLIENTS
  401. def _pack_payload(self) -> bytes:
  402. return b""
  403. @classmethod
  404. def _parse_payload(cls, payload: bytes) -> 'ListClientsMessage':
  405. return cls()
  406. @dataclass
  407. class ClientInfo(BaseClientInfo):
  408. """Information about a connected client (for LIST_CLIENTS_RESPONSE)"""
  409. pass
  410. @MessageRegistry.register(BinaryType.LIST_CLIENTS_RESPONSE, TYPE_LIST_CLIENTS_RESPONSE)
  411. @dataclass
  412. class ListClientsResponseMessage(BaseMessage):
  413. """Response with list of clients"""
  414. clients: List[ClientInfo] = field(default_factory=list)
  415. def get_type(self) -> str:
  416. return TYPE_LIST_CLIENTS_RESPONSE
  417. def _pack_payload(self) -> bytes:
  418. result = struct.pack(">H", len(self.clients))
  419. for client in self.clients:
  420. result += client.client_id.bytes
  421. result += _pack_string_list(client.cidrs)
  422. result += _pack_string_list(client.unix_sockets)
  423. return result
  424. @classmethod
  425. def _parse_payload(cls, payload: bytes) -> 'ListClientsResponseMessage':
  426. if len(payload) < 2:
  427. raise ValueError("Invalid list clients response message")
  428. count = struct.unpack(">H", payload[:2])[0]
  429. clients = []
  430. pos = 2
  431. for _ in range(count):
  432. if pos + 16 > len(payload):
  433. raise ValueError("Invalid list clients response message")
  434. client_id = uuid.UUID(bytes=payload[pos : pos + 16])
  435. pos += 16
  436. cidrs, pos = _unpack_string_list(payload, pos)
  437. unix_sockets, pos = _unpack_string_list(payload, pos)
  438. clients.append(
  439. ClientInfo(
  440. client_id=client_id,
  441. cidrs=cidrs,
  442. unix_sockets=unix_sockets,
  443. )
  444. )
  445. return cls(clients=clients)
  446. # ==================== Server <-> Server Messages ====================
  447. @dataclass
  448. class ClientUpdateInfo(BaseClientInfo):
  449. """Client information in an update message"""
  450. action: Literal["add", "remove", ""] = ""
  451. @MessageRegistry.register(BinaryType.CLIENT_UPDATE, TYPE_CLIENT_UPDATE)
  452. @dataclass
  453. class ClientUpdateMessage(BaseMessage):
  454. """Server notifies peers about client changes"""
  455. server_id: uuid.UUID
  456. updates: List[ClientUpdateInfo] = field(default_factory=list)
  457. def get_type(self) -> str:
  458. return TYPE_CLIENT_UPDATE
  459. def _pack_payload(self) -> bytes:
  460. result = self.server_id.bytes
  461. result += struct.pack(">H", len(self.updates))
  462. for update in self.updates:
  463. result += update.client_id.bytes
  464. result += bytes([1 if update.action == "add" else 0])
  465. result += _pack_string_list(update.cidrs)
  466. result += _pack_string_list(update.unix_sockets)
  467. return result
  468. @classmethod
  469. def _parse_payload(cls, payload: bytes) -> 'ClientUpdateMessage':
  470. if len(payload) < 18:
  471. raise ValueError("Invalid client update message")
  472. server_id = uuid.UUID(bytes=payload[:16])
  473. pos = 16
  474. count = struct.unpack(">H", payload[pos : pos + 2])[0]
  475. pos += 2
  476. updates = []
  477. for _ in range(count):
  478. if pos + 16 > len(payload):
  479. raise ValueError("Invalid client update message")
  480. client_id = uuid.UUID(bytes=payload[pos : pos + 16])
  481. pos += 16
  482. action = "add" if payload[pos] == 1 else "remove"
  483. pos += 1
  484. cidrs, pos = _unpack_string_list(payload, pos)
  485. unix_sockets, pos = _unpack_string_list(payload, pos)
  486. updates.append(
  487. ClientUpdateInfo(
  488. client_id=client_id,
  489. action=action,
  490. cidrs=cidrs,
  491. unix_sockets=unix_sockets,
  492. )
  493. )
  494. return cls(server_id=server_id, updates=updates)
  495. # ==================== Convenience Functions ====================
  496. def pack_message(msg: BaseMessage) -> bytes:
  497. """Pack a message into binary format (backward compatible)"""
  498. return msg.pack()
  499. def parse_message(data: bytes) -> BaseMessage:
  500. """Parse binary data into a message (backward compatible)"""
  501. return BaseMessage.parse(data)
  502. def message_to_json(msg: BaseMessage) -> str:
  503. """Convert message to JSON string (for debugging)"""
  504. def serialize_value(v):
  505. if isinstance(v, uuid.UUID):
  506. return str(v)
  507. if isinstance(v, bytes):
  508. return v.hex()
  509. if isinstance(v, list):
  510. return [serialize_value(x) for x in v]
  511. if isinstance(v, ClientInfo):
  512. return {
  513. "client_id": str(v.client_id),
  514. "cidrs": v.cidrs,
  515. "unix_sockets": v.unix_sockets,
  516. }
  517. return v
  518. result = {"type": msg.get_type()}
  519. for field_name in msg.__dataclass_fields__:
  520. value = getattr(msg, field_name)
  521. result[field_name] = serialize_value(value)
  522. return json.dumps(result)