| 1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- from __future__ import annotations
- import asyncio
- import logging
- from fastapi import WebSocket
- from fastapi.websockets import WebSocketDisconnect
- logger = logging.getLogger(__name__)
- class WebSocketHub:
- def __init__(self) -> None:
- self._connections: set[WebSocket] = set()
- self._lock = asyncio.Lock()
- async def connect(self, ws: WebSocket) -> None:
- await ws.accept()
- async with self._lock:
- self._connections.add(ws)
- def disconnect(self, ws: WebSocket) -> None:
- self._connections.discard(ws)
- async def broadcast(self, message: dict) -> None:
- async with self._lock:
- targets = set(self._connections)
- dead: set[WebSocket] = set()
- for ws in targets:
- try:
- await ws.send_json(message)
- except WebSocketDisconnect:
- dead.add(ws)
- except Exception as exc:
- logger.warning("WebSocket send error: %s", exc)
- dead.add(ws)
- if dead:
- async with self._lock:
- self._connections -= dead
- hub = WebSocketHub()
|