task_progress.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. """
  2. 审查进度SSE实时推送接口
  3. """
  4. import json
  5. import asyncio
  6. from typing import Dict
  7. from datetime import datetime
  8. from pydantic import BaseModel
  9. from fastapi import APIRouter, Query
  10. from .schemas.error_schemas import TaskProgressErrors
  11. from fastapi.responses import StreamingResponse
  12. from foundation.logger.loggering import server_logger as logger
  13. from foundation.trace.trace_context import TraceContext, auto_trace
  14. from core.base.progress_manager import ProgressManager, sse_callback_manager
  15. progress_manager = ProgressManager()
  16. task_progress_router = APIRouter(prefix="/sgsc", tags=["进度推送"])
  17. async def sse_progress_callback(callback_task_id: str, current_data: dict):
  18. """SSE推送回调函数 - 接收进度更新并推送到客户端"""
  19. await sse_manager.send_progress(callback_task_id, current_data)
  20. class TaskProgressResponse(BaseModel):
  21. code: int
  22. data: dict
  23. class SimpleSSEManager:
  24. """SSE连接管理器 - 管理客户端SSE连接和消息推送"""
  25. def __init__(self):
  26. self.connections: Dict[str, asyncio.Queue] = {}
  27. async def connect(self, callback_task_id: str):
  28. """建立SSE连接 - 创建消息队列并发送连接确认"""
  29. queue = asyncio.Queue()
  30. self.connections[callback_task_id] = queue
  31. await queue.put({
  32. "type": "connection_established",
  33. "callback_task_id": callback_task_id,
  34. "timestamp": datetime.now().isoformat()
  35. })
  36. logger.info(f"SSE连接: {callback_task_id}")
  37. return queue
  38. async def disconnect(self, callback_task_id: str):
  39. """断开SSE连接 - 清理连接队列"""
  40. if callback_task_id in self.connections:
  41. del self.connections[callback_task_id]
  42. logger.info(f"SSE连接已断开: {callback_task_id}")
  43. async def send_progress(self, callback_task_id: str, current_data: dict):
  44. """发送进度更新 - 将进度数据放入队列推送给客户端"""
  45. queue = self.connections.get(callback_task_id)
  46. if queue:
  47. await queue.put({
  48. "type": "progress_update",
  49. "data": current_data,
  50. "timestamp": datetime.now().isoformat()
  51. })
  52. logger.debug(f"SSE进度已推送: {callback_task_id}")
  53. sse_manager = SimpleSSEManager()
  54. def format_sse_event(event_type: str, data: str) -> str:
  55. """格式化SSE事件 - 按照SSE协议格式化事件数据"""
  56. lines = [
  57. f"event: {event_type}",
  58. f"data: {data}",
  59. "",
  60. ""
  61. ]
  62. return "\n".join(lines) + "\n"
  63. @task_progress_router.get("/sse/current/{callback_task_id}")
  64. @auto_trace("callback_task_id")
  65. async def sse_progress_stream(
  66. callback_task_id: str,
  67. user: str = Query(..., description="用户标识")
  68. ):
  69. """SSE实时进度推送接口 - 建立SSE连接并实时推送任务进度"""
  70. try:
  71. valid_users = {"user-001", "user-002", "user-003"}
  72. if user not in valid_users:
  73. raise TaskProgressErrors.invalid_user()
  74. sse_callback_manager.register_callback(callback_task_id, sse_progress_callback)
  75. queue = await sse_manager.connect(callback_task_id)
  76. async def generate_events():
  77. """生成SSE事件流 - 处理连接确认、进度推送和任务完成检测"""
  78. try:
  79. logger.info(f"开始SSE事件流: {callback_task_id}")
  80. connected_data = json.dumps({
  81. "callback_task_id": callback_task_id,
  82. "message": "SSE连接已建立,等待进度更新...",
  83. "timestamp": datetime.now().isoformat()
  84. }, ensure_ascii=False)
  85. yield format_sse_event("connected", connected_data)
  86. current_progress = await progress_manager.get_progress(callback_task_id)
  87. if current_progress:
  88. progress_json = json.dumps(current_progress, ensure_ascii=False)
  89. yield format_sse_event("current", progress_json)
  90. logger.debug(f"开始监听队列中的进度更新: {callback_task_id}")
  91. while True:
  92. try:
  93. message = await queue.get()
  94. if message.get("type") == "progress_update":
  95. current_data = message.get("data")
  96. if current_data:
  97. logger.info(f"总流程处理进度: {current_data.get("message")}")
  98. progress_json = json.dumps(current_data, ensure_ascii=False)
  99. yield format_sse_event("current", progress_json)
  100. overall_task_status = current_data.get("overall_task_status")
  101. if overall_task_status in ["completed", "failed"]:
  102. completion_data = {
  103. "callback_task_id": callback_task_id,
  104. "task_status": overall_task_status,
  105. "overall_progress": current_data.get("current", 100),
  106. "timestamp": datetime.now().isoformat(),
  107. "message": "全部任务完成!"
  108. }
  109. completion_json = json.dumps(completion_data, ensure_ascii=False)
  110. yield format_sse_event("completed", completion_json)
  111. logger.info(f"全部任务完成,断开SSE连接: {callback_task_id}, 状态: {overall_task_status}")
  112. break
  113. elif message.get("type") == "connection_established":
  114. pass
  115. except Exception as e:
  116. logger.error(f"队列消息处理异常: {callback_task_id}, {e}")
  117. break
  118. except Exception as e:
  119. logger.error(f"SSE事件流异常: {callback_task_id}, {e}")
  120. error_data = json.dumps({
  121. "error": f"SSE异常: {str(e)}",
  122. "timestamp": datetime.now().isoformat()
  123. }, ensure_ascii=False)
  124. yield format_sse_event("error", error_data)
  125. finally:
  126. sse_callback_manager.unregister_callback(callback_task_id)
  127. await sse_manager.disconnect(callback_task_id)
  128. logger.debug(f"SSE流已结束: {callback_task_id}")
  129. return StreamingResponse(
  130. generate_events(),
  131. media_type="text/event-stream",
  132. headers={
  133. "Cache-Control": "no-cache, no-store, must-revalidate",
  134. "Connection": "keep-alive",
  135. "Access-Control-Allow-Origin": "*",
  136. "Access-Control-Allow-Headers": "Cache-Control, EventSource",
  137. "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
  138. "X-Accel-Buffering": "no",
  139. "X-Content-Type-Options": "nosniff"
  140. }
  141. )
  142. except Exception as e:
  143. logger.error(f"SSE连接失败: {callback_task_id}, {e}")
  144. raise TaskProgressErrors.server_internal_error(e)
  145. @task_progress_router.get("/sse/status")
  146. async def get_sse_status():
  147. """获取SSE连接状态 - 返回当前活跃的SSE连接信息"""
  148. return {
  149. "active_connections": len(sse_manager.connections),
  150. "connections": list(sse_manager.connections.keys()),
  151. "timestamp": datetime.now().isoformat()
  152. }