progress_manager.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import json
  2. import time
  3. import asyncio
  4. from typing import Dict, Any, Optional, List
  5. from datetime import datetime
  6. from foundation.observability.logger.loggering import review_logger as logger
  7. from foundation.infrastructure.config import config_handler
  8. class ProgressManager:
  9. """任务进度管理器"""
  10. def __init__(self):
  11. self.redis_client = None
  12. self.redis_connected = False
  13. self._init_redis()
  14. def _init_redis(self):
  15. try:
  16. import redis
  17. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  18. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  19. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  20. redis_db = config_handler.get('redis', 'REDIS_DB', '0')
  21. if redis_password:
  22. redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
  23. else:
  24. redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
  25. logger.debug(f"ProgressManager连接Redis: {redis_url}")
  26. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  27. self.redis_client.ping()
  28. self.redis_connected = True
  29. logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
  30. except Exception as e:
  31. logger.error(f"ProgressManager Redis连接失败: {e}")
  32. self.redis_connected = False
  33. logger.warning("ProgressManager将使用内存存储作为备选方案")
  34. self.current_data = {}
  35. async def _get_redis_key(self, callback_task_id: str) -> str:
  36. return f"current:{callback_task_id}"
  37. async def _get_stream_event_key(self, callback_task_id: str) -> str:
  38. return f"stream_events:{callback_task_id}"
  39. async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
  40. try:
  41. current_data = {
  42. "user_id": user_id,
  43. "current": 0,
  44. "stage_name": "",
  45. "status": "准备开始",
  46. "message": "任务开始",
  47. "updated_at": datetime.now().isoformat(),
  48. "overall_task_status": "pending"
  49. }
  50. if self.redis_connected:
  51. try:
  52. redis_key = await self._get_redis_key(callback_task_id)
  53. self.redis_client.setex(
  54. redis_key,
  55. 3600,
  56. json.dumps(current_data)
  57. )
  58. stream_event_key = await self._get_stream_event_key(callback_task_id)
  59. self.redis_client.delete(stream_event_key)
  60. logger.info(f"初始化任务进度列表")
  61. except Exception as redis_e:
  62. logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
  63. if not hasattr(self, 'current_data'):
  64. self.current_data = {}
  65. if not hasattr(self, 'stream_events'):
  66. self.stream_events = {}
  67. self.current_data[callback_task_id] = current_data
  68. self.stream_events[callback_task_id] = []
  69. logger.info(f"降级使用内存存储: {callback_task_id}")
  70. else:
  71. if not hasattr(self, 'current_data'):
  72. self.current_data = {}
  73. if not hasattr(self, 'stream_events'):
  74. self.stream_events = {}
  75. self.current_data[callback_task_id] = current_data
  76. self.stream_events[callback_task_id] = []
  77. logger.info(f"初始化任务进度到内存: {callback_task_id}")
  78. except Exception as e:
  79. logger.error(f"初始化进度失败: {str(e)}")
  80. raise
  81. async def append_stream_event(self, callback_task_id: str, event_data: Dict[str, Any]):
  82. """追加流式事件,供 SSE 轮询端按事件顺序吐给前端。"""
  83. try:
  84. if self.redis_connected:
  85. stream_event_key = await self._get_stream_event_key(callback_task_id)
  86. self.redis_client.rpush(
  87. stream_event_key,
  88. json.dumps(event_data, ensure_ascii=False)
  89. )
  90. self.redis_client.expire(stream_event_key, 3600)
  91. else:
  92. if not hasattr(self, 'stream_events'):
  93. self.stream_events = {}
  94. self.stream_events.setdefault(callback_task_id, []).append(event_data)
  95. except Exception as e:
  96. logger.error(f"追加流式事件失败: {callback_task_id}, {str(e)}")
  97. raise
  98. async def pop_stream_events(self, callback_task_id: str, max_events: int = 100) -> List[Dict[str, Any]]:
  99. """按 FIFO 顺序取出待发送的流式事件。"""
  100. events: List[Dict[str, Any]] = []
  101. try:
  102. if self.redis_connected:
  103. stream_event_key = await self._get_stream_event_key(callback_task_id)
  104. for _ in range(max_events):
  105. event_json = self.redis_client.lpop(stream_event_key)
  106. if not event_json:
  107. break
  108. try:
  109. events.append(json.loads(event_json))
  110. except Exception as parse_e:
  111. logger.warning(f"解析流式事件失败: {callback_task_id}, {parse_e}")
  112. else:
  113. if not hasattr(self, 'stream_events'):
  114. self.stream_events = {}
  115. queued_events = self.stream_events.get(callback_task_id, [])
  116. events = queued_events[:max_events]
  117. self.stream_events[callback_task_id] = queued_events[max_events:]
  118. except Exception as e:
  119. logger.error(f"获取流式事件失败: {callback_task_id}, {str(e)}")
  120. return events
  121. async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None, event_type: str = "processing"):
  122. """更新阶段进度 - 除callback_task_id外,其他参数都可选
  123. Args:
  124. callback_task_id: 回调任务ID(必需)
  125. stage_name: 阶段名称(可选)
  126. current: 当前进度(可选)
  127. status: 状态(可选)
  128. message: 消息(可选)
  129. issues: 问题列表(可选)
  130. user_id: 用户ID(可选)
  131. overall_task_status: 整体任务状态(可选)
  132. event_type: SSE事件类型(可选,默认为"processing")
  133. """
  134. try:
  135. task_progress = None
  136. if self.redis_connected:
  137. # 从Redis读取
  138. redis_key = await self._get_redis_key(callback_task_id)
  139. progress_json = self.redis_client.get(redis_key)
  140. if progress_json:
  141. task_progress = json.loads(progress_json)
  142. else:
  143. logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
  144. return
  145. else:
  146. # 从内存读取
  147. if callback_task_id in self.current_data:
  148. task_progress = self.current_data[callback_task_id]
  149. else:
  150. logger.warning(f"内存中未找到任务进度: {callback_task_id}")
  151. return
  152. # 更新进度数据 - 只有非空参数才更新
  153. if current is not None:
  154. task_progress["current"] = current
  155. if stage_name is not None:
  156. task_progress["stage_name"] = stage_name
  157. if status is not None:
  158. task_progress["status"] = status
  159. if message is not None:
  160. task_progress["message"] = message
  161. task_progress["updated_at"] = datetime.now().isoformat()
  162. if issues is not None:
  163. task_progress["issues"] = issues
  164. else:
  165. task_progress["issues"] = []
  166. if user_id is not None:
  167. task_progress["user_id"] = user_id
  168. if overall_task_status is not None:
  169. task_progress["overall_task_status"] = overall_task_status
  170. elif "overall_task_status" not in task_progress:
  171. task_progress["overall_task_status"] = "processing"
  172. if event_type is not None:
  173. task_progress["event_type"] = event_type
  174. logger.debug(f"设置event_type: {event_type} for {callback_task_id}")
  175. else:
  176. logger.debug(f"event_type为None,不设置 for {callback_task_id}")
  177. try:
  178. if self.redis_connected:
  179. try:
  180. self.redis_client.setex(
  181. redis_key,
  182. 3600, # 1小时过期
  183. json.dumps(task_progress)
  184. )
  185. actual_current = task_progress.get("current")
  186. if current is not None:
  187. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {actual_current}%")
  188. else:
  189. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度保持: {actual_current}% (未传入)")
  190. except Exception as sync_e:
  191. logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
  192. # 同步操作也失败时,降级到内存存储
  193. if not hasattr(self, 'current_data'):
  194. self.current_data = {}
  195. self.current_data[callback_task_id] = task_progress
  196. logger.debug(f"降级使用内存存储: {callback_task_id}")
  197. else:
  198. if not hasattr(self, 'current_data'):
  199. self.current_data = {}
  200. self.current_data[callback_task_id] = task_progress
  201. actual_current = task_progress.get("current")
  202. if current is not None:
  203. logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {actual_current}%")
  204. else:
  205. logger.debug(f"更新进度到内存: {callback_task_id}, 进度保持: {actual_current}% (未传入)")
  206. except Exception as e:
  207. logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
  208. if not hasattr(self, 'current_data'):
  209. self.current_data = {}
  210. self.current_data[callback_task_id] = task_progress
  211. # 进度已保存到 Redis,SSE 由主进程通过轮询获取
  212. actual_current = task_progress.get("current")
  213. if current is not None:
  214. logger.debug(f"进度已更新到Redis: {callback_task_id}, current={actual_current}%")
  215. else:
  216. logger.debug(f"进度已更新到Redis: {callback_task_id}, current={actual_current}% (保持)")
  217. except Exception as e:
  218. logger.error(f"更新阶段进度失败: {str(e)}")
  219. raise
  220. async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
  221. """获取任务进度"""
  222. try:
  223. #logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
  224. task_progress = None
  225. if self.redis_connected:
  226. # 从Redis读取
  227. redis_key = await self._get_redis_key(callback_task_id)
  228. #logger.debug(f"Redis键: {redis_key}")
  229. progress_json = self.redis_client.get(redis_key)
  230. #logger.debug(f"从Redis读取数据: {progress_json is not None}")
  231. if progress_json:
  232. task_progress = json.loads(progress_json)
  233. else:
  234. logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
  235. return None
  236. else:
  237. # 从内存读取
  238. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  239. task_progress = self.current_data[callback_task_id]
  240. else:
  241. logger.debug(f"内存中未找到任务进度: {callback_task_id}")
  242. return None
  243. # 获取overall_task_status,默认为"pending"
  244. overall_task_status = task_progress.get("overall_task_status", "pending")
  245. # 转换时间戳
  246. updated_at = task_progress["updated_at"]
  247. if isinstance(updated_at, str):
  248. updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
  249. else:
  250. updated_at_timestamp = int(updated_at.timestamp())
  251. # 构建返回数据
  252. result = {
  253. "callback_task_id": callback_task_id,
  254. "user_id": task_progress["user_id"],
  255. "current": task_progress["current"],
  256. "stage_name": task_progress["stage_name"],
  257. "status": task_progress["status"],
  258. "message": task_progress["message"],
  259. "overall_task_status": overall_task_status,
  260. "updated_at": updated_at_timestamp
  261. }
  262. # 添加可选字段
  263. if "issues" in task_progress:
  264. result["issues"] = task_progress["issues"]
  265. if "event_type" in task_progress:
  266. result["event_type"] = task_progress["event_type"]
  267. return result
  268. except Exception as e:
  269. logger.error(f"获取进度失败: {str(e)}")
  270. return None
  271. async def complete_task(self, callback_task_id: str, user_id: str = None, current_data: dict = None):
  272. """标记任务完成 - 使用单一同步强制关闭逻辑
  273. Args:
  274. callback_task_id: 回调任务ID
  275. user_id: 用户ID
  276. current_data: 包含 overall_task_status 等数据的字典
  277. """
  278. try:
  279. # 先更新 Redis 中的状态(让 SSE 轮询能检测到)
  280. if current_data and current_data.get("overall_task_status"):
  281. progress_data = await self.get_progress(callback_task_id)
  282. if progress_data:
  283. # 更新状态字段
  284. progress_data["overall_task_status"] = current_data.get("overall_task_status")
  285. progress_data["status"] = current_data.get("status", current_data.get("overall_task_status"))
  286. progress_data["message"] = current_data.get("message", "任务已完成")
  287. progress_data["updated_at"] = datetime.now().isoformat()
  288. if current_data.get("error"):
  289. progress_data["error"] = current_data.get("error")
  290. # 保存回 Redis
  291. if self.redis_connected:
  292. redis_key = await self._get_redis_key(callback_task_id)
  293. self.redis_client.setex(
  294. redis_key,
  295. 3600,
  296. json.dumps(progress_data)
  297. )
  298. logger.info(f"任务状态已更新到Redis: {callback_task_id}, status={current_data.get('overall_task_status')}")
  299. else:
  300. if not hasattr(self, 'current_data'):
  301. self.current_data = {}
  302. self.current_data[callback_task_id] = progress_data
  303. # SSE 连接由主进程管理(通过 Redis 轮询),Worker 无需关闭
  304. logger.info(f"任务状态已更新,等待主进程轮询检测: {callback_task_id}")
  305. except Exception as e:
  306. logger.error(f"标记任务完成失败: {str(e)}")
  307. raise