ws_hub.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. from fastapi import WebSocket
  5. from fastapi.websockets import WebSocketDisconnect
  6. logger = logging.getLogger(__name__)
  7. class WebSocketHub:
  8. def __init__(self) -> None:
  9. self._connections: set[WebSocket] = set()
  10. self._lock = asyncio.Lock()
  11. async def connect(self, ws: WebSocket) -> None:
  12. await ws.accept()
  13. async with self._lock:
  14. self._connections.add(ws)
  15. def disconnect(self, ws: WebSocket) -> None:
  16. self._connections.discard(ws)
  17. async def broadcast(self, message: dict) -> None:
  18. async with self._lock:
  19. targets = set(self._connections)
  20. dead: set[WebSocket] = set()
  21. for ws in targets:
  22. try:
  23. await ws.send_json(message)
  24. except WebSocketDisconnect:
  25. dead.add(ws)
  26. except Exception as exc:
  27. logger.warning("WebSocket send error: %s", exc)
  28. dead.add(ws)
  29. if dead:
  30. async with self._lock:
  31. self._connections -= dead
  32. hub = WebSocketHub()