sse_manager.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File : sse_manager.py
  6. @IDE : VsCode
  7. @Author :
  8. @Date : 2025-12-04 10:58:00
  9. =================================
  10. 📋 统一SSE管理器 (Unified SSE Manager)
  11. 🏗️ 核心功能:
  12. ├── UnifiedSSEManager() # 统一SSE管理器(单例)
  13. ├── establish_connection() # 建立连接并注册回调
  14. ├── close_connection() # 关闭连接(清理连接和回调)
  15. ├── send_progress() # 发送进度消息
  16. └── trigger_callback() # 触发回调函数
  17. 📊 状态管理:
  18. ├── connections # 消息队列字典
  19. ├── callbacks # 回调函数字典
  20. └── get_connection_count() # 获取连接数统计
  21. 🔧 实用方法:
  22. ├── is_connected() # 检查连接是否存在
  23. ├── is_callback_registered() # 检查回调是否已注册
  24. ├── get_stats() # 获取详细统计信息
  25. └── clear_all() # 清理所有连接和回调
  26. '''
  27. import asyncio
  28. from typing import Dict, Any, Optional, Callable
  29. from datetime import datetime
  30. from foundation.observability.logger.loggering import review_logger as logger
  31. class UnifiedSSEManager:
  32. """
  33. 统一的SSE管理器 - 管理SSE连接、回调函数和消息推送
  34. 功能:
  35. 1. 管理SSE消息队列连接
  36. 2. 管理回调函数注册和触发
  37. 3. 提供统一的消息推送接口
  38. 4. 确保连接和回调状态同步
  39. """
  40. _instance = None
  41. def __new__(cls):
  42. """单例模式实现"""
  43. if cls._instance is None:
  44. cls._instance = super().__new__(cls)
  45. cls._instance.connections = {} # 消息队列字典
  46. cls._instance.callbacks = {} # 回调函数字典
  47. return cls._instance
  48. def __init__(self):
  49. """初始化统一SSE管理器"""
  50. pass # 在__new__中已完成初始化
  51. async def establish_connection(self, callback_task_id: str, callback_func: Optional[Callable] = None):
  52. """
  53. 建立SSE连接并注册回调函数
  54. Args:
  55. callback_task_id: 回调任务ID
  56. callback_func: 可选的回调函数
  57. Returns:
  58. asyncio.Queue: 消息队列,用于SSE事件流
  59. """
  60. try:
  61. # 创建消息队列
  62. queue = asyncio.Queue()
  63. self.connections[callback_task_id] = queue
  64. # 注册回调函数(如果提供)
  65. if callback_func:
  66. self.callbacks[callback_task_id] = callback_func
  67. # 发送连接建立确认消息
  68. await queue.put({
  69. "type": "connection_established",
  70. "callback_task_id": callback_task_id,
  71. "timestamp": datetime.now().isoformat()
  72. })
  73. logger.info(f"SSE连接已建立: {callback_task_id}")
  74. logger.info(f"当前连接数: {len(self.connections)}, 回调数: {len(self.callbacks)}")
  75. return queue
  76. except Exception as e:
  77. logger.error(f"建立SSE连接失败: {callback_task_id}, 错误: {str(e)}")
  78. raise
  79. async def close_connection(self, callback_task_id: str):
  80. """
  81. 关闭SSE连接(同时清理连接和回调)
  82. Args:
  83. callback_task_id: 回调任务ID
  84. """
  85. try:
  86. connection_existed = False
  87. callback_existed = False
  88. # 1. 先向队列发送结束信号,让SSE流能够正常结束
  89. if callback_task_id in self.connections:
  90. queue = self.connections[callback_task_id]
  91. try:
  92. await queue.put({
  93. "type": "connection_closed",
  94. "callback_task_id": callback_task_id,
  95. "timestamp": datetime.now().isoformat()
  96. })
  97. logger.info(f"已发送连接关闭信号到队列: {callback_task_id}")
  98. except Exception as queue_error:
  99. logger.warning(f"发送关闭信号失败,队列可能已关闭: {callback_task_id}, 错误: {str(queue_error)}")
  100. # 2. 清理连接
  101. if callback_task_id in self.connections:
  102. del self.connections[callback_task_id]
  103. connection_existed = True
  104. logger.info(f"SSE连接已断开: {callback_task_id}")
  105. # 3. 清理回调
  106. if callback_task_id in self.callbacks:
  107. del self.callbacks[callback_task_id]
  108. callback_existed = True
  109. logger.info(f"SSE回调已注销: {callback_task_id}")
  110. if not connection_existed and not callback_existed:
  111. logger.debug(f"SSE连接和回调均不存在: {callback_task_id}")
  112. else:
  113. logger.info(f"SSE连接清理完成: {callback_task_id}, 剩余连接数: {len(self.connections)}, 剩余回调数: {len(self.callbacks)}")
  114. except Exception as e:
  115. logger.error(f"关闭SSE连接时出错: {callback_task_id}, 错误: {str(e)}")
  116. async def send_progress(self, callback_task_id: str, current_data: dict):
  117. """
  118. 发送进度消息到指定连接
  119. Args:
  120. callback_task_id: 回调任务ID
  121. current_data: 进度数据
  122. """
  123. try:
  124. queue = self.connections.get(callback_task_id)
  125. if queue:
  126. # 确定事件类型
  127. event_type = current_data.get("event_type", "processing")
  128. # 处理特殊的单元审查事件
  129. if event_type == "unit_review" or (event_type == "processing" and current_data.get("status") == "unit_review_update"):
  130. event_type = "unit_review_update"
  131. # 添加时间戳
  132. message = {
  133. "type": event_type,
  134. "data": current_data,
  135. "timestamp": datetime.now().isoformat()
  136. }
  137. await queue.put(message)
  138. logger.debug(f"SSE进度已推送: {callback_task_id}, 事件类型: {event_type}")
  139. else:
  140. logger.warning(f"SSE连接不存在,跳过进度推送: {callback_task_id} - 任务继续执行")
  141. except Exception as e:
  142. logger.error(f"发送SSE进度消息失败: {callback_task_id}, 错误: {str(e)}")
  143. async def trigger_callback(self, callback_task_id: str, current_data: dict):
  144. """
  145. 触发指定任务的回调函数
  146. Args:
  147. callback_task_id: 回调任务ID
  148. current_data: 传递给回调的数据
  149. Returns:
  150. bool: 回调是否成功触发
  151. """
  152. try:
  153. callback_func = self.callbacks.get(callback_task_id)
  154. if callback_func:
  155. await callback_func(callback_task_id, current_data)
  156. logger.debug(f"SSE回调执行成功: {callback_task_id}")
  157. return True
  158. else:
  159. logger.debug(f"未找到SSE回调: {callback_task_id}, 已注册ID: {list(self.callbacks.keys())}")
  160. return False
  161. except Exception as e:
  162. logger.error(f"SSE回调执行失败: {callback_task_id}, 错误: {str(e)}")
  163. return False
  164. def is_connected(self, callback_task_id: str) -> bool:
  165. """检查SSE连接是否存在"""
  166. return callback_task_id in self.connections
  167. def is_callback_registered(self, callback_task_id: str) -> bool:
  168. """检查回调函数是否已注册"""
  169. return callback_task_id in self.callbacks
  170. def get_connection_count(self) -> int:
  171. """获取当前连接数"""
  172. return len(self.connections)
  173. def get_callback_count(self) -> int:
  174. """获取当前回调数"""
  175. return len(self.callbacks)
  176. def get_stats(self) -> Dict[str, Any]:
  177. """获取详细的统计信息"""
  178. return {
  179. "connections": {
  180. "count": len(self.connections),
  181. "ids": list(self.connections.keys())
  182. },
  183. "callbacks": {
  184. "count": len(self.callbacks),
  185. "ids": list(self.callbacks.keys())
  186. },
  187. "synchronized": len(self.connections) == len(self.callbacks)
  188. }
  189. async def clear_all(self):
  190. """清理所有连接和回调"""
  191. try:
  192. connection_count = len(self.connections)
  193. callback_count = len(self.callbacks)
  194. self.connections.clear()
  195. self.callbacks.clear()
  196. logger.info(f"已清理所有SSE连接和回调: {connection_count}个连接, {callback_count}个回调")
  197. except Exception as e:
  198. logger.error(f"清理所有SSE连接和回调时出错: {str(e)}")
  199. def register_callback_only(self, callback_task_id: str, callback_func: Callable):
  200. """
  201. 仅注册回调函数(不建立连接)
  202. Args:
  203. callback_task_id: 回调任务ID
  204. callback_func: 回调函数
  205. """
  206. self.callbacks[callback_task_id] = callback_func
  207. logger.info(f"SSE回调已注册: {callback_task_id}, 当前回调数: {len(self.callbacks)}")
  208. def unregister_callback_only(self, callback_task_id: str):
  209. """
  210. 仅注销回调函数(不关闭连接)
  211. Args:
  212. callback_task_id: 回调任务ID
  213. """
  214. if callback_task_id in self.callbacks:
  215. del self.callbacks[callback_task_id]
  216. logger.info(f"SSE回调已注销: {callback_task_id}, 剩余回调数: {len(self.callbacks)}")
  217. else:
  218. logger.debug(f"SSE回调不存在: {callback_task_id}")
  219. # 创建全局单例实例
  220. unified_sse_manager = UnifiedSSEManager()