progress_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. """
  2. 任务进度管理器
  3. 负责任务进度的存储、更新和查询
  4. """
  5. import json
  6. import asyncio
  7. from typing import Dict, Any, Optional
  8. from datetime import datetime
  9. from foundation.logger.loggering import server_logger as logger
  10. from foundation.base.config import config_handler
  11. class SSECallbackManager:
  12. """SSE回调管理器 - 单例模式管理全局SSE回调"""
  13. _instance = None
  14. _callbacks = {} # {callback_task_id: callback_function}
  15. def __new__(cls):
  16. if cls._instance is None:
  17. cls._instance = super().__new__(cls)
  18. return cls._instance
  19. def register_callback(self, callback_task_id: str, callback_func):
  20. """注册SSE回调函数"""
  21. self._callbacks[callback_task_id] = callback_func
  22. logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}")
  23. def unregister_callback(self, callback_task_id: str):
  24. """注销SSE回调函数"""
  25. if callback_task_id in self._callbacks:
  26. del self._callbacks[callback_task_id]
  27. logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}")
  28. async def trigger_callback(self, callback_task_id: str, current_data: dict):
  29. """触发SSE回调"""
  30. if callback_task_id in self._callbacks:
  31. try:
  32. # 直接异步执行回调,保持trace上下文
  33. await self._callbacks[callback_task_id](callback_task_id, current_data)
  34. logger.debug(f"SSE回调执行成功: {callback_task_id}")
  35. logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}")
  36. return True
  37. except Exception as e:
  38. logger.error(f"SSE回调执行失败: {callback_task_id}, {e}")
  39. return False
  40. else:
  41. logger.debug(f"未找到SSE回调: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}, 已注册ID: {list(self._callbacks.keys())}")
  42. return False
  43. def get_callbacks_count(self):
  44. """获取当前回调数量"""
  45. return len(self._callbacks)
  46. def clear_all_callbacks(self):
  47. """清空所有回调"""
  48. self._callbacks.clear()
  49. logger.info("已清空所有SSE回调")
  50. # 全局SSE回调管理器实例
  51. sse_callback_manager = SSECallbackManager()
  52. class ProgressManager:
  53. """任务进度管理器 - 增长型进度管理版本"""
  54. def __init__(self):
  55. self.redis_client = None
  56. self.redis_connected = False
  57. self._init_redis()
  58. def _init_redis(self):
  59. """初始化Redis连接"""
  60. try:
  61. import redis
  62. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  63. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  64. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  65. redis_db = config_handler.get('redis', 'REDIS_DB', '0')
  66. # 构建Redis连接URL
  67. if redis_password:
  68. redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
  69. else:
  70. redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
  71. logger.debug(f"ProgressManager连接Redis: {redis_url}")
  72. # 连接Redis
  73. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  74. # 测试连接
  75. self.redis_client.ping()
  76. self.redis_connected = True
  77. logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
  78. except Exception as e:
  79. logger.error(f"ProgressManager Redis连接失败: {e}")
  80. self.redis_connected = False
  81. logger.warning("ProgressManager将使用内存存储作为备选方案")
  82. self.current_data = {} # 备选内存存储
  83. async def _get_redis_key(self, callback_task_id: str) -> str:
  84. """获取Redis键名"""
  85. return f"current:{callback_task_id}"
  86. async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
  87. """初始化进度记录"""
  88. try:
  89. # 设置总量为100(百分比模式)
  90. stage_name = stages[0]["stage_name"] if stages else ""
  91. message = "任务开始"
  92. current_data = {
  93. "user_id": user_id,
  94. "current": 0,
  95. "stage_name": "",
  96. "status": "准备开始",
  97. "message": "任务开始",
  98. "updated_at": datetime.now().isoformat(),
  99. "overall_task_status": "pending"
  100. }
  101. if self.redis_connected:
  102. # 使用同步Redis操作避免异步任务销毁问题
  103. try:
  104. redis_key = await self._get_redis_key(callback_task_id)
  105. self.redis_client.setex(
  106. redis_key,
  107. 3600, # 1小时过期
  108. json.dumps(current_data)
  109. )
  110. logger.info(f"初始化任务进度列表")
  111. except Exception as redis_e:
  112. logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
  113. # 降级到内存存储
  114. if not hasattr(self, 'current_data'):
  115. self.current_data = {}
  116. self.current_data[callback_task_id] = current_data
  117. logger.info(f"降级使用内存存储: {callback_task_id}")
  118. else:
  119. # 使用内存存储
  120. if not hasattr(self, 'current_data'):
  121. self.current_data = {}
  122. self.current_data[callback_task_id] = current_data
  123. logger.info(f"初始化任务进度到内存: {callback_task_id}")
  124. except Exception as e:
  125. logger.error(f"初始化进度失败: {str(e)}")
  126. raise
  127. async def update_stage_progress(self, callback_task_id: str, stage_name: str, current: int, status: str, message: str = ""):
  128. """更新阶段进度"""
  129. try:
  130. task_progress = None
  131. if self.redis_connected:
  132. # 从Redis读取
  133. redis_key = await self._get_redis_key(callback_task_id)
  134. progress_json = self.redis_client.get(redis_key)
  135. if progress_json:
  136. task_progress = json.loads(progress_json)
  137. else:
  138. logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
  139. return
  140. else:
  141. # 从内存读取
  142. if callback_task_id in self.current_data:
  143. task_progress = self.current_data[callback_task_id]
  144. else:
  145. logger.warning(f"内存中未找到任务进度: {callback_task_id}")
  146. return
  147. # 更新进度数据
  148. task_progress["current"] = current
  149. task_progress["stage_name"] = stage_name
  150. task_progress["status"] = status
  151. task_progress["message"] = message
  152. task_progress["updated_at"] = datetime.now().isoformat()
  153. # 保留overall_task_status字段,不要被普通进度更新覆盖
  154. if "overall_task_status" not in task_progress:
  155. task_progress["overall_task_status"] = "processing"
  156. try:
  157. if self.redis_connected:
  158. try:
  159. self.redis_client.setex(
  160. redis_key,
  161. 3600, # 1小时过期
  162. json.dumps(task_progress)
  163. )
  164. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%")
  165. except Exception as sync_e:
  166. logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
  167. # 同步操作也失败时,降级到内存存储
  168. if not hasattr(self, 'current_data'):
  169. self.current_data = {}
  170. self.current_data[callback_task_id] = task_progress
  171. logger.debug(f"降级使用内存存储: {callback_task_id}")
  172. else:
  173. if not hasattr(self, 'current_data'):
  174. self.current_data = {}
  175. self.current_data[callback_task_id] = task_progress
  176. logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%")
  177. except Exception as e:
  178. logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
  179. if not hasattr(self, 'current_data'):
  180. self.current_data = {}
  181. self.current_data[callback_task_id] = task_progress
  182. # 触发SSE推送 - 使用全局回调管理器
  183. logger.debug(f"触发SSE推送: {callback_task_id}")
  184. updated_progress = await self.get_progress(callback_task_id)
  185. if updated_progress:
  186. await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
  187. except Exception as e:
  188. logger.error(f"更新阶段进度失败: {str(e)}")
  189. raise
  190. async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
  191. """获取任务进度"""
  192. try:
  193. logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
  194. task_progress = None
  195. if self.redis_connected:
  196. # 从Redis读取
  197. redis_key = await self._get_redis_key(callback_task_id)
  198. logger.debug(f"Redis键: {redis_key}")
  199. progress_json = self.redis_client.get(redis_key)
  200. logger.debug(f"从Redis读取数据: {progress_json is not None}")
  201. if progress_json:
  202. task_progress = json.loads(progress_json)
  203. else:
  204. logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
  205. return None
  206. else:
  207. # 从内存读取
  208. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  209. task_progress = self.current_data[callback_task_id]
  210. else:
  211. logger.debug(f"内存中未找到任务进度: {callback_task_id}")
  212. return None
  213. # 获取overall_task_status,默认为"pending"
  214. overall_task_status = task_progress.get("overall_task_status", "pending")
  215. # 转换时间戳
  216. updated_at = task_progress["updated_at"]
  217. if isinstance(updated_at, str):
  218. updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
  219. else:
  220. updated_at_timestamp = int(updated_at.timestamp())
  221. return {
  222. "callback_task_id": callback_task_id,
  223. "user_id": task_progress["user_id"],
  224. "current": task_progress["current"],
  225. "stage_name": task_progress["stage_name"],
  226. "status": task_progress["status"],
  227. "message": task_progress["message"],
  228. "overall_task_status": overall_task_status,
  229. "updated_at": updated_at_timestamp
  230. }
  231. except Exception as e:
  232. logger.error(f"获取进度失败: {str(e)}")
  233. return None
  234. async def complete_task(self, callback_task_id: str):
  235. """标记任务完成"""
  236. try:
  237. task_progress = None
  238. logger.info(f"通知sse连接关闭: {callback_task_id}")
  239. if self.redis_connected:
  240. redis_key = await self._get_redis_key(callback_task_id)
  241. progress_json = self.redis_client.get(redis_key)
  242. if progress_json:
  243. task_progress = json.loads(progress_json)
  244. else:
  245. logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
  246. return
  247. else:
  248. # 从内存读取
  249. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  250. task_progress = self.current_data[callback_task_id]
  251. else:
  252. logger.warning(f"内存中未找到任务进度: {callback_task_id}")
  253. return
  254. task_progress["status"] = "completed"
  255. task_progress["overall_task_status"] = "completed"
  256. task_progress["message"] = "任务已全部完成"
  257. task_progress["updated_at"] = datetime.now().isoformat()
  258. # 保存更新后的数据
  259. if self.redis_connected:
  260. self.redis_client.setex(
  261. redis_key,
  262. 3600,
  263. json.dumps(task_progress)
  264. )
  265. else:
  266. if hasattr(self, 'current_data'):
  267. self.current_data[callback_task_id] = task_progress
  268. # 触发SSE进度更新推送
  269. completed_progress = await self.get_progress(callback_task_id)
  270. if completed_progress:
  271. await sse_callback_manager.trigger_callback(callback_task_id, completed_progress)
  272. logger.debug(f"SSE完成进度已推送: {callback_task_id}")
  273. else:
  274. logger.warning(f"无法获取完成进度数据: {callback_task_id}")
  275. except Exception as e:
  276. logger.error(f"标记任务完成失败: {str(e)}")
  277. raise