| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import asyncio
- import logging
- from typing import Any
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect
- from app.core.logging import logger
- router = APIRouter()
- # 连接管理器:job_id -> WebSocket 连接集合
- _connections: dict[str, set[WebSocket]] = {}
- _lock = asyncio.Lock()
- async def _broadcast(job_id: str, message: dict[str, Any]) -> None:
- """向订阅该 job 的所有客户端广播消息。"""
- async with _lock:
- conns = _connections.get(job_id, set()).copy()
- if not conns:
- return
- import json
- payload = json.dumps(message)
- disconnected = set()
- for ws in conns:
- try:
- await ws.send_text(payload)
- except Exception:
- disconnected.add(ws)
- # 清理失效连接
- if disconnected:
- async with _lock:
- _connections.get(job_id, set()).difference_update(disconnected)
- async def send_progress(job_id: str, **kwargs: Any) -> None:
- await _broadcast(job_id, {"type": "progress", "job_id": job_id, **kwargs})
- async def send_epoch_done(job_id: str, **kwargs: Any) -> None:
- await _broadcast(job_id, {"type": "epoch_done", "job_id": job_id, **kwargs})
- async def send_completed(job_id: str, **kwargs: Any) -> None:
- await _broadcast(job_id, {"type": "completed", "job_id": job_id, **kwargs})
- async def send_error(job_id: str, message: str) -> None:
- await _broadcast(job_id, {"type": "error", "job_id": job_id, "message": message})
- async def send_heartbeat(job_id: str) -> None:
- await _broadcast(job_id, {"type": "heartbeat", "job_id": job_id, "timestamp": datetime_now()})
- def datetime_now() -> str:
- from datetime import datetime, timezone
- return datetime.now(timezone.utc).isoformat()
- @router.websocket("/ws/training/{job_id}")
- async def training_websocket(websocket: WebSocket, job_id: str) -> None:
- token = websocket.query_params.get("token")
- if token:
- try:
- from app.core.security import decode_access_token
- payload = decode_access_token(token)
- if payload.get("type") != "access":
- await websocket.close(code=4001, reason="Invalid token")
- return
- except Exception:
- await websocket.close(code=4001, reason="Token expired or invalid")
- return
- else:
- await websocket.close(code=4001, reason="Authentication required")
- return
- await websocket.accept()
- async with _lock:
- _connections.setdefault(job_id, set()).add(websocket)
- logger.info(f"客户端已连接到训练 WebSocket (job {job_id})")
- try:
- while True:
- # 保持连接;客户端可发送 "ping" 或取消请求
- data = await websocket.receive_text()
- if data == "ping":
- await websocket.send_text('{"type":"pong"}')
- except WebSocketDisconnect:
- logger.info(f"客户端已从训练 WebSocket 断开 (job {job_id})")
- finally:
- async with _lock:
- _connections.get(job_id, set()).discard(websocket)
|