progress_manager.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import json
  2. import asyncio
  3. from typing import Dict, Any, Optional
  4. from datetime import datetime
  5. from foundation.logger.loggering import server_logger as logger
  6. from foundation.base.config import config_handler
  7. class SSECallbackManager:
  8. """SSE回调管理器 - 单例模式"""
  9. _instance = None
  10. _callbacks = {}
  11. def __new__(cls):
  12. if cls._instance is None:
  13. cls._instance = super().__new__(cls)
  14. return cls._instance
  15. def register_callback(self, callback_task_id: str, callback_func):
  16. self._callbacks[callback_task_id] = callback_func
  17. logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}")
  18. def unregister_callback(self, callback_task_id: str):
  19. if callback_task_id in self._callbacks:
  20. del self._callbacks[callback_task_id]
  21. logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}")
  22. def is_callback_registered(self, callback_task_id: str) -> bool:
  23. """检查回调是否已注册"""
  24. return callback_task_id in self._callbacks
  25. async def trigger_callback(self, callback_task_id: str, current_data: dict):
  26. if callback_task_id in self._callbacks:
  27. try:
  28. await self._callbacks[callback_task_id](callback_task_id, current_data)
  29. logger.debug(f"SSE回调执行成功: {callback_task_id}")
  30. logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}")
  31. return True
  32. except Exception as e:
  33. logger.error(f"SSE回调执行失败: {callback_task_id}, {e}")
  34. return False
  35. else:
  36. logger.debug(f"未找到SSE回调: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}, 已注册ID: {list(self._callbacks.keys())}")
  37. return False
  38. def get_callbacks_count(self):
  39. return len(self._callbacks)
  40. def clear_all_callbacks(self):
  41. self._callbacks.clear()
  42. logger.info("已清空所有SSE回调")
  43. def force_close_sse(self, callback_task_id: str):
  44. """强制关闭SSE连接"""
  45. if callback_task_id in self._callbacks:
  46. del self._callbacks[callback_task_id]
  47. logger.info(f"强制关闭SSE连接: {callback_task_id}")
  48. else:
  49. logger.warning(f"SSE连接已不存在,无需关闭: {callback_task_id}")
  50. sse_callback_manager = SSECallbackManager()
  51. class ProgressManager:
  52. """任务进度管理器"""
  53. def __init__(self):
  54. self.redis_client = None
  55. self.redis_connected = False
  56. self._init_redis()
  57. def _init_redis(self):
  58. try:
  59. import redis
  60. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  61. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  62. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  63. redis_db = config_handler.get('redis', 'REDIS_DB', '0')
  64. if redis_password:
  65. redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
  66. else:
  67. redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
  68. logger.debug(f"ProgressManager连接Redis: {redis_url}")
  69. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  70. self.redis_client.ping()
  71. self.redis_connected = True
  72. logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
  73. except Exception as e:
  74. logger.error(f"ProgressManager Redis连接失败: {e}")
  75. self.redis_connected = False
  76. logger.warning("ProgressManager将使用内存存储作为备选方案")
  77. self.current_data = {}
  78. async def _get_redis_key(self, callback_task_id: str) -> str:
  79. return f"current:{callback_task_id}"
  80. async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
  81. try:
  82. current_data = {
  83. "user_id": user_id,
  84. "current": 0,
  85. "stage_name": "",
  86. "status": "准备开始",
  87. "message": "任务开始",
  88. "updated_at": datetime.now().isoformat(),
  89. "overall_task_status": "pending"
  90. }
  91. if self.redis_connected:
  92. try:
  93. redis_key = await self._get_redis_key(callback_task_id)
  94. self.redis_client.setex(
  95. redis_key,
  96. 3600,
  97. json.dumps(current_data)
  98. )
  99. logger.info(f"初始化任务进度列表")
  100. except Exception as redis_e:
  101. logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
  102. if not hasattr(self, 'current_data'):
  103. self.current_data = {}
  104. self.current_data[callback_task_id] = current_data
  105. logger.info(f"降级使用内存存储: {callback_task_id}")
  106. else:
  107. if not hasattr(self, 'current_data'):
  108. self.current_data = {}
  109. self.current_data[callback_task_id] = current_data
  110. logger.info(f"初始化任务进度到内存: {callback_task_id}")
  111. except Exception as e:
  112. logger.error(f"初始化进度失败: {str(e)}")
  113. raise
  114. async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None, event_type: str = "processing"):
  115. """更新阶段进度 - 除callback_task_id外,其他参数都可选
  116. Args:
  117. callback_task_id: 回调任务ID(必需)
  118. stage_name: 阶段名称(可选)
  119. current: 当前进度(可选)
  120. status: 状态(可选)
  121. message: 消息(可选)
  122. issues: 问题列表(可选)
  123. user_id: 用户ID(可选)
  124. overall_task_status: 整体任务状态(可选)
  125. event_type: SSE事件类型(可选,默认为"processing")
  126. """
  127. try:
  128. task_progress = None
  129. if self.redis_connected:
  130. # 从Redis读取
  131. redis_key = await self._get_redis_key(callback_task_id)
  132. progress_json = self.redis_client.get(redis_key)
  133. if progress_json:
  134. task_progress = json.loads(progress_json)
  135. else:
  136. logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
  137. return
  138. else:
  139. # 从内存读取
  140. if callback_task_id in self.current_data:
  141. task_progress = self.current_data[callback_task_id]
  142. else:
  143. logger.warning(f"内存中未找到任务进度: {callback_task_id}")
  144. return
  145. # 更新进度数据 - 只有非空参数才更新
  146. if current is not None:
  147. task_progress["current"] = current
  148. if stage_name is not None:
  149. task_progress["stage_name"] = stage_name
  150. if status is not None:
  151. task_progress["status"] = status
  152. if message is not None:
  153. task_progress["message"] = message
  154. task_progress["updated_at"] = datetime.now().isoformat()
  155. if issues is not None:
  156. task_progress["issues"] = issues
  157. else:
  158. task_progress["issues"] = []
  159. if user_id is not None:
  160. task_progress["user_id"] = user_id
  161. if overall_task_status is not None:
  162. task_progress["overall_task_status"] = overall_task_status
  163. elif "overall_task_status" not in task_progress:
  164. task_progress["overall_task_status"] = "processing"
  165. if event_type is not None:
  166. task_progress["event_type"] = event_type
  167. logger.debug(f"设置event_type: {event_type} for {callback_task_id}")
  168. else:
  169. logger.debug(f"event_type为None,不设置 for {callback_task_id}")
  170. try:
  171. if self.redis_connected:
  172. try:
  173. self.redis_client.setex(
  174. redis_key,
  175. 3600, # 1小时过期
  176. json.dumps(task_progress)
  177. )
  178. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%")
  179. except Exception as sync_e:
  180. logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
  181. # 同步操作也失败时,降级到内存存储
  182. if not hasattr(self, 'current_data'):
  183. self.current_data = {}
  184. self.current_data[callback_task_id] = task_progress
  185. logger.debug(f"降级使用内存存储: {callback_task_id}")
  186. else:
  187. if not hasattr(self, 'current_data'):
  188. self.current_data = {}
  189. self.current_data[callback_task_id] = task_progress
  190. logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%")
  191. except Exception as e:
  192. logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
  193. if not hasattr(self, 'current_data'):
  194. self.current_data = {}
  195. self.current_data[callback_task_id] = task_progress
  196. if task_progress.get("overall_task_status") == "completed":
  197. logger.info(f"任务完成,不触发redis更新SSE推送,将完成信号交由lanch_review上层sse推送: {callback_task_id}")
  198. else:
  199. logger.info(f"触发SSE推送: {callback_task_id}")
  200. updated_progress = await self.get_progress(callback_task_id)
  201. issues = task_progress.get("issues")
  202. event_type = task_progress.get("event_type", "processing")
  203. logger.info(f"触发SSE回调: {callback_task_id}, event_type: {event_type}")
  204. if updated_progress and issues and len(issues) > 0 and issues[0] != 'clear':
  205. await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
  206. elif updated_progress and not issues: # 空列表时也要推送
  207. await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
  208. except Exception as e:
  209. logger.error(f"更新阶段进度失败: {str(e)}")
  210. raise
  211. async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
  212. """获取任务进度"""
  213. try:
  214. logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
  215. task_progress = None
  216. if self.redis_connected:
  217. # 从Redis读取
  218. redis_key = await self._get_redis_key(callback_task_id)
  219. logger.debug(f"Redis键: {redis_key}")
  220. progress_json = self.redis_client.get(redis_key)
  221. logger.debug(f"从Redis读取数据: {progress_json is not None}")
  222. if progress_json:
  223. task_progress = json.loads(progress_json)
  224. else:
  225. logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
  226. return None
  227. else:
  228. # 从内存读取
  229. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  230. task_progress = self.current_data[callback_task_id]
  231. else:
  232. logger.debug(f"内存中未找到任务进度: {callback_task_id}")
  233. return None
  234. # 获取overall_task_status,默认为"pending"
  235. overall_task_status = task_progress.get("overall_task_status", "pending")
  236. # 转换时间戳
  237. updated_at = task_progress["updated_at"]
  238. if isinstance(updated_at, str):
  239. updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
  240. else:
  241. updated_at_timestamp = int(updated_at.timestamp())
  242. # 构建返回数据
  243. result = {
  244. "callback_task_id": callback_task_id,
  245. "user_id": task_progress["user_id"],
  246. "current": task_progress["current"],
  247. "stage_name": task_progress["stage_name"],
  248. "status": task_progress["status"],
  249. "message": task_progress["message"],
  250. "overall_task_status": overall_task_status,
  251. "updated_at": updated_at_timestamp
  252. }
  253. # 添加可选字段
  254. if "issues" in task_progress:
  255. result["issues"] = task_progress["issues"]
  256. if "event_type" in task_progress:
  257. result["event_type"] = task_progress["event_type"]
  258. return result
  259. except Exception as e:
  260. logger.error(f"获取进度失败: {str(e)}")
  261. return None
  262. async def complete_task(self, callback_task_id: str, user_id: str = None, current_data: dict = None):
  263. """标记任务完成"""
  264. try:
  265. logger.info(f"保存审查结果: {callback_task_id}")
  266. # 使用update_stage_progress方法更新响应数据,但不推送SSE
  267. await self.update_stage_progress(
  268. callback_task_id=callback_task_id,
  269. user_id=user_id,
  270. current=current_data.get("current", 100) if current_data else 100,
  271. stage_name="审查完成",
  272. status="completed",
  273. message="施工审查方案处理完成!",
  274. overall_task_status='completed',
  275. issues=current_data.get("issues", []) if current_data else []
  276. )
  277. logger.info(f"取消注册任务: {callback_task_id}")
  278. # 先取消注册,再强制关闭,确保彻底清理
  279. sse_callback_manager.unregister_callback(callback_task_id)
  280. # 强制关闭SSE连接(防止残留)
  281. sse_callback_manager.force_close_sse(callback_task_id)
  282. logger.info(f"SSE连接已彻底关闭: {callback_task_id}")
  283. logger.info(f"任务关闭: {callback_task_id}")
  284. except Exception as e:
  285. logger.error(f"标记任务完成失败: {str(e)}")
  286. raise