message_server.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. #!/usr/bin/env python3
  2. """
  3. Message Server - Handles WebSocket connections and message protocol
  4. """
  5. import uuid
  6. import asyncio
  7. import logging
  8. from typing import Optional, Dict, List, Callable, Coroutine, Any, TypeAlias
  9. from .patricia_trie import CIDRRegistry
  10. from fastapi import WebSocket, Depends, APIRouter
  11. from websockets.asyncio.client import connect as ws_connect
  12. from .connection_manager import ConnectionManager, RemoteConnectionManager
  13. from .message import (
  14. SessionBaseMessage,
  15. HeartbeatMessage,
  16. ClientUpdateMessage,
  17. ClientUpdateInfo,
  18. RegisteredClientInfo,
  19. ServerInfo,
  20. ServerPeer,
  21. parse_message,
  22. pack_message,
  23. )
  24. from .authenticator import Authenticator, NoOpAuthenticator
  25. from .constants import default_connect_path
  26. logger = logging.getLogger(__name__)
  27. VERSION = "0.1.0"
  28. # ==================== Type Aliases ====================
  29. ServerInfoGetter: TypeAlias = Callable[[WebSocket], Optional['ServerInfo']]
  30. RegisteredClientInfoGetter: TypeAlias = Callable[
  31. [WebSocket], Optional['RegisteredClientInfo']
  32. ]
  33. WSCallback: TypeAlias = Callable[
  34. [Optional['ServerInfo'], Optional['RegisteredClientInfo']],
  35. Coroutine[Any, Any, None],
  36. ]
  37. # ==================== Helper Functions ====================
  38. def default_server_info_getter(websocket: WebSocket) -> Optional['ServerInfo']:
  39. """Extract server info from headers if this is a server connection"""
  40. return ServerInfo.from_headers(websocket.headers)
  41. def default_client_info_getter(
  42. websocket: WebSocket,
  43. ) -> Optional['RegisteredClientInfo']:
  44. """Extract client info from headers if this is a client connection"""
  45. return RegisteredClientInfo.from_headers(websocket.headers)
  46. # ==================== Dataclasses ====================
  47. class MessageServerHandler:
  48. """Simple server that handles message protocol and connect-based proxy"""
  49. _server_info: ServerInfo
  50. def __init__(
  51. self,
  52. listen_address: str,
  53. listen_port: int,
  54. proxy_port: int,
  55. server_id: Optional[uuid.UUID] = None,
  56. client_info_getter: RegisteredClientInfoGetter = default_client_info_getter,
  57. server_info_getter: ServerInfoGetter = default_server_info_getter,
  58. authenticator: Authenticator = None,
  59. callback_on_connect: Optional[WSCallback] = None,
  60. callback_on_disconnect: Optional[WSCallback] = None,
  61. ):
  62. self._server_info = ServerInfo(
  63. server_id=server_id or uuid.uuid4(),
  64. listen_address=listen_address,
  65. listen_port=listen_port,
  66. proxy_port=proxy_port,
  67. )
  68. self._callback_on_connect = callback_on_connect
  69. self._callback_on_disconnect = callback_on_disconnect
  70. # Info extractors - can be overridden for customization
  71. self._client_info_getter = client_info_getter
  72. self._server_info_getter = server_info_getter
  73. self._authenticator = authenticator or NoOpAuthenticator()
  74. # Client management
  75. self.client_registry: Dict[uuid.UUID, RegisteredClientInfo] = (
  76. {}
  77. ) # client_id -> client_info
  78. self.connection_managers: Dict[uuid.UUID, ConnectionManager] = {}
  79. self._cidr_registry = CIDRRegistry()
  80. # Generation tracking for disconnect callbacks
  81. self._client_generations: Dict[uuid.UUID, int] = {} # client_id -> generation
  82. self._generation_lock = asyncio.Lock()
  83. # Server federation
  84. self.peers: Dict[uuid.UUID, ServerPeer] = {} # Outgoing: servers I connected to
  85. self.serving_peers: Dict[uuid.UUID, ServerPeer] = (
  86. {}
  87. ) # Incoming: servers that connected to me
  88. self.peer_tasks: Dict[uuid.UUID, asyncio.Task] = {} # server_id -> task
  89. async def _get_next_generation(self, client_id: uuid.UUID) -> int:
  90. """Get the next generation number for a client (thread-safe)."""
  91. async with self._generation_lock:
  92. gen = self._client_generations.get(client_id, 0) + 1
  93. self._client_generations[client_id] = gen
  94. return gen
  95. async def _safe_callback(
  96. self,
  97. callback: WSCallback,
  98. server_info: Optional[ServerInfo],
  99. client_info: Optional[RegisteredClientInfo],
  100. ) -> None:
  101. """Execute callback with error handling, does not raise exceptions."""
  102. try:
  103. await callback(server_info, client_info)
  104. except Exception as e:
  105. logger.error(f"[Server] callback_on_connect error: {e}", exc_info=True)
  106. async def _safe_disconnect_callback(
  107. self,
  108. callback: WSCallback,
  109. client_info: Optional[RegisteredClientInfo],
  110. generation: int,
  111. ) -> None:
  112. """Execute disconnect callback with error handling and stale callback filtering."""
  113. if client_info is None:
  114. return
  115. async with self._generation_lock:
  116. current_gen = self._client_generations.get(client_info.client_id, 0)
  117. if generation < current_gen:
  118. logger.debug(
  119. f"[Server] Stale disconnect callback ignored: "
  120. f"client={client_info.client_id}, callback_gen={generation}, "
  121. f"current_gen={current_gen}"
  122. )
  123. return
  124. try:
  125. await callback(None, client_info)
  126. except Exception as e:
  127. logger.error(f"[Server] callback_on_disconnect error: {e}", exc_info=True)
  128. def _find_peer_by_server_id(self, server_id: uuid.UUID) -> Optional[ServerPeer]:
  129. """Find a peer by server_id, checking both outgoing and incoming peers"""
  130. peer = self.peers.get(server_id)
  131. if peer:
  132. return peer
  133. return self.serving_peers.get(server_id)
  134. def get_connection_manager_by_ip_in_cidr(
  135. self, target_ip: str
  136. ) -> Optional[ConnectionManager]:
  137. """Find a ConnectionManager by matching IP against registered CIDRs using Patricia Trie.
  138. Returns local ConnectionManager for local clients, or RemoteConnectionManager
  139. for peer clients.
  140. """
  141. client_id = self._cidr_registry.find_best_match(target_ip)
  142. if not client_id:
  143. return None
  144. client_info = self.client_registry.get(client_id)
  145. if not client_info:
  146. return None
  147. # Check if this is a local client (registered on this server)
  148. if client_info.server_id == self._server_info.server_id:
  149. return self.connection_managers.get(client_id)
  150. # Peer client - find the peer and return RemoteConnectionManager
  151. peer = self._find_peer_by_server_id(client_info.server_id)
  152. if peer and peer.listen_address and peer.proxy_port:
  153. return RemoteConnectionManager(peer.listen_address, peer.proxy_port)
  154. return None
  155. def get_connection_manager(self, target_ip: str) -> Optional[ConnectionManager]:
  156. """Get connection manager for target IP (local or remote)."""
  157. return self.get_connection_manager_by_ip_in_cidr(target_ip)
  158. async def add_peer(self, address: str, port: int) -> Optional[uuid.UUID]:
  159. """Add a peer server and connect to it. Returns the peer_id when connected."""
  160. # Create a future that will be resolved when the peer connects
  161. future = asyncio.Future()
  162. asyncio.create_task(self.connect_to_peer(address, port, future))
  163. # Wait for the peer to connect (with timeout)
  164. try:
  165. peer_id = await asyncio.wait_for(future, timeout=10.0)
  166. return peer_id
  167. except asyncio.TimeoutError:
  168. logger.debug(
  169. f"[Server] Timeout waiting for peer to connect: {address}:{port}"
  170. )
  171. return None
  172. async def remove_peer(self, peer_id: uuid.UUID):
  173. """Remove a peer server by UUID"""
  174. logger.debug(f"[Server] Attempting to remove peer by UUID: {peer_id}")
  175. return await self._remove_peer_impl(peer_id)
  176. async def remove_peer_by_address(self, address: str):
  177. """Remove a peer server by address (host:port)"""
  178. logger.debug(f"[Server] Attempting to remove peer by address: {address}")
  179. target_peer_id = self._get_peer_id_by_address(address)
  180. if not target_peer_id:
  181. return False
  182. return await self._remove_peer_impl(target_peer_id)
  183. def _get_peer_id_by_address(self, address: str) -> Optional[uuid.UUID]:
  184. """Helper to find peer_id by address (host:port)"""
  185. for peer_id, peer in self.peers.items():
  186. peer_addr = f"{peer.listen_address}:{peer.listen_port}"
  187. if peer_addr == address:
  188. return peer_id
  189. for peer_id, peer in self.serving_peers.items():
  190. peer_addr = f"{peer.listen_address}:{peer.listen_port}"
  191. if peer_addr == address:
  192. return peer_id
  193. logger.debug(f"[Server] No peer found with address: {address}")
  194. return None
  195. async def _remove_peer_impl(self, peer_id: uuid.UUID) -> bool:
  196. """Internal implementation for removing a peer (checks both peers and serving_peers)"""
  197. # Check outgoing peers first
  198. peer = self.peers.pop(peer_id, None)
  199. if not peer:
  200. # Check incoming serving_peers
  201. peer = self.serving_peers.pop(peer_id, None)
  202. if peer:
  203. logger.debug(f"[Server] Found peer to remove: {peer.server_id}")
  204. if peer.websocket:
  205. await peer.websocket.close()
  206. logger.debug(f"[Server] Closed websocket for peer: {peer_id}")
  207. else:
  208. logger.debug(f"[Server] Peer not found: {peer_id}")
  209. return False
  210. task = self.peer_tasks.pop(peer_id, None)
  211. if task:
  212. task.cancel()
  213. logger.debug(f"[Server] Cancelled task for peer: {peer_id}")
  214. logger.debug(f"[Server] Removed peer: {peer_id}")
  215. return True
  216. async def connect_to_peer(self, host: str, port: int, future: asyncio.Future):
  217. """Connect to a peer server"""
  218. peer_key = f"{host}:{port}"
  219. try:
  220. ws_uri = f"ws://{host}:{port}{default_connect_path}"
  221. logger.debug(f"[Server] Connecting to peer: {ws_uri}")
  222. # Connect with server info in headers (header-based registration)
  223. headers = self._server_info.to_headers()
  224. self._authenticator.inject_headers(headers)
  225. websocket = await ws_connect(ws_uri, additional_headers=headers)
  226. # Get peer info from response headers
  227. peer_info = ServerInfo.from_headers(dict(websocket.response.headers))
  228. if not peer_info:
  229. logger.debug("[Server] Peer did not provide valid registration headers")
  230. await websocket.close()
  231. if not future.done():
  232. future.set_result(None)
  233. return
  234. peer_server_id = peer_info.server_id
  235. peer = ServerPeer(
  236. server_id=peer_info.server_id,
  237. listen_address=peer_info.listen_address,
  238. listen_port=peer_info.listen_port,
  239. proxy_port=peer_info.proxy_port,
  240. websocket=websocket,
  241. connected=True,
  242. )
  243. self.peers[peer_server_id] = peer
  244. logger.debug(f"[Server] Registered with peer: {peer_server_id}")
  245. # Resolve the future to notify add_peer that connection is complete
  246. if not future.done():
  247. future.set_result(peer_server_id)
  248. # Start handling messages from peer
  249. task = asyncio.create_task(self.handle_peer(websocket, peer_server_id))
  250. self.peer_tasks[peer_server_id] = task
  251. except Exception as e:
  252. # Connection failed - peer may not be running or rejected connection
  253. logger.debug(f"[Server] Failed to connect to peer {peer_key}: {e}")
  254. if not future.done():
  255. future.set_exception(e)
  256. async def handle_peer(self, websocket, peer_server_id: uuid.UUID):
  257. """Handle messages from a peer server"""
  258. try:
  259. # Check if this is a Starlette WebSocket or websockets WebSocket
  260. if hasattr(websocket, 'receive'):
  261. # Starlette WebSocket
  262. while True:
  263. message = await websocket.receive()
  264. if message.get("type") == "websocket.disconnect":
  265. break
  266. raw_data = message.get("bytes") or message.get("text", "").encode()
  267. msg = parse_message(raw_data)
  268. logger.trace(f"[Server] Received from peer: {msg.get_type()}")
  269. if isinstance(msg, ClientUpdateMessage):
  270. await self.handle_peer_client_update(msg)
  271. else:
  272. # websockets library WebSocket
  273. while True:
  274. raw_data = await websocket.recv()
  275. msg = parse_message(raw_data)
  276. logger.trace(f"[Server] Received from peer: {msg.get_type()}")
  277. if isinstance(msg, ClientUpdateMessage):
  278. await self.handle_peer_client_update(msg)
  279. except Exception as e:
  280. logger.debug(f"[Server] Peer connection error: {e}")
  281. finally:
  282. # Clean up peer on disconnect (check both peers and serving_peers)
  283. if peer_server_id in self.peers:
  284. del self.peers[peer_server_id]
  285. if peer_server_id in self.serving_peers:
  286. del self.serving_peers[peer_server_id]
  287. if peer_server_id in self.peer_tasks:
  288. del self.peer_tasks[peer_server_id]
  289. # Clean up clients registered through this peer
  290. await self._remove_clients_from_peer(peer_server_id)
  291. logger.debug(f"[Server] Peer disconnected: {peer_server_id}")
  292. async def handle_peer_client_update(self, msg: ClientUpdateMessage):
  293. """Handle client update from peer"""
  294. for update in msg.updates:
  295. if update.action == "add":
  296. # Add client's CIDRs to local registry
  297. client_info = RegisteredClientInfo(
  298. client_id=update.client_id,
  299. cidrs=update.cidrs,
  300. unix_sockets=update.unix_sockets,
  301. server_id=msg.server_id,
  302. )
  303. self.client_registry[update.client_id] = client_info
  304. # Index CIDRs for efficient lookup
  305. for cidr in update.cidrs:
  306. self._cidr_registry.insert(cidr, update.client_id)
  307. logger.debug(f"[Server] Added client from peer: {update.client_id}")
  308. elif update.action == "remove":
  309. # Remove client from local registry
  310. if update.client_id in self.client_registry:
  311. self._cidr_registry.remove_client(update.client_id)
  312. del self.client_registry[update.client_id]
  313. logger.debug(
  314. f"[Server] Removed client from peer: {update.client_id}"
  315. )
  316. async def _remove_clients_from_peer(self, peer_server_id: uuid.UUID):
  317. """Remove all clients that were registered through a peer server."""
  318. clients_to_remove = [
  319. client_id
  320. for client_id, client_info in self.client_registry.items()
  321. if client_info.server_id == peer_server_id
  322. ]
  323. for client_id in clients_to_remove:
  324. self._cidr_registry.remove_client(client_id)
  325. del self.client_registry[client_id]
  326. logger.debug(
  327. f"[Server] Removed client {client_id} from disconnected peer {peer_server_id}"
  328. )
  329. async def send_client_update_to_peer(self, websocket, action: str):
  330. """Send client updates to a peer"""
  331. updates = []
  332. for client_id, client_info in self.client_registry.items():
  333. # Only send clients owned by this server
  334. if client_info.server_id == self._server_info.server_id:
  335. updates.append(
  336. ClientUpdateInfo(
  337. client_id=client_id,
  338. action=action,
  339. cidrs=client_info.cidrs,
  340. unix_sockets=client_info.unix_sockets,
  341. )
  342. )
  343. if updates:
  344. msg = ClientUpdateMessage(
  345. server_id=self._server_info.server_id, updates=updates
  346. )
  347. msg_data = pack_message(msg)
  348. if hasattr(websocket, 'send_bytes'):
  349. await websocket.send_bytes(msg_data)
  350. else:
  351. await websocket.send(msg_data)
  352. async def broadcast_client_update(
  353. self,
  354. action: str,
  355. client_id: uuid.UUID,
  356. cidrs: List[str],
  357. unix_sockets: List[str],
  358. ):
  359. """Broadcast client update to all peers (both outgoing and incoming)"""
  360. update = ClientUpdateInfo(
  361. client_id=client_id,
  362. action=action,
  363. cidrs=cidrs,
  364. unix_sockets=unix_sockets,
  365. )
  366. msg = ClientUpdateMessage(
  367. server_id=self._server_info.server_id, updates=[update]
  368. )
  369. # Broadcast to outgoing peers
  370. for peer_id, peer in self.peers.items():
  371. if peer.connected and peer.websocket:
  372. try:
  373. msg_data = pack_message(msg)
  374. if hasattr(peer.websocket, 'send_bytes'):
  375. await peer.websocket.send_bytes(msg_data)
  376. else:
  377. await peer.websocket.send(msg_data)
  378. except Exception as e:
  379. logger.debug(
  380. f"[Server] Error sending update to peer {peer_id}: {e}"
  381. )
  382. # Broadcast to incoming serving_peers
  383. for peer_id, peer in self.serving_peers.items():
  384. if peer.connected and peer.websocket:
  385. try:
  386. msg_data = pack_message(msg)
  387. if hasattr(peer.websocket, 'send_bytes'):
  388. await peer.websocket.send_bytes(msg_data)
  389. else:
  390. await peer.websocket.send(msg_data)
  391. except Exception as e:
  392. logger.debug(
  393. f"[Server] Error sending update to serving_peer {peer_id}: {e}"
  394. )
  395. async def handle_client_connection(
  396. self, websocket: WebSocket, client_info: RegisteredClientInfo
  397. ):
  398. """Handle a client WebSocket connection"""
  399. # Accept first — if the handshake fails nothing needs cleanup.
  400. await websocket.accept()
  401. client_id = client_info.client_id
  402. cidr_list = client_info.cidrs
  403. socket_list = client_info.unix_sockets
  404. connection_manager = ConnectionManager(websocket)
  405. self.connection_managers[client_id] = connection_manager
  406. # Set server_id so send_client_update_to_peer can filter correctly
  407. client_info.server_id = self._server_info.server_id
  408. self.client_registry[client_id] = client_info
  409. # Index CIDRs for efficient lookup
  410. for cidr in cidr_list:
  411. self._cidr_registry.insert(cidr, client_id)
  412. logger.debug(
  413. f"[Server] Client registered via WS: {client_id}, CIDRs: {cidr_list}"
  414. )
  415. # Broadcast new client to peers
  416. await self.broadcast_client_update("add", client_id, cidr_list, socket_list)
  417. # Get generation for this connection (used to filter stale disconnect callbacks)
  418. generation = await self._get_next_generation(client_id)
  419. if self._callback_on_connect:
  420. await self._safe_callback(self._callback_on_connect, None, client_info)
  421. await self.handle_client(websocket, client_id, generation)
  422. async def handle_server_federation(
  423. self,
  424. websocket: WebSocket,
  425. server_info: ServerInfo,
  426. ):
  427. """Handle incoming server-to-server federation connection"""
  428. our_server_id = self._server_info.server_id
  429. logger.debug(
  430. f"[Server] handle_server_federation: incoming={server_info.server_id}, self={our_server_id}"
  431. )
  432. # Prevent adding self as peer
  433. if server_info.server_id == our_server_id:
  434. logger.debug("[Server] Ignoring self-connection attempt")
  435. await websocket.close()
  436. return
  437. # Accept with our server info in response headers
  438. await websocket.accept(headers=self._server_info.to_bytes_headers())
  439. # Add to serving_peers (incoming connections)
  440. peer = ServerPeer(
  441. server_id=server_info.server_id,
  442. listen_address=server_info.listen_address,
  443. listen_port=server_info.listen_port,
  444. proxy_port=server_info.proxy_port,
  445. websocket=websocket,
  446. connected=True,
  447. )
  448. self.serving_peers[server_info.server_id] = peer
  449. logger.debug(f"[Server] Serving peer connected: {server_info.server_id}")
  450. if self._callback_on_connect:
  451. await self._safe_callback(self._callback_on_connect, server_info, None)
  452. # Send existing clients to new peer
  453. await self.send_client_update_to_peer(websocket, "add")
  454. await self.handle_peer(websocket, server_info.server_id)
  455. async def handle_client(
  456. self, websocket: WebSocket, client_id: uuid.UUID, generation: int
  457. ):
  458. """Handle a client connection"""
  459. try:
  460. while True:
  461. try:
  462. message = await websocket.receive()
  463. if message.get("type") == "websocket.disconnect":
  464. break
  465. if "text" in message:
  466. raw_data = (
  467. message["text"].encode()
  468. if isinstance(message["text"], str)
  469. else message["text"]
  470. )
  471. elif "bytes" in message:
  472. raw_data = message["bytes"]
  473. else:
  474. continue
  475. msg = parse_message(raw_data)
  476. logger.trace(f"[Server] Received: {msg.get_type()}")
  477. if isinstance(msg, SessionBaseMessage):
  478. if client_id in self.connection_managers:
  479. await self.connection_managers[client_id].dispatch(msg)
  480. elif isinstance(msg, HeartbeatMessage):
  481. response = HeartbeatMessage(timestamp=msg.timestamp)
  482. await websocket.send_bytes(pack_message(response))
  483. except asyncio.CancelledError:
  484. # Client is being closed, exit gracefully
  485. break
  486. except Exception as e:
  487. logger.debug(f"[Server] Error processing message: {e}")
  488. except Exception as e:
  489. logger.debug(f"[Server] Client error: {e}")
  490. finally:
  491. # Get client info before removing
  492. client_info = self.client_registry.get(client_id)
  493. cidr_list = client_info.cidrs if client_info else []
  494. socket_list = client_info.unix_sockets if client_info else []
  495. if client_id and client_id in self.client_registry:
  496. self._cidr_registry.remove_client(client_id)
  497. del self.client_registry[client_id]
  498. if client_id and client_id in self.connection_managers:
  499. del self.connection_managers[client_id]
  500. # Broadcast client disconnection to peers
  501. await self.broadcast_client_update(
  502. "remove", client_id, cidr_list, socket_list
  503. )
  504. # Call disconnect callback (filtered by generation to avoid stale callbacks)
  505. if self._callback_on_disconnect:
  506. await self._safe_disconnect_callback(
  507. self._callback_on_disconnect, client_info, generation
  508. )
  509. logger.debug(f"[Server] Client disconnected: {client_id}")
  510. def handler_getter(websocket: WebSocket) -> MessageServerHandler:
  511. """FastAPI dependency: return the ``MessageServerHandler`` from ``app.state``."""
  512. return getattr(websocket.app.state, "message_server_handler", None)
  513. def authenticator_getter(websocket: WebSocket) -> Authenticator:
  514. """FastAPI dependency: return the ``Authenticator`` from ``app.state``.
  515. Falls back to ``NoOpAuthenticator`` (accepts all connections) when no
  516. authenticator has been configured on the application.
  517. """
  518. authenticator: Authenticator = getattr(
  519. websocket.app.state, "websocket_authenticator", NoOpAuthenticator()
  520. )
  521. return authenticator
  522. router = APIRouter()
  523. @router.websocket(default_connect_path)
  524. async def websocket_endpoint(
  525. websocket: WebSocket,
  526. handler: MessageServerHandler = Depends(handler_getter),
  527. authenticator: Authenticator = Depends(authenticator_getter),
  528. ):
  529. """WebSocket endpoint - handles both client and server connections"""
  530. try:
  531. if not await authenticator.authenticate(websocket):
  532. logger.debug("[Server] Authentication failed for connection")
  533. await websocket.close(code=4001, reason="Authentication failed")
  534. return
  535. except Exception as e:
  536. logger.debug(f"Server failed with: {e}")
  537. await websocket.close(code=1008, reason="Server Error")
  538. return
  539. # Check if this is a server federation connection
  540. server_info = handler._server_info_getter(websocket)
  541. if server_info:
  542. logger.debug(
  543. f"[Server] Detected server federation connection: {server_info.server_id}"
  544. )
  545. await handler.handle_server_federation(websocket, server_info)
  546. return
  547. client_info = handler._client_info_getter(websocket)
  548. if client_info:
  549. logger.debug(f"[Server] Detected client connection: {client_info.client_id}")
  550. await handler.handle_client_connection(websocket, client_info)
  551. return
  552. logger.debug(
  553. "[Server] No valid server or client info found in headers, rejecting connection"
  554. )
  555. await websocket.close(code=1008, reason="not a valid client or server connection")