workflow_manager.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. """
  2. 基于LangGraph的工作流管理器
  3. 负责任务的创建、编排和执行,使用LangGraph进行状态管理
  4. 新增功能:
  5. - 任务终止管理
  6. - 终止信号设置和检测
  7. """
  8. import asyncio
  9. import time
  10. from typing import Dict, Optional
  11. from datetime import datetime
  12. from foundation.observability.logger.loggering import server_logger as logger
  13. from foundation.observability.monitoring.time_statistics import track_execution_time
  14. from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory
  15. from .progress_manager import ProgressManager
  16. from .redis_duplicate_checker import RedisDuplicateChecker
  17. from .task_models import TaskFileInfo, TaskChain
  18. from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow
  19. class ProgressManagerRegistry:
  20. """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
  21. _registry = {} # {callback_task_id: ProgressManager}
  22. @classmethod
  23. def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
  24. """注册ProgressManager实例"""
  25. cls._registry[callback_task_id] = progress_manager
  26. logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
  27. @classmethod
  28. def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
  29. """获取ProgressManager实例"""
  30. return cls._registry.get(callback_task_id)
  31. @classmethod
  32. def unregister_progress_manager(cls, callback_task_id: str):
  33. """注销ProgressManager实例"""
  34. if callback_task_id in cls._registry:
  35. del cls._registry[callback_task_id]
  36. logger.info(f"注销ProgressManager实例: {callback_task_id}")
  37. class WorkflowManager:
  38. """工作流管理器"""
  39. def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10):
  40. self.max_concurrent_docs = max_concurrent_docs
  41. self.max_concurrent_reviews = max_concurrent_reviews
  42. # 并发控制
  43. self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs)
  44. self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
  45. # 服务组件
  46. self.progress_manager = ProgressManager()
  47. self.redis_duplicate_checker = RedisDuplicateChecker()
  48. # 活跃任务跟踪
  49. self.active_chains: Dict[str, TaskChain] = {}
  50. self._cleanup_task_started = False
  51. # 任务终止管理
  52. self._terminate_signal_prefix = "ai_review:terminate_signal:"
  53. self._task_expire_time = 7200 # 2小时
  54. async def submit_task_processing(self, file_info: dict) -> str:
  55. """异步提交任务处理(用于file_upload层)"""
  56. from foundation.infrastructure.messaging.tasks import submit_task_processing_task
  57. from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
  58. try:
  59. logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
  60. # 使用CeleryTraceManager提交任务,自动传递trace_id
  61. task = CeleryTraceManager.submit_celery_task(
  62. submit_task_processing_task,
  63. file_info
  64. )
  65. logger.info(f"Celery任务已提交,Task ID: {task.id}")
  66. return task.id
  67. except Exception as e:
  68. logger.error(f"提交Celery任务失败: {str(e)}")
  69. raise
  70. @track_execution_time
  71. def submit_task_processing_sync(self, file_info: dict) -> dict:
  72. """同步提交任务处理(用于Celery worker)"""
  73. try:
  74. logger.info(f"提交文档处理任务: {file_info['file_id']}")
  75. # 1. 创建TaskFileInfo对象(封装任务文件信息)
  76. task_file_info = TaskFileInfo(file_info)
  77. logger.info(f"创建任务文件信息: {task_file_info}")
  78. # 2. 生成任务链ID
  79. callback_task_id = task_file_info.callback_task_id
  80. # 3. 创建任务链(引用 TaskFileInfo,避免数据重复)
  81. task_chain = TaskChain(task_file_info)
  82. # 4. 标记任务开始
  83. task_chain.start_processing()
  84. # 5. 添加到活跃任务跟踪
  85. self.active_chains[callback_task_id] = task_chain
  86. # 5. 初始化进度管理
  87. asyncio.run(self.progress_manager.initialize_progress(
  88. callback_task_id=callback_task_id,
  89. user_id=task_file_info.user_id,
  90. stages=[]
  91. ))
  92. # 6. 启动处理流程(同步执行)
  93. self._process_task_chain_sync(task_chain, task_file_info, task_file_info.file_type)
  94. # logger.info(f"提交文档处理任务: {callback_task_id}")
  95. logger.info(f"施工方案审查任务已完成! ")
  96. logger.info(f"文件ID: {task_file_info.file_id}")
  97. logger.info(f"文件名:{task_file_info.file_name}")
  98. except Exception as e:
  99. logger.error(f"提交文档处理任务失败: {str(e)}")
  100. raise
  101. def _process_task_chain_sync(self, task_chain: TaskChain, task_file_info: TaskFileInfo, file_type: str):
  102. """同步处理文档任务链(用于Celery worker)"""
  103. try:
  104. file_content = task_file_info.file_content
  105. # 阶段1:文档处理(串行)
  106. document_workflow = DocumentWorkflow(
  107. task_file_info=task_file_info,
  108. progress_manager=self.progress_manager,
  109. redis_duplicate_checker=self.redis_duplicate_checker
  110. )
  111. # 同步执行文档处理
  112. loop = asyncio.new_event_loop()
  113. asyncio.set_event_loop(loop)
  114. doc_result = loop.run_until_complete(document_workflow.execute(file_content, file_type))
  115. loop.close()
  116. task_chain.results['document'] = doc_result
  117. # 阶段2:AI审查(内部并发)
  118. task_chain.update_stage("ai_review")
  119. structured_content = doc_result['structured_content']
  120. # 读取AI审查配置
  121. import configparser
  122. config = configparser.ConfigParser()
  123. config.read('config/config.ini', encoding='utf-8')
  124. max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
  125. if max_review_units == 0: # 如果配置为0,表示审查所有
  126. max_review_units = None
  127. review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
  128. logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}")
  129. ai_workflow = AIReviewWorkflow(
  130. task_file_info=task_file_info,
  131. structured_content=structured_content,
  132. progress_manager=self.progress_manager,
  133. max_review_units=max_review_units,
  134. review_mode=review_mode
  135. )
  136. # 同步执行AI审查
  137. loop = asyncio.new_event_loop()
  138. asyncio.set_event_loop(loop)
  139. ai_result = loop.run_until_complete(ai_workflow.execute())
  140. loop.close()
  141. task_chain.results['ai_review'] = ai_result
  142. # # 阶段3:报告生成(串行)
  143. # task_chain.current_stage = "report_generation"
  144. # report_workflow = ReportWorkflow(
  145. # file_id=task_chain.file_id,
  146. # callback_task_id=task_chain.callback_task_id,
  147. # user_id=task_chain.user_id,
  148. # ai_review_results=ai_result,
  149. # progress_manager=self.progress_manager
  150. # )
  151. # # 同步执行报告生成
  152. # loop = asyncio.new_event_loop()
  153. # asyncio.set_event_loop(loop)
  154. # report_result = loop.run_until_complete(report_workflow.execute())
  155. # loop.close()
  156. # task_chain.results['report'] = report_result
  157. # 完成任务链
  158. task_chain.complete_processing()
  159. # 清理任务注册
  160. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  161. # 通知SSE连接任务完成
  162. asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id, task_chain.user_id))
  163. # 清理Redis文件缓存
  164. try:
  165. from foundation.utils.redis_utils import delete_file_info
  166. asyncio.run(delete_file_info(task_chain.file_id))
  167. logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
  168. except Exception as e:
  169. logger.warning(f"清理Redis文件缓存失败: {str(e)}")
  170. logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
  171. return task_chain.results
  172. except Exception as e:
  173. # 标记任务失败
  174. task_chain.fail_processing(str(e))
  175. logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}")
  176. # 清理任务注册
  177. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  178. # 清理Redis文件缓存(即使失败也清理)
  179. try:
  180. from foundation.utils.redis_utils import delete_file_info
  181. asyncio.run(delete_file_info(task_chain.file_id))
  182. logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
  183. except Exception as cleanup_error:
  184. logger.warning(f"清理Redis文件缓存失败: {str(cleanup_error)}")
  185. # 通知SSE连接任务失败
  186. error_result = {
  187. "error": str(e),
  188. "status": "failed",
  189. "timestamp": datetime.now().isoformat()
  190. }
  191. current_data = {
  192. "status": "failed",
  193. "result": error_result
  194. }
  195. asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id, task_chain.user_id, current_data))
  196. raise
  197. finally:
  198. # 清理活跃任务
  199. if task_chain.callback_task_id in self.active_chains:
  200. del self.active_chains[task_chain.callback_task_id]
  201. # ==================== 任务终止管理方法 ====================
  202. async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
  203. """
  204. 设置任务终止信号
  205. Args:
  206. callback_task_id: 任务回调ID
  207. operator: 操作人(用户ID或系统标识)
  208. Returns:
  209. Dict: 操作结果 {"success": bool, "message": str, "task_info": dict}
  210. Note:
  211. 将终止信号写入 Redis,支持跨进程检测
  212. AI审查节点在执行前会检查此信号
  213. """
  214. try:
  215. # 检查任务是否在活跃列表中
  216. if callback_task_id not in self.active_chains:
  217. return {
  218. "success": False,
  219. "message": f"任务不存在或已完成: {callback_task_id}",
  220. "task_info": None
  221. }
  222. task_chain = self.active_chains[callback_task_id]
  223. # 检查任务状态
  224. if task_chain.status != "processing":
  225. return {
  226. "success": False,
  227. "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})",
  228. "task_info": {
  229. "callback_task_id": callback_task_id,
  230. "status": task_chain.status,
  231. "file_name": task_chain.file_name
  232. }
  233. }
  234. # 设置 Redis 终止信号
  235. redis_client = await RedisConnectionFactory.get_connection()
  236. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  237. # 存储终止信号和操作人、时间
  238. terminate_data = {
  239. "operator": operator,
  240. "terminate_time": str(time.time()),
  241. "task_id": callback_task_id
  242. }
  243. # 使用 hash 存储更多信息
  244. await redis_client.hset(terminate_key, mapping=terminate_data)
  245. # 设置过期时间(2小时)
  246. await redis_client.expire(terminate_key, self._task_expire_time)
  247. logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})")
  248. return {
  249. "success": True,
  250. "message": f"终止信号已设置,任务将在当前节点完成后终止",
  251. "task_info": {
  252. "callback_task_id": callback_task_id,
  253. "file_id": task_chain.file_id,
  254. "file_name": task_chain.file_name,
  255. "user_id": task_chain.user_id,
  256. "status": task_chain.status,
  257. "current_stage": task_chain.current_stage
  258. }
  259. }
  260. except Exception as e:
  261. logger.error(f"设置终止信号失败: {str(e)}", exc_info=True)
  262. return {
  263. "success": False,
  264. "message": f"设置终止信号失败: {str(e)}",
  265. "task_info": None
  266. }
  267. async def check_terminate_signal(self, callback_task_id: str) -> bool:
  268. """
  269. 检查是否有终止信号
  270. Args:
  271. callback_task_id: 任务回调ID
  272. Returns:
  273. bool: 有终止信号返回 True
  274. Note:
  275. 从 Redis 读取终止信号
  276. 工作流节点在执行前调用此方法检查是否需要终止
  277. """
  278. try:
  279. redis_client = await RedisConnectionFactory.get_connection()
  280. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  281. # 检查键是否存在
  282. exists = await redis_client.exists(terminate_key)
  283. if exists:
  284. # 读取终止信息
  285. terminate_info = await redis_client.hgetall(terminate_key)
  286. logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
  287. return True
  288. return False
  289. except RuntimeError as e:
  290. # 事件循环相关的错误处理
  291. error_msg = str(e)
  292. if "Event loop is closed" in error_msg:
  293. # 事件循环关闭是正常情况(任务结束),不记录错误
  294. logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}")
  295. return False
  296. elif "bound to a different event loop" in error_msg:
  297. # 事件循环不匹配,记录警告但不中断流程
  298. logger.warning(f"检查终止信号时检测到事件循环不匹配: {callback_task_id},将忽略本次检查")
  299. return False
  300. else:
  301. # 其他 RuntimeError 记录错误
  302. logger.error(f"检查终止信号失败(RuntimeError): {error_msg}", exc_info=True)
  303. return False
  304. except Exception as e:
  305. # 其他异常仍然记录错误
  306. logger.error(f"检查终止信号失败: {str(e)}", exc_info=True)
  307. return False
  308. async def clear_terminate_signal(self, callback_task_id: str):
  309. """
  310. 清理 Redis 中的终止信号
  311. Args:
  312. callback_task_id: 任务回调ID
  313. """
  314. try:
  315. redis_client = await RedisConnectionFactory.get_connection()
  316. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  317. await redis_client.delete(terminate_key)
  318. logger.debug(f"清理终止信号: {callback_task_id}")
  319. except Exception as e:
  320. logger.warning(f"清理终止信号失败: {str(e)}")
  321. async def get_active_tasks(self) -> list:
  322. """
  323. 获取活跃任务列表
  324. Returns:
  325. list: 活跃任务信息列表
  326. """
  327. try:
  328. active_tasks = []
  329. current_time = time.time()
  330. for task_id, task_chain in self.active_chains.items():
  331. if task_chain.status == "processing":
  332. task_info = {
  333. "callback_task_id": task_id,
  334. "file_id": task_chain.file_id,
  335. "file_name": task_chain.file_name,
  336. "user_id": task_chain.user_id,
  337. "status": task_chain.status,
  338. "current_stage": task_chain.current_stage,
  339. "start_time": task_chain.start_time,
  340. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0
  341. }
  342. active_tasks.append(task_info)
  343. return active_tasks
  344. except Exception as e:
  345. logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
  346. return []
  347. async def get_task_info(self, callback_task_id: str) -> Optional[Dict]:
  348. """
  349. 获取任务信息
  350. Args:
  351. callback_task_id: 任务回调ID
  352. Returns:
  353. Optional[Dict]: 任务信息字典,不存在返回 None
  354. """
  355. try:
  356. task_chain = self.active_chains.get(callback_task_id)
  357. if task_chain:
  358. current_time = time.time()
  359. return {
  360. "callback_task_id": callback_task_id,
  361. "file_id": task_chain.file_id,
  362. "file_name": task_chain.file_name,
  363. "user_id": task_chain.user_id,
  364. "status": task_chain.status,
  365. "current_stage": task_chain.current_stage,
  366. "start_time": task_chain.start_time,
  367. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0,
  368. "results": task_chain.results
  369. }
  370. return None
  371. except Exception as e:
  372. logger.error(f"获取任务信息失败: {str(e)}", exc_info=True)
  373. return None