progress_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import json
  2. import time
  3. import asyncio
  4. from typing import Dict, Any, Optional
  5. from datetime import datetime
  6. from foundation.logger.loggering import server_logger as logger
  7. from foundation.base.config import config_handler
  8. from core.base.sse_manager import unified_sse_manager
  9. class SSECallbackManager:
  10. """
  11. SSE回调管理器 - 兼容性包装器,委托给统一SSE管理器
  12. 注意: 此类保持向后兼容,建议直接使用 unified_sse_manager
  13. """
  14. def register_callback(self, callback_task_id: str, callback_func):
  15. """注册回调函数"""
  16. unified_sse_manager.register_callback_only(callback_task_id, callback_func)
  17. def unregister_callback(self, callback_task_id: str):
  18. """注销回调函数"""
  19. unified_sse_manager.unregister_callback_only(callback_task_id)
  20. def is_callback_registered(self, callback_task_id: str) -> bool:
  21. """检查回调是否已注册"""
  22. return unified_sse_manager.is_callback_registered(callback_task_id)
  23. async def trigger_callback(self, callback_task_id: str, current_data: dict):
  24. """触发回调函数"""
  25. return await unified_sse_manager.trigger_callback(callback_task_id, current_data)
  26. def get_callbacks_count(self):
  27. """获取回调数量"""
  28. return unified_sse_manager.get_callback_count()
  29. def clear_all_callbacks(self):
  30. """清空所有回调"""
  31. asyncio.create_task(unified_sse_manager.clear_all())
  32. async def force_close_sse(self, callback_task_id: str):
  33. """强制关闭SSE连接 - 同步调用close_connection(兼容性方法)"""
  34. await unified_sse_manager.close_connection(callback_task_id)
  35. # 创建兼容性实例
  36. sse_callback_manager = SSECallbackManager()
  37. class ProgressManager:
  38. """任务进度管理器"""
  39. def __init__(self):
  40. self.redis_client = None
  41. self.redis_connected = False
  42. self._init_redis()
  43. def _init_redis(self):
  44. try:
  45. import redis
  46. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  47. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  48. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  49. redis_db = config_handler.get('redis', 'REDIS_DB', '0')
  50. if redis_password:
  51. redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
  52. else:
  53. redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
  54. logger.debug(f"ProgressManager连接Redis: {redis_url}")
  55. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  56. self.redis_client.ping()
  57. self.redis_connected = True
  58. logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
  59. except Exception as e:
  60. logger.error(f"ProgressManager Redis连接失败: {e}")
  61. self.redis_connected = False
  62. logger.warning("ProgressManager将使用内存存储作为备选方案")
  63. self.current_data = {}
  64. async def _get_redis_key(self, callback_task_id: str) -> str:
  65. return f"current:{callback_task_id}"
  66. async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
  67. try:
  68. current_data = {
  69. "user_id": user_id,
  70. "current": 0,
  71. "stage_name": "",
  72. "status": "准备开始",
  73. "message": "任务开始",
  74. "updated_at": datetime.now().isoformat(),
  75. "overall_task_status": "pending"
  76. }
  77. if self.redis_connected:
  78. try:
  79. redis_key = await self._get_redis_key(callback_task_id)
  80. self.redis_client.setex(
  81. redis_key,
  82. 3600,
  83. json.dumps(current_data)
  84. )
  85. logger.info(f"初始化任务进度列表")
  86. except Exception as redis_e:
  87. logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
  88. if not hasattr(self, 'current_data'):
  89. self.current_data = {}
  90. self.current_data[callback_task_id] = current_data
  91. logger.info(f"降级使用内存存储: {callback_task_id}")
  92. else:
  93. if not hasattr(self, 'current_data'):
  94. self.current_data = {}
  95. self.current_data[callback_task_id] = current_data
  96. logger.info(f"初始化任务进度到内存: {callback_task_id}")
  97. except Exception as e:
  98. logger.error(f"初始化进度失败: {str(e)}")
  99. raise
  100. async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None, event_type: str = "processing"):
  101. """更新阶段进度 - 除callback_task_id外,其他参数都可选
  102. Args:
  103. callback_task_id: 回调任务ID(必需)
  104. stage_name: 阶段名称(可选)
  105. current: 当前进度(可选)
  106. status: 状态(可选)
  107. message: 消息(可选)
  108. issues: 问题列表(可选)
  109. user_id: 用户ID(可选)
  110. overall_task_status: 整体任务状态(可选)
  111. event_type: SSE事件类型(可选,默认为"processing")
  112. """
  113. try:
  114. task_progress = None
  115. if self.redis_connected:
  116. # 从Redis读取
  117. redis_key = await self._get_redis_key(callback_task_id)
  118. progress_json = self.redis_client.get(redis_key)
  119. if progress_json:
  120. task_progress = json.loads(progress_json)
  121. else:
  122. logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
  123. return
  124. else:
  125. # 从内存读取
  126. if callback_task_id in self.current_data:
  127. task_progress = self.current_data[callback_task_id]
  128. else:
  129. logger.warning(f"内存中未找到任务进度: {callback_task_id}")
  130. return
  131. # 更新进度数据 - 只有非空参数才更新
  132. if current is not None:
  133. task_progress["current"] = current
  134. if stage_name is not None:
  135. task_progress["stage_name"] = stage_name
  136. if status is not None:
  137. task_progress["status"] = status
  138. if message is not None:
  139. task_progress["message"] = message
  140. task_progress["updated_at"] = datetime.now().isoformat()
  141. if issues is not None:
  142. task_progress["issues"] = issues
  143. else:
  144. task_progress["issues"] = []
  145. if user_id is not None:
  146. task_progress["user_id"] = user_id
  147. if overall_task_status is not None:
  148. task_progress["overall_task_status"] = overall_task_status
  149. elif "overall_task_status" not in task_progress:
  150. task_progress["overall_task_status"] = "processing"
  151. if event_type is not None:
  152. task_progress["event_type"] = event_type
  153. logger.debug(f"设置event_type: {event_type} for {callback_task_id}")
  154. else:
  155. logger.debug(f"event_type为None,不设置 for {callback_task_id}")
  156. try:
  157. if self.redis_connected:
  158. try:
  159. self.redis_client.setex(
  160. redis_key,
  161. 3600, # 1小时过期
  162. json.dumps(task_progress)
  163. )
  164. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%")
  165. except Exception as sync_e:
  166. logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
  167. # 同步操作也失败时,降级到内存存储
  168. if not hasattr(self, 'current_data'):
  169. self.current_data = {}
  170. self.current_data[callback_task_id] = task_progress
  171. logger.debug(f"降级使用内存存储: {callback_task_id}")
  172. else:
  173. if not hasattr(self, 'current_data'):
  174. self.current_data = {}
  175. self.current_data[callback_task_id] = task_progress
  176. logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%")
  177. except Exception as e:
  178. logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
  179. if not hasattr(self, 'current_data'):
  180. self.current_data = {}
  181. self.current_data[callback_task_id] = task_progress
  182. if task_progress.get("overall_task_status") == "completed":
  183. logger.info(f"任务完成,不触发redis更新SSE推送,将完成信号交由lanch_review上层sse推送: {callback_task_id}")
  184. else:
  185. logger.info(f"触发SSE推送: {callback_task_id}")
  186. updated_progress = await self.get_progress(callback_task_id)
  187. issues = task_progress.get("issues")
  188. event_type = task_progress.get("event_type", "processing")
  189. logger.info(f"触发SSE回调: {callback_task_id}, event_type: {event_type}")
  190. if updated_progress and issues and len(issues) > 0 and issues[0] != 'clear':
  191. await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
  192. elif updated_progress and not issues: # 空列表时也要推送
  193. await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
  194. except Exception as e:
  195. logger.error(f"更新阶段进度失败: {str(e)}")
  196. raise
  197. async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
  198. """获取任务进度"""
  199. try:
  200. logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
  201. task_progress = None
  202. if self.redis_connected:
  203. # 从Redis读取
  204. redis_key = await self._get_redis_key(callback_task_id)
  205. logger.debug(f"Redis键: {redis_key}")
  206. progress_json = self.redis_client.get(redis_key)
  207. logger.debug(f"从Redis读取数据: {progress_json is not None}")
  208. if progress_json:
  209. task_progress = json.loads(progress_json)
  210. else:
  211. logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
  212. return None
  213. else:
  214. # 从内存读取
  215. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  216. task_progress = self.current_data[callback_task_id]
  217. else:
  218. logger.debug(f"内存中未找到任务进度: {callback_task_id}")
  219. return None
  220. # 获取overall_task_status,默认为"pending"
  221. overall_task_status = task_progress.get("overall_task_status", "pending")
  222. # 转换时间戳
  223. updated_at = task_progress["updated_at"]
  224. if isinstance(updated_at, str):
  225. updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
  226. else:
  227. updated_at_timestamp = int(updated_at.timestamp())
  228. # 构建返回数据
  229. result = {
  230. "callback_task_id": callback_task_id,
  231. "user_id": task_progress["user_id"],
  232. "current": task_progress["current"],
  233. "stage_name": task_progress["stage_name"],
  234. "status": task_progress["status"],
  235. "message": task_progress["message"],
  236. "overall_task_status": overall_task_status,
  237. "updated_at": updated_at_timestamp
  238. }
  239. # 添加可选字段
  240. if "issues" in task_progress:
  241. result["issues"] = task_progress["issues"]
  242. if "event_type" in task_progress:
  243. result["event_type"] = task_progress["event_type"]
  244. return result
  245. except Exception as e:
  246. logger.error(f"获取进度失败: {str(e)}")
  247. return None
  248. async def complete_task(self, callback_task_id: str, user_id: str = None, current_data: dict = None):
  249. """标记任务完成 - 使用单一同步强制关闭逻辑"""
  250. # logger.info(f"审查任务已完成,任务准备关闭...: {callback_task_id}")
  251. # time.sleep(1)
  252. try:
  253. logger.info(f"取消注册任务: {callback_task_id}")
  254. await unified_sse_manager.close_connection(callback_task_id)
  255. logger.info(f"任务关闭: {callback_task_id}")
  256. except Exception as e:
  257. logger.error(f"标记任务完成失败: {str(e)}")
  258. raise