progress_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import json
  2. import time
  3. import asyncio
  4. from typing import Dict, Any, Optional
  5. from datetime import datetime
  6. from foundation.observability.logger.loggering import review_logger as logger
  7. from foundation.infrastructure.config import config_handler
  8. class ProgressManager:
  9. """任务进度管理器"""
  10. def __init__(self):
  11. self.redis_client = None
  12. self.redis_connected = False
  13. self._init_redis()
  14. def _init_redis(self):
  15. try:
  16. import redis
  17. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  18. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  19. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  20. redis_db = config_handler.get('redis', 'REDIS_DB', '0')
  21. if redis_password:
  22. redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
  23. else:
  24. redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
  25. logger.debug(f"ProgressManager连接Redis: {redis_url}")
  26. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  27. self.redis_client.ping()
  28. self.redis_connected = True
  29. logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
  30. except Exception as e:
  31. logger.error(f"ProgressManager Redis连接失败: {e}")
  32. self.redis_connected = False
  33. logger.warning("ProgressManager将使用内存存储作为备选方案")
  34. self.current_data = {}
  35. async def _get_redis_key(self, callback_task_id: str) -> str:
  36. return f"current:{callback_task_id}"
  37. async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
  38. try:
  39. current_data = {
  40. "user_id": user_id,
  41. "current": 0,
  42. "stage_name": "",
  43. "status": "准备开始",
  44. "message": "任务开始",
  45. "updated_at": datetime.now().isoformat(),
  46. "overall_task_status": "pending"
  47. }
  48. if self.redis_connected:
  49. try:
  50. redis_key = await self._get_redis_key(callback_task_id)
  51. self.redis_client.setex(
  52. redis_key,
  53. 3600,
  54. json.dumps(current_data)
  55. )
  56. logger.info(f"初始化任务进度列表")
  57. except Exception as redis_e:
  58. logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
  59. if not hasattr(self, 'current_data'):
  60. self.current_data = {}
  61. self.current_data[callback_task_id] = current_data
  62. logger.info(f"降级使用内存存储: {callback_task_id}")
  63. else:
  64. if not hasattr(self, 'current_data'):
  65. self.current_data = {}
  66. self.current_data[callback_task_id] = current_data
  67. logger.info(f"初始化任务进度到内存: {callback_task_id}")
  68. except Exception as e:
  69. logger.error(f"初始化进度失败: {str(e)}")
  70. raise
  71. async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None, event_type: str = "processing"):
  72. """更新阶段进度 - 除callback_task_id外,其他参数都可选
  73. Args:
  74. callback_task_id: 回调任务ID(必需)
  75. stage_name: 阶段名称(可选)
  76. current: 当前进度(可选)
  77. status: 状态(可选)
  78. message: 消息(可选)
  79. issues: 问题列表(可选)
  80. user_id: 用户ID(可选)
  81. overall_task_status: 整体任务状态(可选)
  82. event_type: SSE事件类型(可选,默认为"processing")
  83. """
  84. try:
  85. task_progress = None
  86. if self.redis_connected:
  87. # 从Redis读取
  88. redis_key = await self._get_redis_key(callback_task_id)
  89. progress_json = self.redis_client.get(redis_key)
  90. if progress_json:
  91. task_progress = json.loads(progress_json)
  92. else:
  93. logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
  94. return
  95. else:
  96. # 从内存读取
  97. if callback_task_id in self.current_data:
  98. task_progress = self.current_data[callback_task_id]
  99. else:
  100. logger.warning(f"内存中未找到任务进度: {callback_task_id}")
  101. return
  102. # 更新进度数据 - 只有非空参数才更新
  103. if current is not None:
  104. task_progress["current"] = current
  105. if stage_name is not None:
  106. task_progress["stage_name"] = stage_name
  107. if status is not None:
  108. task_progress["status"] = status
  109. if message is not None:
  110. task_progress["message"] = message
  111. task_progress["updated_at"] = datetime.now().isoformat()
  112. if issues is not None:
  113. task_progress["issues"] = issues
  114. else:
  115. task_progress["issues"] = []
  116. if user_id is not None:
  117. task_progress["user_id"] = user_id
  118. if overall_task_status is not None:
  119. task_progress["overall_task_status"] = overall_task_status
  120. elif "overall_task_status" not in task_progress:
  121. task_progress["overall_task_status"] = "processing"
  122. if event_type is not None:
  123. task_progress["event_type"] = event_type
  124. logger.debug(f"设置event_type: {event_type} for {callback_task_id}")
  125. else:
  126. logger.debug(f"event_type为None,不设置 for {callback_task_id}")
  127. try:
  128. if self.redis_connected:
  129. try:
  130. self.redis_client.setex(
  131. redis_key,
  132. 3600, # 1小时过期
  133. json.dumps(task_progress)
  134. )
  135. actual_current = task_progress.get("current")
  136. if current is not None:
  137. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {actual_current}%")
  138. else:
  139. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度保持: {actual_current}% (未传入)")
  140. except Exception as sync_e:
  141. logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
  142. # 同步操作也失败时,降级到内存存储
  143. if not hasattr(self, 'current_data'):
  144. self.current_data = {}
  145. self.current_data[callback_task_id] = task_progress
  146. logger.debug(f"降级使用内存存储: {callback_task_id}")
  147. else:
  148. if not hasattr(self, 'current_data'):
  149. self.current_data = {}
  150. self.current_data[callback_task_id] = task_progress
  151. actual_current = task_progress.get("current")
  152. if current is not None:
  153. logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {actual_current}%")
  154. else:
  155. logger.debug(f"更新进度到内存: {callback_task_id}, 进度保持: {actual_current}% (未传入)")
  156. except Exception as e:
  157. logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
  158. if not hasattr(self, 'current_data'):
  159. self.current_data = {}
  160. self.current_data[callback_task_id] = task_progress
  161. # 进度已保存到 Redis,SSE 由主进程通过轮询获取
  162. actual_current = task_progress.get("current")
  163. if current is not None:
  164. logger.debug(f"进度已更新到Redis: {callback_task_id}, current={actual_current}%")
  165. else:
  166. logger.debug(f"进度已更新到Redis: {callback_task_id}, current={actual_current}% (保持)")
  167. except Exception as e:
  168. logger.error(f"更新阶段进度失败: {str(e)}")
  169. raise
  170. async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
  171. """获取任务进度"""
  172. try:
  173. #logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
  174. task_progress = None
  175. if self.redis_connected:
  176. # 从Redis读取
  177. redis_key = await self._get_redis_key(callback_task_id)
  178. #logger.debug(f"Redis键: {redis_key}")
  179. progress_json = self.redis_client.get(redis_key)
  180. #logger.debug(f"从Redis读取数据: {progress_json is not None}")
  181. if progress_json:
  182. task_progress = json.loads(progress_json)
  183. else:
  184. logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
  185. return None
  186. else:
  187. # 从内存读取
  188. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  189. task_progress = self.current_data[callback_task_id]
  190. else:
  191. logger.debug(f"内存中未找到任务进度: {callback_task_id}")
  192. return None
  193. # 获取overall_task_status,默认为"pending"
  194. overall_task_status = task_progress.get("overall_task_status", "pending")
  195. # 转换时间戳
  196. updated_at = task_progress["updated_at"]
  197. if isinstance(updated_at, str):
  198. updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
  199. else:
  200. updated_at_timestamp = int(updated_at.timestamp())
  201. # 构建返回数据
  202. result = {
  203. "callback_task_id": callback_task_id,
  204. "user_id": task_progress["user_id"],
  205. "current": task_progress["current"],
  206. "stage_name": task_progress["stage_name"],
  207. "status": task_progress["status"],
  208. "message": task_progress["message"],
  209. "overall_task_status": overall_task_status,
  210. "updated_at": updated_at_timestamp
  211. }
  212. # 添加可选字段
  213. if "issues" in task_progress:
  214. result["issues"] = task_progress["issues"]
  215. if "event_type" in task_progress:
  216. result["event_type"] = task_progress["event_type"]
  217. return result
  218. except Exception as e:
  219. logger.error(f"获取进度失败: {str(e)}")
  220. return None
  221. async def complete_task(self, callback_task_id: str, user_id: str = None, current_data: dict = None):
  222. """标记任务完成 - 使用单一同步强制关闭逻辑
  223. Args:
  224. callback_task_id: 回调任务ID
  225. user_id: 用户ID
  226. current_data: 包含 overall_task_status 等数据的字典
  227. """
  228. try:
  229. # 先更新 Redis 中的状态(让 SSE 轮询能检测到)
  230. if current_data and current_data.get("overall_task_status"):
  231. progress_data = await self.get_progress(callback_task_id)
  232. if progress_data:
  233. # 更新状态字段
  234. progress_data["overall_task_status"] = current_data.get("overall_task_status")
  235. progress_data["status"] = current_data.get("status", current_data.get("overall_task_status"))
  236. progress_data["message"] = current_data.get("message", "任务已完成")
  237. progress_data["updated_at"] = datetime.now().isoformat()
  238. if current_data.get("error"):
  239. progress_data["error"] = current_data.get("error")
  240. # 保存回 Redis
  241. if self.redis_connected:
  242. redis_key = await self._get_redis_key(callback_task_id)
  243. self.redis_client.setex(
  244. redis_key,
  245. 3600,
  246. json.dumps(progress_data)
  247. )
  248. logger.info(f"任务状态已更新到Redis: {callback_task_id}, status={current_data.get('overall_task_status')}")
  249. else:
  250. if not hasattr(self, 'current_data'):
  251. self.current_data = {}
  252. self.current_data[callback_task_id] = progress_data
  253. # SSE 连接由主进程管理(通过 Redis 轮询),Worker 无需关闭
  254. logger.info(f"任务状态已更新,等待主进程轮询检测: {callback_task_id}")
  255. except Exception as e:
  256. logger.error(f"标记任务完成失败: {str(e)}")
  257. raise