from __future__ import annotations import asyncio import logging from fastapi import WebSocket from fastapi.websockets import WebSocketDisconnect logger = logging.getLogger(__name__) class WebSocketHub: def __init__(self) -> None: self._connections: set[WebSocket] = set() self._lock = asyncio.Lock() async def connect(self, ws: WebSocket) -> None: await ws.accept() async with self._lock: self._connections.add(ws) def disconnect(self, ws: WebSocket) -> None: self._connections.discard(ws) async def broadcast(self, message: dict) -> None: async with self._lock: targets = set(self._connections) dead: set[WebSocket] = set() for ws in targets: try: await ws.send_json(message) except WebSocketDisconnect: dead.add(ws) except Exception as exc: logger.warning("WebSocket send error: %s", exc) dead.add(ws) if dead: async with self._lock: self._connections -= dead hub = WebSocketHub()