| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- """
- 审查进度SSE实时推送接口
- """
- import json
- import asyncio
- from typing import Dict
- from datetime import datetime
- from pydantic import BaseModel
- from fastapi import APIRouter, Query
- from .schemas.error_schemas import LaunchReviewErrors
- from fastapi.responses import StreamingResponse
- from foundation.logger.loggering import server_logger as logger
- from foundation.trace.trace_context import TraceContext, auto_trace
- from core.base.progress_manager import ProgressManager, sse_callback_manager
- progress_manager = ProgressManager()
- task_progress_router = APIRouter(prefix="/sgsc", tags=["进度推送"])
- async def sse_progress_callback(callback_task_id: str, current_data: dict):
- """SSE推送回调函数 - 接收进度更新并推送到客户端"""
- await sse_manager.send_progress(callback_task_id, current_data)
- class TaskProgressResponse(BaseModel):
- code: int
- data: dict
- class SimpleSSEManager:
- """SSE连接管理器 - 管理客户端SSE连接和消息推送"""
- def __init__(self):
- self.connections: Dict[str, asyncio.Queue] = {}
- async def connect(self, callback_task_id: str):
- """建立SSE连接 - 创建消息队列并发送连接确认"""
- queue = asyncio.Queue()
- self.connections[callback_task_id] = queue
- await queue.put({
- "type": "connection_established",
- "callback_task_id": callback_task_id,
- "timestamp": datetime.now().isoformat()
- })
- logger.info(f"SSE连接: {callback_task_id}")
- return queue
- async def disconnect(self, callback_task_id: str):
- """断开SSE连接 - 清理连接队列"""
- if callback_task_id in self.connections:
- del self.connections[callback_task_id]
- logger.info(f"SSE连接已断开: {callback_task_id}")
- async def send_progress(self, callback_task_id: str, current_data: dict):
- """发送进度更新 - 将进度数据放入队列推送给客户端"""
- queue = self.connections.get(callback_task_id)
- if queue:
- await queue.put({
- "type": "progress_update",
- "data": current_data,
- "timestamp": datetime.now().isoformat()
- })
- logger.debug(f"SSE进度已推送: {callback_task_id}")
- sse_manager = SimpleSSEManager()
- def format_sse_event(event_type: str, data: str) -> str:
- """格式化SSE事件 - 按照SSE协议格式化事件数据"""
- lines = [
- f"event: {event_type}",
- f"data: {data}",
- "",
- ""
- ]
- return "\n".join(lines) + "\n"
- @task_progress_router.get("/sse/progress/{callback_task_id}")
- @auto_trace("callback_task_id")
- async def sse_progress_stream(
- callback_task_id: str,
- user: str = Query(..., description="用户标识")
- ):
- """SSE实时进度推送接口 - 建立SSE连接并实时推送任务进度"""
- try:
- valid_users = {"user-001", "user-002", "user-003"}
- if user not in valid_users:
- raise LaunchReviewErrors.invalid_user()
- sse_callback_manager.register_callback(callback_task_id, sse_progress_callback)
- queue = await sse_manager.connect(callback_task_id)
- async def generate_events():
- """生成SSE事件流 - 处理连接确认、进度推送和任务完成检测"""
- try:
- logger.info(f"开始SSE事件流: {callback_task_id}")
- connected_data = json.dumps({
- "callback_task_id": callback_task_id,
- "message": "SSE连接已建立,等待进度更新...",
- "timestamp": datetime.now().isoformat()
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- current_progress = await progress_manager.get_progress(callback_task_id)
- if current_progress:
- progress_json = json.dumps(current_progress, ensure_ascii=False)
- yield format_sse_event("current", progress_json)
- logger.debug(f"开始监听队列中的进度更新: {callback_task_id}")
- while True:
- try:
- message = await queue.get()
- if message.get("type") == "progress_update":
- current_data = message.get("data")
- if current_data:
- logger.info(f"总流程处理进度: {current_data.get("message")}")
- progress_json = json.dumps(current_data, ensure_ascii=False)
- yield format_sse_event("current", progress_json)
- overall_task_status = current_data.get("overall_task_status")
- if overall_task_status in ["completed", "failed"]:
- completion_data = {
- "callback_task_id": callback_task_id,
- "task_status": overall_task_status,
- "overall_progress": current_data.get("current", 100),
- "timestamp": datetime.now().isoformat(),
- "message": "全部任务完成!"
- }
- completion_json = json.dumps(completion_data, ensure_ascii=False)
- yield format_sse_event("completed", completion_json)
- logger.info(f"全部任务完成,断开SSE连接: {callback_task_id}, 状态: {overall_task_status}")
- break
- elif message.get("type") == "connection_established":
- pass
- except Exception as e:
- logger.error(f"队列消息处理异常: {callback_task_id}, {e}")
- break
- except Exception as e:
- logger.error(f"SSE事件流异常: {callback_task_id}, {e}")
- error_data = json.dumps({
- "error": f"SSE异常: {str(e)}",
- "timestamp": datetime.now().isoformat()
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- sse_callback_manager.unregister_callback(callback_task_id)
- await sse_manager.disconnect(callback_task_id)
- logger.debug(f"SSE流已结束: {callback_task_id}")
- return StreamingResponse(
- generate_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache, no-store, must-revalidate",
- "Connection": "keep-alive",
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Headers": "Cache-Control, EventSource",
- "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
- "X-Accel-Buffering": "no",
- "X-Content-Type-Options": "nosniff"
- }
- )
- except Exception as e:
- logger.error(f"SSE连接失败: {callback_task_id}, {e}")
- raise LaunchReviewErrors.internal_error(e)
- @task_progress_router.get("/sse/status")
- async def get_sse_status():
- """获取SSE连接状态 - 返回当前活跃的SSE连接信息"""
- return {
- "active_connections": len(sse_manager.connections),
- "connections": list(sse_manager.connections.keys()),
- "timestamp": datetime.now().isoformat()
- }
|