progress_manager.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import json
  2. import time
  3. import asyncio
  4. from typing import Dict, Any, Optional, List
  5. from datetime import datetime
  6. from urllib.parse import quote
  7. from foundation.observability.logger.loggering import write_logger as logger
  8. from foundation.infrastructure.config import config_handler
  9. class ProgressManager:
  10. """任务进度管理器"""
  11. def __init__(self):
  12. self.redis_client = None
  13. self.redis_connected = False
  14. self.current_data = {}
  15. self.stream_events = {}
  16. self._missing_progress_log_state = {}
  17. self._init_redis()
  18. def _log_missing_progress(self, callback_task_id: str, storage_name: str, log_func=logger.debug):
  19. """Avoid flooding logs while SSE polling waits for worker progress."""
  20. now = time.time()
  21. state = self._missing_progress_log_state.get(callback_task_id, {"count": 0, "last_log_at": 0})
  22. state["count"] += 1
  23. should_log = state["count"] == 1 or now - state["last_log_at"] >= 30
  24. if should_log:
  25. suffix = ""
  26. if state["count"] > 1:
  27. suffix = f" (已连续 {state['count']} 次,后续每30秒提示一次)"
  28. log_func(f"{storage_name}中未找到任务进度: {callback_task_id}{suffix}")
  29. state["last_log_at"] = now
  30. self._missing_progress_log_state[callback_task_id] = state
  31. def _clear_missing_progress_state(self, callback_task_id: str):
  32. self._missing_progress_log_state.pop(callback_task_id, None)
  33. def _init_redis(self):
  34. try:
  35. import redis
  36. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  37. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  38. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  39. redis_db = config_handler.get('redis', 'REDIS_DB', '0')
  40. if redis_password:
  41. redis_url = f"redis://:{quote(redis_password, safe='')}@{redis_host}:{redis_port}/{redis_db}"
  42. else:
  43. redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
  44. logger.debug(f"ProgressManager connecting to Redis: {redis_host}:{redis_port}/{redis_db}")
  45. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  46. self.redis_client.ping()
  47. self.redis_connected = True
  48. logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
  49. except Exception as e:
  50. logger.error(f"ProgressManager Redis连接失败: {e}")
  51. self.redis_connected = False
  52. logger.warning("ProgressManager将使用内存存储作为备选方案")
  53. self.current_data = {}
  54. async def _get_redis_key(self, callback_task_id: str) -> str:
  55. return f"current:{callback_task_id}"
  56. async def _get_stream_event_key(self, callback_task_id: str) -> str:
  57. return f"stream_events:{callback_task_id}"
  58. async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
  59. try:
  60. current_data = {
  61. "user_id": user_id,
  62. "current": 0,
  63. "stage_name": "",
  64. "status": "准备开始",
  65. "message": "任务开始",
  66. "updated_at": datetime.now().isoformat(),
  67. "overall_task_status": "pending"
  68. }
  69. if self.redis_connected:
  70. try:
  71. redis_key = await self._get_redis_key(callback_task_id)
  72. existing_progress_json = self.redis_client.get(redis_key)
  73. if existing_progress_json:
  74. try:
  75. existing_progress = json.loads(existing_progress_json)
  76. current_data.update(existing_progress)
  77. except Exception as parse_e:
  78. logger.warning(f"解析已有任务进度失败,将重新初始化: {callback_task_id}, {parse_e}")
  79. self.redis_client.setex(
  80. redis_key,
  81. 3600,
  82. json.dumps(current_data)
  83. )
  84. stream_event_key = await self._get_stream_event_key(callback_task_id)
  85. self.redis_client.delete(stream_event_key)
  86. logger.info(f"初始化任务进度列表")
  87. except Exception as redis_e:
  88. logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
  89. self.current_data[callback_task_id] = current_data
  90. self.stream_events[callback_task_id] = []
  91. logger.info(f"降级使用内存存储: {callback_task_id}")
  92. else:
  93. if callback_task_id in self.current_data:
  94. current_data.update(self.current_data[callback_task_id])
  95. self.current_data[callback_task_id] = current_data
  96. self.stream_events[callback_task_id] = []
  97. logger.info(f"初始化任务进度到内存: {callback_task_id}")
  98. self._clear_missing_progress_state(callback_task_id)
  99. except Exception as e:
  100. logger.error(f"初始化进度失败: {str(e)}")
  101. raise
  102. async def append_stream_event(self, callback_task_id: str, event_data: Dict[str, Any]):
  103. """追加流式事件,供 SSE 轮询端按事件顺序吐给前端。"""
  104. try:
  105. if self.redis_connected:
  106. stream_event_key = await self._get_stream_event_key(callback_task_id)
  107. self.redis_client.rpush(
  108. stream_event_key,
  109. json.dumps(event_data, ensure_ascii=False)
  110. )
  111. self.redis_client.expire(stream_event_key, 3600)
  112. else:
  113. if not hasattr(self, 'stream_events'):
  114. self.stream_events = {}
  115. self.stream_events.setdefault(callback_task_id, []).append(event_data)
  116. except Exception as e:
  117. logger.error(f"追加流式事件失败: {callback_task_id}, {str(e)}")
  118. raise
  119. async def pop_stream_events(self, callback_task_id: str, max_events: int = 100) -> List[Dict[str, Any]]:
  120. """按 FIFO 顺序取出待发送的流式事件。"""
  121. events: List[Dict[str, Any]] = []
  122. try:
  123. if self.redis_connected:
  124. stream_event_key = await self._get_stream_event_key(callback_task_id)
  125. for _ in range(max_events):
  126. event_json = self.redis_client.lpop(stream_event_key)
  127. if not event_json:
  128. break
  129. try:
  130. events.append(json.loads(event_json))
  131. except Exception as parse_e:
  132. logger.warning(f"解析流式事件失败: {callback_task_id}, {parse_e}")
  133. else:
  134. if not hasattr(self, 'stream_events'):
  135. self.stream_events = {}
  136. queued_events = self.stream_events.get(callback_task_id, [])
  137. events = queued_events[:max_events]
  138. self.stream_events[callback_task_id] = queued_events[max_events:]
  139. except Exception as e:
  140. logger.error(f"获取流式事件失败: {callback_task_id}, {str(e)}")
  141. return events
  142. async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None, event_type: str = "processing"):
  143. """更新阶段进度 - 除callback_task_id外,其他参数都可选
  144. Args:
  145. callback_task_id: 回调任务ID(必需)
  146. stage_name: 阶段名称(可选)
  147. current: 当前进度(可选)
  148. status: 状态(可选)
  149. message: 消息(可选)
  150. issues: 问题列表(可选)
  151. user_id: 用户ID(可选)
  152. overall_task_status: 整体任务状态(可选)
  153. event_type: SSE事件类型(可选,默认为"processing")
  154. """
  155. try:
  156. task_progress = None
  157. if self.redis_connected:
  158. # 从Redis读取
  159. redis_key = await self._get_redis_key(callback_task_id)
  160. progress_json = self.redis_client.get(redis_key)
  161. if progress_json:
  162. task_progress = json.loads(progress_json)
  163. self._clear_missing_progress_state(callback_task_id)
  164. else:
  165. self._log_missing_progress(callback_task_id, "Redis", logger.warning)
  166. return
  167. else:
  168. # 从内存读取
  169. if callback_task_id in self.current_data:
  170. task_progress = self.current_data[callback_task_id]
  171. self._clear_missing_progress_state(callback_task_id)
  172. else:
  173. self._log_missing_progress(callback_task_id, "内存", logger.warning)
  174. return
  175. # 更新进度数据 - 只有非空参数才更新
  176. if current is not None:
  177. task_progress["current"] = current
  178. if stage_name is not None:
  179. task_progress["stage_name"] = stage_name
  180. if status is not None:
  181. task_progress["status"] = status
  182. if message is not None:
  183. task_progress["message"] = message
  184. task_progress["updated_at"] = datetime.now().isoformat()
  185. if issues is not None:
  186. task_progress["issues"] = issues
  187. else:
  188. task_progress["issues"] = []
  189. if user_id is not None:
  190. task_progress["user_id"] = user_id
  191. if overall_task_status is not None:
  192. task_progress["overall_task_status"] = overall_task_status
  193. elif "overall_task_status" not in task_progress:
  194. task_progress["overall_task_status"] = "processing"
  195. if event_type is not None:
  196. task_progress["event_type"] = event_type
  197. logger.debug(f"设置event_type: {event_type} for {callback_task_id}")
  198. else:
  199. logger.debug(f"event_type为None,不设置 for {callback_task_id}")
  200. try:
  201. if self.redis_connected:
  202. try:
  203. self.redis_client.setex(
  204. redis_key,
  205. 3600, # 1小时过期
  206. json.dumps(task_progress)
  207. )
  208. actual_current = task_progress.get("current")
  209. if current is not None:
  210. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {actual_current}%")
  211. else:
  212. logger.debug(f"更新进度到Redis: {callback_task_id}, 进度保持: {actual_current}% (未传入)")
  213. except Exception as sync_e:
  214. logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
  215. # 同步操作也失败时,降级到内存存储
  216. if not hasattr(self, 'current_data'):
  217. self.current_data = {}
  218. self.current_data[callback_task_id] = task_progress
  219. logger.debug(f"降级使用内存存储: {callback_task_id}")
  220. else:
  221. if not hasattr(self, 'current_data'):
  222. self.current_data = {}
  223. self.current_data[callback_task_id] = task_progress
  224. actual_current = task_progress.get("current")
  225. if current is not None:
  226. logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {actual_current}%")
  227. else:
  228. logger.debug(f"更新进度到内存: {callback_task_id}, 进度保持: {actual_current}% (未传入)")
  229. except Exception as e:
  230. logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
  231. if not hasattr(self, 'current_data'):
  232. self.current_data = {}
  233. self.current_data[callback_task_id] = task_progress
  234. # 进度已保存到 Redis,SSE 由主进程通过轮询获取
  235. actual_current = task_progress.get("current")
  236. if current is not None:
  237. logger.debug(f"进度已更新到Redis: {callback_task_id}, current={actual_current}%")
  238. else:
  239. logger.debug(f"进度已更新到Redis: {callback_task_id}, current={actual_current}% (保持)")
  240. except Exception as e:
  241. logger.error(f"更新阶段进度失败: {str(e)}")
  242. raise
  243. async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
  244. """获取任务进度"""
  245. try:
  246. #logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
  247. task_progress = None
  248. if self.redis_connected:
  249. # 从Redis读取
  250. redis_key = await self._get_redis_key(callback_task_id)
  251. #logger.debug(f"Redis键: {redis_key}")
  252. progress_json = self.redis_client.get(redis_key)
  253. #logger.debug(f"从Redis读取数据: {progress_json is not None}")
  254. if progress_json:
  255. task_progress = json.loads(progress_json)
  256. self._clear_missing_progress_state(callback_task_id)
  257. else:
  258. self._log_missing_progress(callback_task_id, "Redis", logger.debug)
  259. return None
  260. else:
  261. # 从内存读取
  262. if hasattr(self, 'current_data') and callback_task_id in self.current_data:
  263. task_progress = self.current_data[callback_task_id]
  264. self._clear_missing_progress_state(callback_task_id)
  265. else:
  266. self._log_missing_progress(callback_task_id, "内存", logger.debug)
  267. return None
  268. # 获取overall_task_status,默认为"pending"
  269. overall_task_status = task_progress.get("overall_task_status", "pending")
  270. # 转换时间戳
  271. updated_at = task_progress["updated_at"]
  272. if isinstance(updated_at, str):
  273. updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
  274. elif isinstance(updated_at, (int, float)):
  275. updated_at_timestamp = int(updated_at)
  276. else:
  277. updated_at_timestamp = int(updated_at.timestamp())
  278. # 构建返回数据
  279. result = {
  280. "callback_task_id": callback_task_id,
  281. "user_id": task_progress["user_id"],
  282. "current": task_progress["current"],
  283. "stage_name": task_progress["stage_name"],
  284. "status": task_progress["status"],
  285. "message": task_progress["message"],
  286. "overall_task_status": overall_task_status,
  287. "updated_at": updated_at_timestamp
  288. }
  289. # 添加可选字段
  290. if "issues" in task_progress:
  291. result["issues"] = task_progress["issues"]
  292. if "event_type" in task_progress:
  293. result["event_type"] = task_progress["event_type"]
  294. return result
  295. except Exception as e:
  296. logger.error(f"获取进度失败: {str(e)}")
  297. return None
  298. async def complete_task(self, callback_task_id: str, user_id: str = None, current_data: dict = None):
  299. """标记任务完成 - 使用单一同步强制关闭逻辑
  300. Args:
  301. callback_task_id: 回调任务ID
  302. user_id: 用户ID
  303. current_data: 包含 overall_task_status 等数据的字典
  304. """
  305. try:
  306. # 先更新 Redis 中的状态(让 SSE 轮询能检测到)
  307. if current_data and current_data.get("overall_task_status"):
  308. progress_data = await self.get_progress(callback_task_id)
  309. if progress_data:
  310. # 更新状态字段
  311. progress_data["overall_task_status"] = current_data.get("overall_task_status")
  312. progress_data["status"] = current_data.get("status", current_data.get("overall_task_status"))
  313. progress_data["message"] = current_data.get("message", "任务已完成")
  314. progress_data["updated_at"] = datetime.now().isoformat()
  315. if current_data.get("error"):
  316. progress_data["error"] = current_data.get("error")
  317. # 保存回 Redis
  318. if self.redis_connected:
  319. redis_key = await self._get_redis_key(callback_task_id)
  320. self.redis_client.setex(
  321. redis_key,
  322. 3600,
  323. json.dumps(progress_data)
  324. )
  325. logger.info(f"任务状态已更新到Redis: {callback_task_id}, status={current_data.get('overall_task_status')}")
  326. else:
  327. if not hasattr(self, 'current_data'):
  328. self.current_data = {}
  329. self.current_data[callback_task_id] = progress_data
  330. # SSE 连接由主进程管理(通过 Redis 轮询),Worker 无需关闭
  331. logger.info(f"任务状态已更新,等待主进程轮询检测: {callback_task_id}")
  332. except Exception as e:
  333. logger.error(f"标记任务完成失败: {str(e)}")
  334. raise