""" 审查进度SSE实时推送接口 """ import json import asyncio from typing import Dict from datetime import datetime from pydantic import BaseModel from fastapi import APIRouter, Query from .schemas.error_schemas import TaskProgressErrors from fastapi.responses import StreamingResponse from foundation.logger.loggering import server_logger as logger from foundation.trace.trace_context import TraceContext, auto_trace from core.base.progress_manager import ProgressManager, sse_callback_manager progress_manager = ProgressManager() task_progress_router = APIRouter(prefix="/sgsc", tags=["进度推送"]) async def sse_progress_callback(callback_task_id: str, current_data: dict): """SSE推送回调函数 - 接收进度更新并推送到客户端""" await sse_manager.send_progress(callback_task_id, current_data) class TaskProgressResponse(BaseModel): code: int data: dict class SimpleSSEManager: """SSE连接管理器 - 管理客户端SSE连接和消息推送""" def __init__(self): self.connections: Dict[str, asyncio.Queue] = {} async def connect(self, callback_task_id: str): """建立SSE连接 - 创建消息队列并发送连接确认""" queue = asyncio.Queue() self.connections[callback_task_id] = queue await queue.put({ "type": "connection_established", "callback_task_id": callback_task_id, "timestamp": datetime.now().isoformat() }) logger.info(f"SSE连接: {callback_task_id}") return queue async def disconnect(self, callback_task_id: str): """断开SSE连接 - 清理连接队列""" if callback_task_id in self.connections: del self.connections[callback_task_id] logger.info(f"SSE连接已断开: {callback_task_id}") async def send_progress(self, callback_task_id: str, current_data: dict): """发送进度更新 - 将进度数据放入队列推送给客户端""" queue = self.connections.get(callback_task_id) if queue: await queue.put({ "type": "progress_update", "data": current_data, "timestamp": datetime.now().isoformat() }) logger.debug(f"SSE进度已推送: {callback_task_id}") sse_manager = SimpleSSEManager() def format_sse_event(event_type: str, data: str) -> str: """格式化SSE事件 - 按照SSE协议格式化事件数据""" lines = [ f"event: {event_type}", f"data: {data}", "", "" ] return "\n".join(lines) + "\n" @task_progress_router.get("/sse/current/{callback_task_id}") @auto_trace("callback_task_id") async def sse_progress_stream( callback_task_id: str, user: str = Query(..., description="用户标识") ): """SSE实时进度推送接口 - 建立SSE连接并实时推送任务进度""" try: valid_users = {"user-001", "user-002", "user-003"} if user not in valid_users: raise TaskProgressErrors.invalid_user() sse_callback_manager.register_callback(callback_task_id, sse_progress_callback) queue = await sse_manager.connect(callback_task_id) async def generate_events(): """生成SSE事件流 - 处理连接确认、进度推送和任务完成检测""" try: logger.info(f"开始SSE事件流: {callback_task_id}") connected_data = json.dumps({ "callback_task_id": callback_task_id, "message": "SSE连接已建立,等待进度更新...", "timestamp": datetime.now().isoformat() }, ensure_ascii=False) yield format_sse_event("connected", connected_data) current_progress = await progress_manager.get_progress(callback_task_id) if current_progress: progress_json = json.dumps(current_progress, ensure_ascii=False) yield format_sse_event("current", progress_json) logger.debug(f"开始监听队列中的进度更新: {callback_task_id}") while True: try: message = await queue.get() if message.get("type") == "progress_update": current_data = message.get("data") if current_data: logger.info(f"总流程处理进度: {current_data.get("message")}") progress_json = json.dumps(current_data, ensure_ascii=False) yield format_sse_event("current", progress_json) overall_task_status = current_data.get("overall_task_status") if overall_task_status in ["completed", "failed"]: completion_data = { "callback_task_id": callback_task_id, "task_status": overall_task_status, "overall_progress": current_data.get("current", 100), "timestamp": datetime.now().isoformat(), "message": "全部任务完成!" } completion_json = json.dumps(completion_data, ensure_ascii=False) yield format_sse_event("completed", completion_json) logger.info(f"全部任务完成,断开SSE连接: {callback_task_id}, 状态: {overall_task_status}") break elif message.get("type") == "connection_established": pass except Exception as e: logger.error(f"队列消息处理异常: {callback_task_id}, {e}") break except Exception as e: logger.error(f"SSE事件流异常: {callback_task_id}, {e}") error_data = json.dumps({ "error": f"SSE异常: {str(e)}", "timestamp": datetime.now().isoformat() }, ensure_ascii=False) yield format_sse_event("error", error_data) finally: sse_callback_manager.unregister_callback(callback_task_id) await sse_manager.disconnect(callback_task_id) logger.debug(f"SSE流已结束: {callback_task_id}") return StreamingResponse( generate_events(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-store, must-revalidate", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control, EventSource", "Access-Control-Allow-Methods": "GET, POST, OPTIONS", "X-Accel-Buffering": "no", "X-Content-Type-Options": "nosniff" } ) except Exception as e: logger.error(f"SSE连接失败: {callback_task_id}, {e}") raise TaskProgressErrors.server_internal_error(e) @task_progress_router.get("/sse/status") async def get_sse_status(): """获取SSE连接状态 - 返回当前活跃的SSE连接信息""" return { "active_connections": len(sse_manager.connections), "connections": list(sse_manager.connections.keys()), "timestamp": datetime.now().isoformat() }