import asyncio import logging from typing import Any from fastapi import APIRouter, WebSocket, WebSocketDisconnect from app.core.logging import logger router = APIRouter() # 连接管理器:job_id -> WebSocket 连接集合 _connections: dict[str, set[WebSocket]] = {} _lock = asyncio.Lock() async def _broadcast(job_id: str, message: dict[str, Any]) -> None: """向订阅该 job 的所有客户端广播消息。""" async with _lock: conns = _connections.get(job_id, set()).copy() if not conns: return import json payload = json.dumps(message) disconnected = set() for ws in conns: try: await ws.send_text(payload) except Exception: disconnected.add(ws) # 清理失效连接 if disconnected: async with _lock: _connections.get(job_id, set()).difference_update(disconnected) async def send_progress(job_id: str, **kwargs: Any) -> None: await _broadcast(job_id, {"type": "progress", "job_id": job_id, **kwargs}) async def send_epoch_done(job_id: str, **kwargs: Any) -> None: await _broadcast(job_id, {"type": "epoch_done", "job_id": job_id, **kwargs}) async def send_completed(job_id: str, **kwargs: Any) -> None: await _broadcast(job_id, {"type": "completed", "job_id": job_id, **kwargs}) async def send_error(job_id: str, message: str) -> None: await _broadcast(job_id, {"type": "error", "job_id": job_id, "message": message}) async def send_heartbeat(job_id: str) -> None: await _broadcast(job_id, {"type": "heartbeat", "job_id": job_id, "timestamp": datetime_now()}) def datetime_now() -> str: from datetime import datetime, timezone return datetime.now(timezone.utc).isoformat() @router.websocket("/ws/training/{job_id}") async def training_websocket(websocket: WebSocket, job_id: str) -> None: await websocket.accept() async with _lock: _connections.setdefault(job_id, set()).add(websocket) logger.info(f"客户端已连接到训练 WebSocket (job {job_id})") try: while True: # 保持连接;客户端可发送 "ping" 或取消请求 data = await websocket.receive_text() if data == "ping": await websocket.send_text('{"type":"pong"}') except WebSocketDisconnect: logger.info(f"客户端已从训练 WebSocket 断开 (job {job_id})") finally: async with _lock: _connections.get(job_id, set()).discard(websocket)