websocket.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import asyncio
  2. import logging
  3. from typing import Any
  4. from fastapi import APIRouter, WebSocket, WebSocketDisconnect
  5. from app.core.logging import logger
  6. router = APIRouter()
  7. # 连接管理器:job_id -> WebSocket 连接集合
  8. _connections: dict[str, set[WebSocket]] = {}
  9. _lock = asyncio.Lock()
  10. async def _broadcast(job_id: str, message: dict[str, Any]) -> None:
  11. """向订阅该 job 的所有客户端广播消息。"""
  12. async with _lock:
  13. conns = _connections.get(job_id, set()).copy()
  14. if not conns:
  15. return
  16. import json
  17. payload = json.dumps(message)
  18. disconnected = set()
  19. for ws in conns:
  20. try:
  21. await ws.send_text(payload)
  22. except Exception:
  23. disconnected.add(ws)
  24. # 清理失效连接
  25. if disconnected:
  26. async with _lock:
  27. _connections.get(job_id, set()).difference_update(disconnected)
  28. async def send_progress(job_id: str, **kwargs: Any) -> None:
  29. await _broadcast(job_id, {"type": "progress", "job_id": job_id, **kwargs})
  30. async def send_epoch_done(job_id: str, **kwargs: Any) -> None:
  31. await _broadcast(job_id, {"type": "epoch_done", "job_id": job_id, **kwargs})
  32. async def send_completed(job_id: str, **kwargs: Any) -> None:
  33. await _broadcast(job_id, {"type": "completed", "job_id": job_id, **kwargs})
  34. async def send_error(job_id: str, message: str) -> None:
  35. await _broadcast(job_id, {"type": "error", "job_id": job_id, "message": message})
  36. async def send_heartbeat(job_id: str) -> None:
  37. await _broadcast(job_id, {"type": "heartbeat", "job_id": job_id, "timestamp": datetime_now()})
  38. def datetime_now() -> str:
  39. from datetime import datetime, timezone
  40. return datetime.now(timezone.utc).isoformat()
  41. @router.websocket("/ws/training/{job_id}")
  42. async def training_websocket(websocket: WebSocket, job_id: str) -> None:
  43. token = websocket.query_params.get("token")
  44. if token:
  45. try:
  46. from app.core.security import decode_access_token
  47. payload = decode_access_token(token)
  48. if payload.get("type") != "access":
  49. await websocket.close(code=4001, reason="Invalid token")
  50. return
  51. except Exception:
  52. await websocket.close(code=4001, reason="Token expired or invalid")
  53. return
  54. else:
  55. await websocket.close(code=4001, reason="Authentication required")
  56. return
  57. await websocket.accept()
  58. async with _lock:
  59. _connections.setdefault(job_id, set()).add(websocket)
  60. logger.info(f"客户端已连接到训练 WebSocket (job {job_id})")
  61. try:
  62. while True:
  63. # 保持连接;客户端可发送 "ping" 或取消请求
  64. data = await websocket.receive_text()
  65. if data == "ping":
  66. await websocket.send_text('{"type":"pong"}')
  67. except WebSocketDisconnect:
  68. logger.info(f"客户端已从训练 WebSocket 断开 (job {job_id})")
  69. finally:
  70. async with _lock:
  71. _connections.get(job_id, set()).discard(websocket)