message_client.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #!/usr/bin/env python3
  2. """
  3. Message Client - Client that handles CONNECT_REQUEST from server
  4. """
  5. import asyncio
  6. import logging
  7. import random
  8. import uuid
  9. from websockets.asyncio.client import connect
  10. from websockets.exceptions import ConnectionClosed
  11. from typing import List, Optional
  12. from .connection_manager import ClientConnectionManager
  13. from .message import (
  14. BaseClientInfo,
  15. SessionBaseMessage,
  16. parse_message,
  17. )
  18. from .authenticator import Authenticator, create_authenticator
  19. from .constants import default_connect_path
  20. logger = logging.getLogger(__name__)
  21. # Reconnect constants
  22. INITIAL_RECONNECT_DELAY = 1.0
  23. MAX_RECONNECT_DELAY = 60.0
  24. RECONNECT_JITTER_FACTOR = 0.3
  25. class MessageClient:
  26. """Client that handles CONNECT_REQUEST from server using ClientConnectionManager"""
  27. _client_info: BaseClientInfo
  28. def __init__(
  29. self,
  30. server_endpoint: str,
  31. client_id: uuid.UUID,
  32. cidrs: Optional[List[str]] = None,
  33. unix_sockets: Optional[List[str]] = None,
  34. authenticator: Optional[Authenticator] = None,
  35. ) -> None:
  36. # replace http(s):// with ws(s):// and append connect path
  37. self.server_uri = (
  38. server_endpoint.replace('https://', 'wss://').replace('http://', 'ws://')
  39. + default_connect_path
  40. )
  41. self._client_info = BaseClientInfo(
  42. client_id=client_id,
  43. cidrs=cidrs or [],
  44. unix_sockets=unix_sockets or [],
  45. )
  46. self._authenticator = (
  47. authenticator if authenticator is not None else create_authenticator(None)
  48. )
  49. self._lock = asyncio.Lock()
  50. self._websocket = None
  51. async def update_cidrs(self, cidrs: List[str]) -> None:
  52. """Update CIDRs for the client (thread-safe)"""
  53. async with self._lock:
  54. self._client_info.cidrs = cidrs
  55. logger.debug(f"[Client] Updated CIDRs: {cidrs}")
  56. if self._websocket and not self._websocket.close_code:
  57. await self._websocket.close(
  58. code=1008, reason="CIDRs updated"
  59. ) # Trigger reconnect to update server with new CIDRs
  60. async def run(self) -> None:
  61. """Connect to server and handle incoming messages with automatic reconnect"""
  62. reconnect_delay = INITIAL_RECONNECT_DELAY
  63. while True:
  64. async with self._lock:
  65. headers = self._client_info.to_headers()
  66. self._authenticator.inject_headers(headers)
  67. try:
  68. self._websocket = await connect(
  69. self.server_uri,
  70. proxy=None,
  71. additional_headers=headers,
  72. )
  73. logger.debug(
  74. f"[Client] Connected to {self.server_uri} with client_id: {self._client_info.client_id}"
  75. )
  76. connection_manager = ClientConnectionManager(self._websocket)
  77. reconnect_delay = (
  78. INITIAL_RECONNECT_DELAY # Reset delay on successful connection
  79. )
  80. async for raw_data in self._websocket:
  81. msg = parse_message(raw_data)
  82. logger.trace(f"[Client] Received: {msg.get_type()}")
  83. if isinstance(msg, SessionBaseMessage):
  84. await connection_manager.dispatch(msg)
  85. except ConnectionClosed:
  86. # Suppress ConnectionClosed - reconnect automatically
  87. logger.debug("[Client] Server disconnected, reconnecting...")
  88. except asyncio.CancelledError:
  89. # Task was cancelled externally - exit gracefully
  90. logger.debug("[Client] Task was cancelled, stopping")
  91. return
  92. except Exception as e:
  93. logger.error(f"[Client] Unexpected error: {e}, reconnecting...")
  94. # Exponential backoff with jitter
  95. jitter = (
  96. reconnect_delay * RECONNECT_JITTER_FACTOR * (2 * random.random() - 1)
  97. )
  98. actual_delay = min(reconnect_delay + jitter, MAX_RECONNECT_DELAY)
  99. logger.debug(f"[Client] Reconnecting in {actual_delay:.2f} seconds")
  100. await asyncio.sleep(actual_delay)
  101. reconnect_delay = min(reconnect_delay * 2, MAX_RECONNECT_DELAY)