workflow_manager.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. """
  2. 基于LangGraph的工作流管理器
  3. 负责任务的创建、编排和执行,使用LangGraph进行状态管理
  4. 新增功能:
  5. - 任务终止管理
  6. - 终止信号设置和检测
  7. """
  8. import asyncio
  9. import time
  10. from typing import Dict, Optional, Any
  11. from datetime import datetime
  12. from langgraph.graph import StateGraph, END
  13. from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
  14. from foundation.observability.logger.loggering import server_logger as logger
  15. from foundation.observability.monitoring.time_statistics import track_execution_time
  16. from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory
  17. from .progress_manager import ProgressManager
  18. from .redis_duplicate_checker import RedisDuplicateChecker
  19. from .task_models import TaskFileInfo, TaskChain
  20. from ..construction_review.workflows import DocumentWorkflow, AIReviewWorkflow, ReportWorkflow
  21. from ..construction_review.workflows.types import TaskChainState
  22. class ProgressManagerRegistry:
  23. """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
  24. _registry = {} # {callback_task_id: ProgressManager}
  25. @classmethod
  26. def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
  27. """注册ProgressManager实例"""
  28. cls._registry[callback_task_id] = progress_manager
  29. logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
  30. @classmethod
  31. def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
  32. """获取ProgressManager实例"""
  33. return cls._registry.get(callback_task_id)
  34. @classmethod
  35. def unregister_progress_manager(cls, callback_task_id: str):
  36. """注销ProgressManager实例"""
  37. if callback_task_id in cls._registry:
  38. del cls._registry[callback_task_id]
  39. logger.info(f"注销ProgressManager实例: {callback_task_id}")
  40. class WorkflowManager:
  41. """工作流管理器"""
  42. def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10):
  43. self.max_concurrent_docs = max_concurrent_docs
  44. self.max_concurrent_reviews = max_concurrent_reviews
  45. # 并发控制
  46. self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs)
  47. self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
  48. # 服务组件
  49. self.progress_manager = ProgressManager()
  50. self.redis_duplicate_checker = RedisDuplicateChecker()
  51. # 活跃任务跟踪
  52. self.active_chains: Dict[str, TaskChain] = {}
  53. self._cleanup_task_started = False
  54. # 任务终止管理
  55. self._terminate_signal_prefix = "ai_review:terminate_signal:"
  56. self._task_expire_time = 7200 # 2小时
  57. # LangGraph 任务链工作流(方案D)
  58. self.task_chain_graph = None # 延迟初始化,避免循环导入
  59. async def submit_task_processing(self, file_info: dict) -> str:
  60. """异步提交任务处理(用于file_upload层)"""
  61. from foundation.infrastructure.messaging.tasks import submit_task_processing_task
  62. from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
  63. try:
  64. logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
  65. # 使用CeleryTraceManager提交任务,自动传递trace_id
  66. task = CeleryTraceManager.submit_celery_task(
  67. submit_task_processing_task,
  68. file_info
  69. )
  70. logger.info(f"Celery任务已提交,Task ID: {task.id}")
  71. return task.id
  72. except Exception as e:
  73. logger.error(f"提交Celery任务失败: {str(e)}")
  74. raise
  75. @track_execution_time
  76. def submit_task_processing_sync(self, file_info: dict) -> dict:
  77. """
  78. 同步提交任务处理(用于Celery worker)
  79. Note:
  80. 已切换到 LangGraph 任务链工作流(方案D)
  81. 使用统一的状态管理和嵌套子图架构
  82. """
  83. try:
  84. logger.info(f"提交文档处理任务(LangGraph方案D): {file_info['file_id']}")
  85. # 1. 创建TaskFileInfo对象(封装任务文件信息)
  86. task_file_info = TaskFileInfo(file_info)
  87. logger.info(f"创建任务文件信息: {task_file_info}")
  88. # 2. 生成任务链ID
  89. callback_task_id = task_file_info.callback_task_id
  90. # 3. 创建任务链(引用 TaskFileInfo,避免数据重复)
  91. task_chain = TaskChain(task_file_info)
  92. # 4. 标记任务开始
  93. task_chain.start_processing()
  94. # 5. 添加到活跃任务跟踪
  95. self.active_chains[callback_task_id] = task_chain
  96. # 6. 初始化进度管理
  97. asyncio.run(self.progress_manager.initialize_progress(
  98. callback_task_id=callback_task_id,
  99. user_id=task_file_info.user_id,
  100. stages=[]
  101. ))
  102. # 7. 构建 LangGraph 任务链工作流(延迟初始化)
  103. if self.task_chain_graph is None:
  104. self.task_chain_graph = self._build_task_chain_workflow()
  105. # 8. 构建初始状态
  106. initial_state = TaskChainState(
  107. file_id=task_file_info.file_id,
  108. callback_task_id=callback_task_id,
  109. user_id=task_file_info.user_id,
  110. file_name=task_file_info.file_name,
  111. file_type=task_file_info.file_type,
  112. file_content=task_file_info.file_content,
  113. current_stage="start",
  114. overall_task_status="processing",
  115. stage_status={
  116. "document": "pending",
  117. "ai_review": "pending",
  118. "report": "pending"
  119. },
  120. document_result=None,
  121. ai_review_result=None,
  122. report_result=None,
  123. error_message=None,
  124. progress_manager=self.progress_manager,
  125. task_file_info=task_file_info,
  126. messages=[HumanMessage(content=f"开始任务链: {task_file_info.file_id}")]
  127. )
  128. # 9. 执行 LangGraph 任务链工作流
  129. loop = asyncio.new_event_loop()
  130. asyncio.set_event_loop(loop)
  131. result = loop.run_until_complete(self.task_chain_graph.ainvoke(initial_state))
  132. loop.close()
  133. # 10. 清理任务注册
  134. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  135. logger.info(f"施工方案审查任务已完成(LangGraph方案D)!")
  136. logger.info(f"文件ID: {task_file_info.file_id}")
  137. logger.info(f"文件名: {task_file_info.file_name}")
  138. logger.info(f"整体状态: {result.get('overall_task_status', 'unknown')}")
  139. # 构建可序列化的返回结果(移除不可序列化的对象)
  140. serializable_result = {
  141. "file_id": result.get("file_id"),
  142. "callback_task_id": result.get("callback_task_id"),
  143. "user_id": result.get("user_id"),
  144. "file_name": result.get("file_name"),
  145. "current_stage": result.get("current_stage"),
  146. "overall_task_status": result.get("overall_task_status"),
  147. "stage_status": result.get("stage_status"),
  148. "error_message": result.get("error_message"),
  149. # 注意:不包含 progress_manager, task_file_info, messages 等不可序列化对象
  150. }
  151. return serializable_result
  152. except Exception as e:
  153. logger.error(f"提交文档处理任务失败: {str(e)}", exc_info=True)
  154. # 标记任务失败
  155. if callback_task_id in self.active_chains:
  156. self.active_chains[callback_task_id].fail_processing(str(e))
  157. # 清理任务注册
  158. asyncio.run(self.redis_duplicate_checker.unregister_task(task_file_info.file_id))
  159. # 通知SSE连接任务失败
  160. error_data = {
  161. "error": str(e),
  162. "status": "failed",
  163. "overall_task_status": "failed",
  164. "timestamp": datetime.now().isoformat()
  165. }
  166. asyncio.run(self.progress_manager.complete_task(callback_task_id, task_file_info.user_id, error_data))
  167. raise
  168. finally:
  169. # 清理活跃任务
  170. if callback_task_id in self.active_chains:
  171. del self.active_chains[callback_task_id]
  172. async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
  173. """
  174. 设置任务终止信号
  175. Args:
  176. callback_task_id: 任务回调ID
  177. operator: 操作人(用户ID或系统标识)
  178. Returns:
  179. Dict: 操作结果 {"success": bool, "message": str, "task_info": dict}
  180. Note:
  181. 将终止信号写入 Redis,支持跨进程检测
  182. AI审查节点在执行前会检查此信号
  183. """
  184. try:
  185. # 检查任务是否在活跃列表中
  186. if callback_task_id not in self.active_chains:
  187. return {
  188. "success": False,
  189. "message": f"任务不存在或已完成: {callback_task_id}",
  190. "task_info": None
  191. }
  192. task_chain = self.active_chains[callback_task_id]
  193. # 检查任务状态
  194. if task_chain.status != "processing":
  195. return {
  196. "success": False,
  197. "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})",
  198. "task_info": {
  199. "callback_task_id": callback_task_id,
  200. "status": task_chain.status,
  201. "file_name": task_chain.file_name
  202. }
  203. }
  204. # 设置 Redis 终止信号
  205. redis_client = await RedisConnectionFactory.get_connection()
  206. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  207. # 存储终止信号和操作人、时间
  208. terminate_data = {
  209. "operator": operator,
  210. "terminate_time": str(time.time()),
  211. "task_id": callback_task_id
  212. }
  213. # 使用 hash 存储更多信息
  214. await redis_client.hset(terminate_key, mapping=terminate_data)
  215. # 设置过期时间(2小时)
  216. await redis_client.expire(terminate_key, self._task_expire_time)
  217. logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})")
  218. return {
  219. "success": True,
  220. "message": f"终止信号已设置,任务将在当前节点完成后终止",
  221. "task_info": {
  222. "callback_task_id": callback_task_id,
  223. "file_id": task_chain.file_id,
  224. "file_name": task_chain.file_name,
  225. "user_id": task_chain.user_id,
  226. "status": task_chain.status,
  227. "current_stage": task_chain.current_stage
  228. }
  229. }
  230. except Exception as e:
  231. logger.error(f"设置终止信号失败: {str(e)}", exc_info=True)
  232. return {
  233. "success": False,
  234. "message": f"设置终止信号失败: {str(e)}",
  235. "task_info": None
  236. }
  237. async def check_terminate_signal(self, callback_task_id: str) -> bool:
  238. """
  239. 检查是否有终止信号
  240. Args:
  241. callback_task_id: 任务回调ID
  242. Returns:
  243. bool: 有终止信号返回 True
  244. Note:
  245. 从 Redis 读取终止信号
  246. 工作流节点在执行前调用此方法检查是否需要终止
  247. """
  248. try:
  249. redis_client = await RedisConnectionFactory.get_connection()
  250. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  251. # 检查键是否存在
  252. exists = await redis_client.exists(terminate_key)
  253. if exists:
  254. # 读取终止信息
  255. terminate_info = await redis_client.hgetall(terminate_key)
  256. logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
  257. return True
  258. return False
  259. except RuntimeError as e:
  260. # 事件循环相关的错误处理
  261. error_msg = str(e)
  262. if "Event loop is closed" in error_msg:
  263. # 事件循环关闭是正常情况(任务结束),不记录错误
  264. logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}")
  265. return False
  266. elif "bound to a different event loop" in error_msg:
  267. # 事件循环不匹配,记录警告但不中断流程
  268. logger.warning(f"检查终止信号时检测到事件循环不匹配: {callback_task_id},将忽略本次检查")
  269. return False
  270. else:
  271. # 其他 RuntimeError 记录错误
  272. logger.error(f"检查终止信号失败(RuntimeError): {error_msg}", exc_info=True)
  273. return False
  274. except Exception as e:
  275. # 其他异常仍然记录错误
  276. logger.error(f"检查终止信号失败: {str(e)}", exc_info=True)
  277. return False
  278. async def clear_terminate_signal(self, callback_task_id: str):
  279. """
  280. 清理 Redis 中的终止信号
  281. Args:
  282. callback_task_id: 任务回调ID
  283. """
  284. try:
  285. redis_client = await RedisConnectionFactory.get_connection()
  286. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  287. await redis_client.delete(terminate_key)
  288. logger.debug(f"清理终止信号: {callback_task_id}")
  289. except Exception as e:
  290. logger.warning(f"清理终止信号失败: {str(e)}")
  291. async def get_active_tasks(self) -> list:
  292. """
  293. 获取活跃任务列表
  294. Returns:
  295. list: 活跃任务信息列表
  296. """
  297. try:
  298. active_tasks = []
  299. current_time = time.time()
  300. for task_id, task_chain in self.active_chains.items():
  301. if task_chain.status == "processing":
  302. task_info = {
  303. "callback_task_id": task_id,
  304. "file_id": task_chain.file_id,
  305. "file_name": task_chain.file_name,
  306. "user_id": task_chain.user_id,
  307. "status": task_chain.status,
  308. "current_stage": task_chain.current_stage,
  309. "start_time": task_chain.start_time,
  310. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0
  311. }
  312. active_tasks.append(task_info)
  313. return active_tasks
  314. except Exception as e:
  315. logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
  316. return []
  317. async def get_task_info(self, callback_task_id: str) -> Optional[Dict]:
  318. """
  319. 获取任务信息
  320. Args:
  321. callback_task_id: 任务回调ID
  322. Returns:
  323. Optional[Dict]: 任务信息字典,不存在返回 None
  324. """
  325. try:
  326. task_chain = self.active_chains.get(callback_task_id)
  327. if task_chain:
  328. current_time = time.time()
  329. return {
  330. "callback_task_id": callback_task_id,
  331. "file_id": task_chain.file_id,
  332. "file_name": task_chain.file_name,
  333. "user_id": task_chain.user_id,
  334. "status": task_chain.status,
  335. "current_stage": task_chain.current_stage,
  336. "start_time": task_chain.start_time,
  337. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0,
  338. "results": task_chain.results
  339. }
  340. return None
  341. except Exception as e:
  342. logger.error(f"获取任务信息失败: {str(e)}", exc_info=True)
  343. return None
  344. def _build_task_chain_workflow(self) -> StateGraph:
  345. """
  346. 构建 LangGraph 任务链工作流图(方案D)
  347. Returns:
  348. StateGraph: 配置完成的 LangGraph 任务链图实例
  349. Note:
  350. 创建包含文档处理、AI审查(嵌套子图)、报告生成的完整任务链
  351. 设置节点间的转换关系和条件边,支持终止检查和错误处理
  352. 工作流路径: start → document_processing → ai_review_subgraph → report_generation → complete → END
  353. """
  354. logger.info("开始构建 LangGraph 任务链工作流图")
  355. workflow = StateGraph(TaskChainState)
  356. # 添加节点
  357. workflow.add_node("start", self._start_chain_node)
  358. workflow.add_node("document_processing", self._document_processing_node)
  359. workflow.add_node("ai_review_subgraph", self._ai_review_subgraph_node)
  360. workflow.add_node("report_generation", self._report_generation_node)
  361. workflow.add_node("complete", self._complete_chain_node)
  362. workflow.add_node("error_handler", self._error_handler_chain_node)
  363. workflow.add_node("terminate", self._terminate_chain_node)
  364. # 设置入口点
  365. workflow.set_entry_point("start")
  366. # 添加边和条件边
  367. workflow.add_edge("start", "document_processing")
  368. # 文档处理后检查终止信号
  369. workflow.add_conditional_edges(
  370. "document_processing",
  371. self._should_terminate_or_error_chain,
  372. {
  373. "terminate": "terminate",
  374. "error": "error_handler",
  375. "continue": "ai_review_subgraph"
  376. }
  377. )
  378. # AI审查后检查终止信号
  379. workflow.add_conditional_edges(
  380. "ai_review_subgraph",
  381. self._should_terminate_or_error_chain,
  382. {
  383. "terminate": "terminate",
  384. "error": "error_handler",
  385. "continue": "report_generation"
  386. }
  387. )
  388. # 报告生成后检查终止信号
  389. workflow.add_conditional_edges(
  390. "report_generation",
  391. self._should_terminate_or_error_chain,
  392. {
  393. "terminate": "terminate",
  394. "error": "error_handler",
  395. "continue": "complete"
  396. }
  397. )
  398. # 完成节点直接结束
  399. workflow.add_edge("complete", END)
  400. workflow.add_edge("error_handler", END)
  401. workflow.add_edge("terminate", END)
  402. # 编译工作流图
  403. compiled_graph = workflow.compile()
  404. logger.info("LangGraph 任务链工作流图构建完成")
  405. return compiled_graph
  406. async def _start_chain_node(self, state: TaskChainState) -> TaskChainState:
  407. """
  408. 任务链开始节点
  409. Args:
  410. state: 任务链状态
  411. Returns:
  412. TaskChainState: 更新后的状态
  413. """
  414. logger.info(f"任务链工作流启动: {state['callback_task_id']}")
  415. return {
  416. "current_stage": "start",
  417. "overall_task_status": "processing",
  418. "stage_status": {
  419. "document": "pending",
  420. "ai_review": "pending",
  421. "report": "pending"
  422. },
  423. "messages": [AIMessage(content="任务链工作流启动")]
  424. }
  425. async def _document_processing_node(self, state: TaskChainState) -> TaskChainState:
  426. """
  427. 文档处理节点
  428. Args:
  429. state: 任务链状态
  430. Returns:
  431. TaskChainState: 更新后的状态,包含文档处理结果
  432. """
  433. try:
  434. logger.info(f"开始文档处理阶段: {state['callback_task_id']}")
  435. # 检查终止信号
  436. if await self.check_terminate_signal(state["callback_task_id"]):
  437. logger.warning(f"文档处理阶段检测到终止信号: {state['callback_task_id']}")
  438. return {
  439. "current_stage": "document_processing",
  440. "overall_task_status": "terminated",
  441. "stage_status": {**state["stage_status"], "document": "terminated"},
  442. "messages": [AIMessage(content="文档处理阶段检测到终止信号")]
  443. }
  444. # 获取 TaskFileInfo 实例
  445. task_file_info = state["task_file_info"]
  446. # 创建文档工作流实例
  447. document_workflow = DocumentWorkflow(
  448. task_file_info=task_file_info,
  449. progress_manager=state["progress_manager"],
  450. redis_duplicate_checker=self.redis_duplicate_checker
  451. )
  452. # 执行文档处理
  453. doc_result = await document_workflow.execute(
  454. state["file_content"],
  455. state["file_type"]
  456. )
  457. logger.info(f"文档处理完成: {state['callback_task_id']}")
  458. return {
  459. "current_stage": "document_processing",
  460. "overall_task_status": "processing",
  461. "stage_status": {**state["stage_status"], "document": "completed"},
  462. "document_result": doc_result,
  463. "messages": [AIMessage(content="文档处理完成")]
  464. }
  465. except Exception as e:
  466. logger.error(f"文档处理失败: {str(e)}", exc_info=True)
  467. return {
  468. "current_stage": "document_processing",
  469. "overall_task_status": "failed",
  470. "stage_status": {**state["stage_status"], "document": "failed"},
  471. "error_message": f"文档处理失败: {str(e)}",
  472. "messages": [AIMessage(content=f"文档处理失败: {str(e)}")]
  473. }
  474. async def _ai_review_subgraph_node(self, state: TaskChainState) -> TaskChainState:
  475. """
  476. AI审查子图节点(嵌套现有的 AIReviewWorkflow)
  477. Args:
  478. state: 任务链状态
  479. Returns:
  480. TaskChainState: 更新后的状态,包含AI审查结果
  481. Note:
  482. 这是方案D的核心实现:将现有的 AIReviewWorkflow 作为子图嵌套
  483. 无需修改 AIReviewWorkflow 的代码,保持其独立性
  484. """
  485. try:
  486. logger.info(f"开始AI审查阶段: {state['callback_task_id']}")
  487. # 检查终止信号
  488. if await self.check_terminate_signal(state["callback_task_id"]):
  489. logger.warning(f"AI审查阶段检测到终止信号: {state['callback_task_id']}")
  490. return {
  491. "current_stage": "ai_review",
  492. "overall_task_status": "terminated",
  493. "stage_status": {**state["stage_status"], "ai_review": "terminated"},
  494. "messages": [AIMessage(content="AI审查阶段检测到终止信号")]
  495. }
  496. # 获取文档处理结果中的结构化内容
  497. structured_content = state["document_result"].get("structured_content")
  498. if not structured_content:
  499. raise ValueError("文档处理结果中缺少结构化内容")
  500. # 获取 TaskFileInfo 实例
  501. task_file_info = state["task_file_info"]
  502. # 读取AI审查配置
  503. import configparser
  504. config = configparser.ConfigParser()
  505. config.read('config/config.ini', encoding='utf-8')
  506. max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
  507. if max_review_units == 0:
  508. max_review_units = None
  509. review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
  510. logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}")
  511. # 创建AI审查工作流实例(作为嵌套子图)
  512. ai_workflow = AIReviewWorkflow(
  513. task_file_info=task_file_info,
  514. structured_content=structured_content,
  515. progress_manager=state["progress_manager"],
  516. max_review_units=max_review_units,
  517. review_mode=review_mode
  518. )
  519. # 执行AI审查(内部使用 LangGraph)
  520. ai_result = await ai_workflow.execute()
  521. logger.info(f"AI审查完成: {state['callback_task_id']}")
  522. return {
  523. "current_stage": "ai_review",
  524. "overall_task_status": "processing",
  525. "stage_status": {**state["stage_status"], "ai_review": "completed"},
  526. "ai_review_result": ai_result,
  527. "messages": [AIMessage(content="AI审查完成")]
  528. }
  529. except Exception as e:
  530. logger.error(f"AI审查失败: {str(e)}", exc_info=True)
  531. return {
  532. "current_stage": "ai_review",
  533. "overall_task_status": "failed",
  534. "stage_status": {**state["stage_status"], "ai_review": "failed"},
  535. "error_message": f"AI审查失败: {str(e)}",
  536. "messages": [AIMessage(content=f"AI审查失败: {str(e)}")]
  537. }
  538. async def _report_generation_node(self, state: TaskChainState) -> TaskChainState:
  539. """
  540. 报告生成节点
  541. Args:
  542. state: 任务链状态
  543. Returns:
  544. TaskChainState: 更新后的状态,包含报告生成结果
  545. Note:
  546. 调用ReportWorkflow生成审查报告摘要(基于高中风险问题,使用LLM)
  547. 根据决策2(方案A-方式1),在此阶段生成完整报告后一次性保存
  548. """
  549. try:
  550. logger.info(f"开始报告生成阶段: {state['callback_task_id']}")
  551. # 检查终止信号
  552. if await self.check_terminate_signal(state["callback_task_id"]):
  553. logger.warning(f"报告生成阶段检测到终止信号: {state['callback_task_id']}")
  554. return {
  555. "current_stage": "report_generation",
  556. "overall_task_status": "terminated",
  557. "stage_status": {**state["stage_status"], "report": "terminated"},
  558. "messages": [AIMessage(content="报告生成阶段检测到终止信号")]
  559. }
  560. # 获取AI审查结果
  561. ai_review_result = state.get("ai_review_result")
  562. if not ai_review_result:
  563. raise ValueError("AI审查结果缺失,无法生成报告")
  564. # 获取 TaskFileInfo 实例
  565. task_file_info = state["task_file_info"]
  566. # 创建报告生成工作流实例
  567. report_workflow = ReportWorkflow(
  568. file_id=state["file_id"],
  569. file_name=state["file_name"],
  570. callback_task_id=state["callback_task_id"],
  571. user_id=state["user_id"],
  572. ai_review_results=ai_review_result,
  573. progress_manager=state["progress_manager"]
  574. )
  575. # 执行报告生成
  576. report_result = await report_workflow.execute()
  577. logger.info(f"报告生成完成: {state['callback_task_id']}")
  578. # 保存完整结果(包含文档处理、AI审查、报告生成)
  579. await self._save_complete_results(state, report_result)
  580. return {
  581. "current_stage": "report_generation",
  582. "overall_task_status": "processing",
  583. "stage_status": {**state["stage_status"], "report": "completed"},
  584. "report_result": report_result,
  585. "messages": [AIMessage(content="报告生成完成")]
  586. }
  587. except Exception as e:
  588. logger.error(f"报告生成失败: {str(e)}", exc_info=True)
  589. return {
  590. "current_stage": "report_generation",
  591. "overall_task_status": "failed",
  592. "stage_status": {**state["stage_status"], "report": "failed"},
  593. "error_message": f"报告生成失败: {str(e)}",
  594. "messages": [AIMessage(content=f"报告生成失败: {str(e)}")]
  595. }
  596. async def _complete_chain_node(self, state: TaskChainState) -> TaskChainState:
  597. """
  598. 任务链完成节点
  599. Args:
  600. state: 任务链状态
  601. Returns:
  602. TaskChainState: 更新后的状态,标记整体任务已完成
  603. Note:
  604. 只有在所有阶段(文档处理、AI审查、报告生成)都完成后才标记 overall_task_status="completed"
  605. 这解决了原有的状态语义混乱问题(P0-1)
  606. """
  607. logger.info(f"任务链工作流完成: {state['callback_task_id']}")
  608. # 标记整体任务完成
  609. if state["progress_manager"]:
  610. await state["progress_manager"].complete_task(
  611. state["callback_task_id"],
  612. state["user_id"],
  613. {"overall_task_status": "completed", "message": "所有阶段已完成"}
  614. )
  615. # 清理 Redis 缓存
  616. try:
  617. from foundation.utils.redis_utils import delete_file_info
  618. await delete_file_info(state["file_id"])
  619. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  620. except Exception as e:
  621. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  622. return {
  623. "current_stage": "complete",
  624. "overall_task_status": "completed", # ⚠️ 关键:只有到这里才标记整体完成
  625. "messages": [AIMessage(content="任务链工作流完成")]
  626. }
  627. async def _error_handler_chain_node(self, state: TaskChainState) -> TaskChainState:
  628. """
  629. 任务链错误处理节点
  630. Args:
  631. state: 任务链状态
  632. Returns:
  633. TaskChainState: 更新后的状态,标记为失败
  634. """
  635. logger.error(f"任务链工作流错误: {state['callback_task_id']}, 错误: {state.get('error_message', '未知错误')}")
  636. # 通知失败
  637. if state["progress_manager"]:
  638. error_data = {
  639. "overall_task_status": "failed",
  640. "error": state.get("error_message", "未知错误"),
  641. "status": "failed",
  642. "timestamp": datetime.now().isoformat()
  643. }
  644. await state["progress_manager"].complete_task(
  645. state["callback_task_id"],
  646. state["user_id"],
  647. error_data
  648. )
  649. # 清理 Redis 缓存(即使失败也清理)
  650. try:
  651. from foundation.utils.redis_utils import delete_file_info
  652. await delete_file_info(state["file_id"])
  653. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  654. except Exception as e:
  655. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  656. return {
  657. "current_stage": "error_handler",
  658. "overall_task_status": "failed",
  659. "messages": [AIMessage(content=f"任务链错误: {state.get('error_message', '未知错误')}")]
  660. }
  661. async def _terminate_chain_node(self, state: TaskChainState) -> TaskChainState:
  662. """
  663. 任务链终止节点
  664. Args:
  665. state: 任务链状态
  666. Returns:
  667. TaskChainState: 更新后的状态,标记为已终止
  668. """
  669. logger.warning(f"任务链工作流已终止: {state['callback_task_id']}")
  670. # 通知终止
  671. if state["progress_manager"]:
  672. await state["progress_manager"].complete_task(
  673. state["callback_task_id"],
  674. state["user_id"],
  675. {"overall_task_status": "terminated", "message": "任务已被用户终止"}
  676. )
  677. # 清理 Redis 终止信号
  678. await self.clear_terminate_signal(state["callback_task_id"])
  679. # 清理 Redis 文件缓存
  680. try:
  681. from foundation.utils.redis_utils import delete_file_info
  682. await delete_file_info(state["file_id"])
  683. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  684. except Exception as e:
  685. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  686. return {
  687. "current_stage": "terminated",
  688. "overall_task_status": "terminated",
  689. "messages": [AIMessage(content="任务链已被终止")]
  690. }
  691. def _should_terminate_or_error_chain(self, state: TaskChainState) -> str:
  692. """
  693. 检查任务链是否应该终止或发生错误
  694. Args:
  695. state: 任务链状态
  696. Returns:
  697. str: "terminate", "error", 或 "continue"
  698. Note:
  699. 这是条件边判断方法,用于决定工作流的下一步走向
  700. 1. 优先检查终止信号
  701. 2. 检查是否有错误
  702. 3. 都没有则继续执行
  703. """
  704. # 检查终止状态
  705. if state.get("overall_task_status") == "terminated":
  706. return "terminate"
  707. # 检查错误状态
  708. if state.get("overall_task_status") == "failed" or state.get("error_message"):
  709. return "error"
  710. # 默认继续执行
  711. return "continue"
  712. async def _save_complete_results(self, state: TaskChainState, report_result: Dict[str, Any]):
  713. """
  714. 保存完整结果(方案A-方式1:一次性保存)
  715. Args:
  716. state: 任务链状态
  717. report_result: 报告生成结果
  718. Note:
  719. 根据决策2(方案A-方式1),在报告工作流完成后一次性保存完整结果
  720. 包含:文档处理结果 + AI审查结果 + 报告生成结果
  721. """
  722. try:
  723. import json
  724. import os
  725. logger.info(f"开始保存完整结果: {state['callback_task_id']}")
  726. # 创建 temp 目录
  727. temp_dir = "temp"
  728. os.makedirs(temp_dir, exist_ok=True)
  729. # 构建完整结果
  730. complete_results = {
  731. "callback_task_id": state["callback_task_id"],
  732. "file_id": state["file_id"],
  733. "file_name": state["file_name"],
  734. "user_id": state["user_id"],
  735. "overall_task_status": "processing", # 此时还在处理中,complete节点才标记completed
  736. "stage_status": state["stage_status"],
  737. "document_result": state.get("document_result"),
  738. "ai_review_result": state.get("ai_review_result"),
  739. "issues": state.get("ai_review_result").get("review_results"),
  740. "report_result": report_result,
  741. "timestamp": datetime.now().isoformat()
  742. }
  743. # 保存到文件
  744. file_path = os.path.join(temp_dir, f"{state['callback_task_id']}.json")
  745. with open(file_path, 'w', encoding='utf-8') as f:
  746. json.dump(complete_results, f, ensure_ascii=False, indent=2)
  747. logger.info(f"完整结果已保存到: {file_path}")
  748. except Exception as e:
  749. logger.error(f"保存完整结果失败: {str(e)}", exc_info=True)
  750. raise