workflow_manager.py 64 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561
  1. """
  2. 基于LangGraph的工作流管理器
  3. 负责任务的创建、编排和执行,使用LangGraph进行状态管理
  4. 新增功能:
  5. - 任务终止管理
  6. - 终止信号设置和检测
  7. """
  8. import asyncio
  9. import time
  10. import json
  11. import os
  12. from typing import Dict, Optional, Any
  13. from datetime import datetime
  14. from langgraph.graph import StateGraph, END
  15. from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
  16. from foundation.observability.logger.loggering import review_logger as logger
  17. from foundation.observability.monitoring.time_statistics import track_execution_time
  18. from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory
  19. from .progress_manager import ProgressManager
  20. from .redis_duplicate_checker import RedisDuplicateChecker
  21. from .task_models import TaskFileInfo, TaskChain
  22. from ..construction_review.workflows import DocumentWorkflow, AIReviewWorkflow, ReportWorkflow
  23. from ..construction_review.workflows.types import TaskChainState
  24. class ProgressManagerRegistry:
  25. """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
  26. _registry = {} # {callback_task_id: ProgressManager}
  27. @classmethod
  28. def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
  29. """注册ProgressManager实例"""
  30. cls._registry[callback_task_id] = progress_manager
  31. logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
  32. @classmethod
  33. def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
  34. """获取ProgressManager实例"""
  35. return cls._registry.get(callback_task_id)
  36. @classmethod
  37. def unregister_progress_manager(cls, callback_task_id: str):
  38. """注销ProgressManager实例"""
  39. if callback_task_id in cls._registry:
  40. del cls._registry[callback_task_id]
  41. logger.info(f"注销ProgressManager实例: {callback_task_id}")
  42. class WorkflowManager:
  43. """工作流管理器"""
  44. def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10):
  45. self.max_concurrent_docs = max_concurrent_docs
  46. self.max_concurrent_reviews = max_concurrent_reviews
  47. # 并发控制
  48. self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs)
  49. self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
  50. # 服务组件
  51. self.progress_manager = ProgressManager()
  52. self.redis_duplicate_checker = RedisDuplicateChecker()
  53. # 活跃任务跟踪
  54. self.active_chains: Dict[str, TaskChain] = {}
  55. self._cleanup_task_started = False
  56. # 任务终止管理
  57. self._terminate_signal_prefix = "ai_review:terminate_signal:"
  58. self._task_expire_time = 7200 # 2小时
  59. # LangGraph 任务链工作流(方案D)
  60. self.task_chain_graph = None # 延迟初始化,避免循环导入
  61. # ==================== 施工方案编写任务管理 ====================
  62. # 大纲生成活跃任务跟踪
  63. self.active_outline_tasks: Dict[str, Any] = {}
  64. # 大纲生成任务 Redis 前缀
  65. self._outline_result_prefix = "outline_write:result:"
  66. self._outline_terminate_signal_prefix = "outline_write:terminate_signal:"
  67. # 大纲生成工作流图(延迟初始化)
  68. self.outline_generation_graph = None
  69. async def submit_task_processing(self, file_info: dict) -> str:
  70. """异步提交任务处理(用于file_upload层)"""
  71. from foundation.infrastructure.messaging.tasks import submit_task_processing_task
  72. from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
  73. try:
  74. logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
  75. # 使用CeleryTraceManager提交任务,自动传递trace_id
  76. task = CeleryTraceManager.submit_celery_task(
  77. submit_task_processing_task,
  78. file_info
  79. )
  80. logger.info(f"Celery任务已提交,Task ID: {task.id}")
  81. return task.id
  82. except Exception as e:
  83. logger.error(f"提交Celery任务失败: {str(e)}")
  84. raise
  85. @track_execution_time
  86. def submit_construction_review_task_processing_sync(self, file_info: dict) -> dict:
  87. """
  88. 同步提交施工审查任务处理(用于Celery worker)
  89. Note:
  90. 已切换到 LangGraph 任务链工作流(方案D)
  91. 使用统一的状态管理和嵌套子图架构
  92. """
  93. try:
  94. logger.info(f"提交文档处理任务(LangGraph方案D): {file_info['file_id']}")
  95. # 1. 创建TaskFileInfo对象(封装任务文件信息)
  96. task_file_info = TaskFileInfo(file_info)
  97. logger.info(f"创建任务文件信息: {task_file_info}")
  98. # 2. 生成任务链ID
  99. callback_task_id = task_file_info.callback_task_id
  100. # 3. 创建任务链(引用 TaskFileInfo,避免数据重复)
  101. task_chain = TaskChain(task_file_info)
  102. # 4. 标记任务开始
  103. task_chain.start_processing()
  104. # 5. 添加到活跃任务跟踪
  105. self.active_chains[callback_task_id] = task_chain
  106. # 6. 初始化进度管理
  107. asyncio.run(self.progress_manager.initialize_progress(
  108. callback_task_id=callback_task_id,
  109. user_id=task_file_info.user_id,
  110. stages=[]
  111. ))
  112. # 7. 构建 LangGraph 任务链工作流(延迟初始化)
  113. if self.task_chain_graph is None:
  114. self.task_chain_graph = self._build_task_chain_workflow()
  115. # 8. 构建初始状态
  116. initial_state = TaskChainState(
  117. file_id=task_file_info.file_id,
  118. callback_task_id=callback_task_id,
  119. user_id=task_file_info.user_id,
  120. file_name=task_file_info.file_name,
  121. file_type=task_file_info.file_type,
  122. file_content=task_file_info.file_content,
  123. current_stage="start",
  124. overall_task_status="processing",
  125. stage_status={
  126. "document": "pending",
  127. "ai_review": "pending",
  128. "report": "pending"
  129. },
  130. document_result=None,
  131. ai_review_result=None,
  132. report_result=None,
  133. error_message=None,
  134. progress_manager=self.progress_manager,
  135. task_file_info=task_file_info,
  136. messages=[HumanMessage(content=f"开始任务链: {task_file_info.file_id}")]
  137. )
  138. # 9. 执行 LangGraph 任务链工作流
  139. loop = asyncio.new_event_loop()
  140. asyncio.set_event_loop(loop)
  141. result = loop.run_until_complete(self.task_chain_graph.ainvoke(initial_state))
  142. loop.close()
  143. # 10. 清理任务注册
  144. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  145. logger.info(f"施工方案审查任务已完成(LangGraph方案D)!")
  146. logger.info(f"文件ID: {task_file_info.file_id}")
  147. logger.info(f"文件名: {task_file_info.file_name}")
  148. logger.info(f"整体状态: {result.get('overall_task_status', 'unknown')}")
  149. # 构建可序列化的返回结果(移除不可序列化的对象)
  150. serializable_result = {
  151. "file_id": result.get("file_id"),
  152. "callback_task_id": result.get("callback_task_id"),
  153. "user_id": result.get("user_id"),
  154. "file_name": result.get("file_name"),
  155. "current_stage": result.get("current_stage"),
  156. "overall_task_status": result.get("overall_task_status"),
  157. "stage_status": result.get("stage_status"),
  158. "error_message": result.get("error_message"),
  159. # 注意:不包含 progress_manager, task_file_info, messages 等不可序列化对象
  160. }
  161. return serializable_result
  162. except Exception as e:
  163. logger.error(f"提交文档处理任务失败: {str(e)}", exc_info=True)
  164. # 标记任务失败
  165. if callback_task_id in self.active_chains:
  166. self.active_chains[callback_task_id].fail_processing(str(e))
  167. # 清理任务注册
  168. asyncio.run(self.redis_duplicate_checker.unregister_task(task_file_info.file_id))
  169. # 通知SSE连接任务失败
  170. error_data = {
  171. "error": str(e),
  172. "status": "failed",
  173. "overall_task_status": "failed",
  174. "timestamp": datetime.now().isoformat()
  175. }
  176. asyncio.run(self.progress_manager.complete_task(callback_task_id, task_file_info.user_id, error_data))
  177. raise
  178. finally:
  179. # 清理活跃任务
  180. if callback_task_id in self.active_chains:
  181. del self.active_chains[callback_task_id]
  182. async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
  183. """
  184. 设置任务终止信号
  185. Args:
  186. callback_task_id: 任务回调ID
  187. operator: 操作人(用户ID或系统标识)
  188. Returns:
  189. Dict: 操作结果 {"success": bool, "message": str, "sgbx_task_info": dict}
  190. Note:
  191. 将终止信号写入 Redis,支持跨进程检测
  192. AI审查节点在执行前会检查此信号
  193. """
  194. try:
  195. # 检查任务是否在活跃列表中
  196. if callback_task_id not in self.active_chains:
  197. return {
  198. "success": False,
  199. "message": f"任务不存在或已完成: {callback_task_id}",
  200. "sgbx_task_info": None
  201. }
  202. task_chain = self.active_chains[callback_task_id]
  203. # 检查任务状态
  204. if task_chain.status != "processing":
  205. return {
  206. "success": False,
  207. "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})",
  208. "sgbx_task_info": {
  209. "callback_task_id": callback_task_id,
  210. "status": task_chain.status,
  211. "file_name": task_chain.file_name
  212. }
  213. }
  214. # 设置 Redis 终止信号
  215. redis_client = await RedisConnectionFactory.get_connection()
  216. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  217. # 存储终止信号和操作人、时间
  218. terminate_data = {
  219. "operator": operator,
  220. "terminate_time": str(time.time()),
  221. "task_id": callback_task_id
  222. }
  223. # 使用 hash 存储更多信息
  224. await redis_client.hmset(terminate_key, terminate_data)
  225. # 设置过期时间(2小时)
  226. await redis_client.expire(terminate_key, self._task_expire_time)
  227. logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})")
  228. return {
  229. "success": True,
  230. "message": f"终止信号已设置,任务将在当前节点完成后终止",
  231. "sgbx_task_info": {
  232. "callback_task_id": callback_task_id,
  233. "file_id": task_chain.file_id,
  234. "file_name": task_chain.file_name,
  235. "user_id": task_chain.user_id,
  236. "status": task_chain.status,
  237. "current_stage": task_chain.current_stage
  238. }
  239. }
  240. except Exception as e:
  241. logger.error(f"设置终止信号失败: {str(e)}", exc_info=True)
  242. return {
  243. "success": False,
  244. "message": f"设置终止信号失败: {str(e)}",
  245. "sgbx_task_info": None
  246. }
  247. async def check_terminate_signal(self, callback_task_id: str) -> bool:
  248. """
  249. 检查是否有终止信号
  250. Args:
  251. callback_task_id: 任务回调ID
  252. Returns:
  253. bool: 有终止信号返回 True
  254. Note:
  255. 从 Redis 读取终止信号
  256. 工作流节点在执行前调用此方法检查是否需要终止
  257. """
  258. try:
  259. redis_client = await RedisConnectionFactory.get_connection()
  260. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  261. # 检查键是否存在
  262. exists = await redis_client.exists(terminate_key)
  263. if exists:
  264. # 读取终止信息
  265. terminate_info = await redis_client.hgetall(terminate_key)
  266. logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
  267. return True
  268. return False
  269. except RuntimeError as e:
  270. # 事件循环相关的错误处理
  271. error_msg = str(e)
  272. if "Event loop is closed" in error_msg:
  273. # 事件循环关闭是正常情况(任务结束),不记录错误
  274. logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}")
  275. return False
  276. elif "bound to a different event loop" in error_msg:
  277. # 事件循环不匹配,记录警告但不中断流程
  278. logger.warning(f"检查终止信号时检测到事件循环不匹配: {callback_task_id},将忽略本次检查")
  279. return False
  280. else:
  281. # 其他 RuntimeError 记录错误
  282. logger.error(f"检查终止信号失败(RuntimeError): {error_msg}", exc_info=True)
  283. return False
  284. except Exception as e:
  285. # 其他异常仍然记录错误
  286. logger.error(f"检查终止信号失败: {str(e)}", exc_info=True)
  287. return False
  288. async def clear_terminate_signal(self, callback_task_id: str):
  289. """
  290. 清理 Redis 中的终止信号
  291. Args:
  292. callback_task_id: 任务回调ID
  293. """
  294. try:
  295. redis_client = await RedisConnectionFactory.get_connection()
  296. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  297. await redis_client.delete(terminate_key)
  298. logger.debug(f"清理终止信号: {callback_task_id}")
  299. except Exception as e:
  300. logger.warning(f"清理终止信号失败: {str(e)}")
  301. async def get_active_tasks(self) -> list:
  302. """
  303. 获取活跃任务列表
  304. Returns:
  305. list: 活跃任务信息列表
  306. """
  307. try:
  308. active_tasks = []
  309. current_time = time.time()
  310. for task_id, task_chain in self.active_chains.items():
  311. if task_chain.status == "processing":
  312. sgbx_task_info = {
  313. "callback_task_id": task_id,
  314. "file_id": task_chain.file_id,
  315. "file_name": task_chain.file_name,
  316. "user_id": task_chain.user_id,
  317. "status": task_chain.status,
  318. "current_stage": task_chain.current_stage,
  319. "start_time": task_chain.start_time,
  320. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0
  321. }
  322. active_tasks.append(sgbx_task_info)
  323. return active_tasks
  324. except Exception as e:
  325. logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
  326. return []
  327. async def get_sgbx_task_info(self, callback_task_id: str) -> Optional[Dict]:
  328. """
  329. 获取任务信息
  330. Args:
  331. callback_task_id: 任务回调ID
  332. Returns:
  333. Optional[Dict]: 任务信息字典,不存在返回 None
  334. """
  335. try:
  336. task_chain = self.active_chains.get(callback_task_id)
  337. if task_chain:
  338. current_time = time.time()
  339. return {
  340. "callback_task_id": callback_task_id,
  341. "file_id": task_chain.file_id,
  342. "file_name": task_chain.file_name,
  343. "user_id": task_chain.user_id,
  344. "status": task_chain.status,
  345. "current_stage": task_chain.current_stage,
  346. "start_time": task_chain.start_time,
  347. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0,
  348. "results": task_chain.results
  349. }
  350. return None
  351. except Exception as e:
  352. logger.error(f"获取任务信息失败: {str(e)}", exc_info=True)
  353. return None
  354. def _build_task_chain_workflow(self) -> StateGraph:
  355. """
  356. 构建 LangGraph 任务链工作流图(方案D)
  357. Returns:
  358. StateGraph: 配置完成的 LangGraph 任务链图实例
  359. Note:
  360. 创建包含文档处理、AI审查(嵌套子图)、报告生成的完整任务链
  361. 设置节点间的转换关系和条件边,支持终止检查和错误处理
  362. 工作流路径: start → document_processing → ai_review_subgraph → report_generation → complete → END
  363. """
  364. logger.info("开始构建 LangGraph 任务链工作流图")
  365. workflow = StateGraph(TaskChainState)
  366. # 添加节点
  367. workflow.add_node("start", self._start_chain_node)
  368. workflow.add_node("document_processing", self._document_processing_node)
  369. workflow.add_node("ai_review_subgraph", self._ai_review_subgraph_node)
  370. workflow.add_node("report_generation", self._report_generation_node)
  371. workflow.add_node("complete", self._complete_chain_node)
  372. workflow.add_node("error_handler", self._error_handler_chain_node)
  373. workflow.add_node("terminate", self._terminate_chain_node)
  374. # 设置入口点
  375. workflow.set_entry_point("start")
  376. # 添加边和条件边
  377. workflow.add_edge("start", "document_processing")
  378. # 文档处理后检查终止信号
  379. workflow.add_conditional_edges(
  380. "document_processing",
  381. self._should_terminate_or_error_chain,
  382. {
  383. "terminate": "terminate",
  384. "error": "error_handler",
  385. "continue": "ai_review_subgraph"
  386. }
  387. )
  388. # AI审查后检查终止信号
  389. workflow.add_conditional_edges(
  390. "ai_review_subgraph",
  391. self._should_terminate_or_error_chain,
  392. {
  393. "terminate": "terminate",
  394. "error": "error_handler",
  395. "continue": "report_generation"
  396. }
  397. )
  398. # 报告生成后检查终止信号
  399. workflow.add_conditional_edges(
  400. "report_generation",
  401. self._should_terminate_or_error_chain,
  402. {
  403. "terminate": "terminate",
  404. "error": "error_handler",
  405. "continue": "complete"
  406. }
  407. )
  408. # 完成节点直接结束
  409. workflow.add_edge("complete", END)
  410. workflow.add_edge("error_handler", END)
  411. workflow.add_edge("terminate", END)
  412. # 编译工作流图
  413. compiled_graph = workflow.compile()
  414. # 保存工作流图到 temp/construction_review 目录
  415. self._save_workflow_graph(compiled_graph, "temp/construction_review/task_chain_workflow.png")
  416. logger.info("LangGraph 任务链工作流图构建完成")
  417. return compiled_graph
  418. def _save_workflow_graph(self, compiled_graph: StateGraph, output_path: str):
  419. """
  420. 保存 LangGraph 工作流图为 PNG 图片
  421. Args:
  422. compiled_graph: 编译后的 LangGraph 工作流图
  423. output_path: 输出文件路径
  424. """
  425. try:
  426. # 确保输出目录存在
  427. output_dir = os.path.dirname(output_path)
  428. if output_dir and not os.path.exists(output_dir):
  429. os.makedirs(output_dir, exist_ok=True)
  430. logger.info(f"创建输出目录:{output_dir}")
  431. # 使用 graphviz 保存图片
  432. # 需要安装 graphviz 和 graphviz Python 包
  433. try:
  434. graph = compiled_graph.get_graph()
  435. graph_image = graph.draw_mermaid_png()
  436. with open(output_path, "wb") as f:
  437. f.write(graph_image)
  438. logger.info(f"工作流图已保存到:{output_path}")
  439. except Exception as graphviz_error:
  440. logger.warning(f"Graphviz 保存失败:{str(graphviz_error)}, 尝试使用 JSON 格式保存")
  441. # 备用方案:保存为 JSON 格式
  442. json_path = output_path.replace(".png", ".json")
  443. with open(json_path, "w", encoding="utf-8") as f:
  444. json.dump(compiled_graph.get_graph().to_json(), f, indent=2, ensure_ascii=False)
  445. logger.info(f"工作流图已保存到 (JSON 格式): {json_path}")
  446. except Exception as e:
  447. logger.warning(f"保存工作流图失败:{str(e)}")
  448. async def _start_chain_node(self, state: TaskChainState) -> TaskChainState:
  449. """
  450. 任务链开始节点
  451. Args:
  452. state: 任务链状态
  453. Returns:
  454. TaskChainState: 更新后的状态
  455. """
  456. logger.info(f"任务链工作流启动: {state['callback_task_id']}")
  457. return {
  458. "current_stage": "start",
  459. "overall_task_status": "processing",
  460. "stage_status": {
  461. "document": "pending",
  462. "ai_review": "pending",
  463. "report": "pending"
  464. },
  465. "messages": [AIMessage(content="任务链工作流启动")]
  466. }
  467. async def _document_processing_node(self, state: TaskChainState) -> TaskChainState:
  468. """
  469. 文档处理节点
  470. Args:
  471. state: 任务链状态
  472. Returns:
  473. TaskChainState: 更新后的状态,包含文档处理结果
  474. """
  475. try:
  476. logger.info(f"开始文档处理阶段: {state['callback_task_id']}")
  477. # 检查终止信号
  478. if await self.check_terminate_signal(state["callback_task_id"]):
  479. logger.warning(f"文档处理阶段检测到终止信号: {state['callback_task_id']}")
  480. return {
  481. "current_stage": "document_processing",
  482. "overall_task_status": "terminated",
  483. "stage_status": {**state["stage_status"], "document": "terminated"},
  484. "messages": [AIMessage(content="文档处理阶段检测到终止信号")]
  485. }
  486. # 获取 TaskFileInfo 实例
  487. task_file_info = state["task_file_info"]
  488. # 创建文档工作流实例
  489. document_workflow = DocumentWorkflow(
  490. task_file_info=task_file_info,
  491. progress_manager=state["progress_manager"],
  492. redis_duplicate_checker=self.redis_duplicate_checker
  493. )
  494. # 执行文档处理
  495. doc_result = await document_workflow.execute(
  496. state["file_content"],
  497. state["file_type"]
  498. )
  499. logger.info(f"文档处理完成: {state['callback_task_id']}")
  500. return {
  501. "current_stage": "document_processing",
  502. "overall_task_status": "processing",
  503. "stage_status": {**state["stage_status"], "document": "completed"},
  504. "document_result": doc_result,
  505. "messages": [AIMessage(content="文档处理完成")]
  506. }
  507. except Exception as e:
  508. logger.error(f"文档处理失败: {str(e)}", exc_info=True)
  509. return {
  510. "current_stage": "document_processing",
  511. "overall_task_status": "failed",
  512. "stage_status": {**state["stage_status"], "document": "failed"},
  513. "error_message": f"文档处理失败: {str(e)}",
  514. "messages": [AIMessage(content=f"文档处理失败: {str(e)}")]
  515. }
  516. async def _ai_review_subgraph_node(self, state: TaskChainState) -> TaskChainState:
  517. """
  518. AI审查子图节点(嵌套现有的 AIReviewWorkflow)
  519. Args:
  520. state: 任务链状态
  521. Returns:
  522. TaskChainState: 更新后的状态,包含AI审查结果
  523. Note:
  524. 这是方案D的核心实现:将现有的 AIReviewWorkflow 作为子图嵌套
  525. 无需修改 AIReviewWorkflow 的代码,保持其独立性
  526. """
  527. try:
  528. logger.info(f"开始AI审查阶段: {state['callback_task_id']}")
  529. # 检查终止信号
  530. if await self.check_terminate_signal(state["callback_task_id"]):
  531. logger.warning(f"AI审查阶段检测到终止信号: {state['callback_task_id']}")
  532. return {
  533. "current_stage": "ai_review",
  534. "overall_task_status": "terminated",
  535. "stage_status": {**state["stage_status"], "ai_review": "terminated"},
  536. "messages": [AIMessage(content="AI审查阶段检测到终止信号")]
  537. }
  538. # 获取文档处理结果中的结构化内容
  539. structured_content = state["document_result"].get("structured_content")
  540. if not structured_content:
  541. raise ValueError("文档处理结果中缺少结构化内容")
  542. # 获取 TaskFileInfo 实例
  543. task_file_info = state["task_file_info"]
  544. # 读取AI审查配置
  545. import configparser
  546. config = configparser.ConfigParser()
  547. config.read('config/config.ini', encoding='utf-8')
  548. max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
  549. if max_review_units == 0:
  550. max_review_units = None
  551. review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
  552. logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}")
  553. # 创建AI审查工作流实例(作为嵌套子图)
  554. ai_workflow = AIReviewWorkflow(
  555. task_file_info=task_file_info,
  556. structured_content=structured_content,
  557. progress_manager=state["progress_manager"],
  558. max_review_units=max_review_units,
  559. review_mode=review_mode
  560. )
  561. # 执行AI审查(内部使用 LangGraph)
  562. ai_result = await ai_workflow.execute()
  563. logger.info(f"AI审查完成: {state['callback_task_id']}")
  564. return {
  565. "current_stage": "ai_review",
  566. "overall_task_status": "processing",
  567. "stage_status": {**state["stage_status"], "ai_review": "completed"},
  568. "ai_review_result": ai_result,
  569. "messages": [AIMessage(content="AI审查完成")]
  570. }
  571. except Exception as e:
  572. logger.error(f"AI审查失败: {str(e)}", exc_info=True)
  573. return {
  574. "current_stage": "ai_review",
  575. "overall_task_status": "failed",
  576. "stage_status": {**state["stage_status"], "ai_review": "failed"},
  577. "error_message": f"AI审查失败: {str(e)}",
  578. "messages": [AIMessage(content=f"AI审查失败: {str(e)}")]
  579. }
  580. async def _report_generation_node(self, state: TaskChainState) -> TaskChainState:
  581. """
  582. 报告生成节点
  583. Args:
  584. state: 任务链状态
  585. Returns:
  586. TaskChainState: 更新后的状态,包含报告生成结果
  587. Note:
  588. 调用ReportWorkflow生成审查报告摘要(基于高中风险问题,使用LLM)
  589. 根据决策2(方案A-方式1),在此阶段生成完整报告后一次性保存
  590. """
  591. try:
  592. logger.info(f"开始报告生成阶段: {state['callback_task_id']}")
  593. # 检查终止信号
  594. if await self.check_terminate_signal(state["callback_task_id"]):
  595. logger.warning(f"报告生成阶段检测到终止信号: {state['callback_task_id']}")
  596. return {
  597. "current_stage": "report_generation",
  598. "overall_task_status": "terminated",
  599. "stage_status": {**state["stage_status"], "report": "terminated"},
  600. "messages": [AIMessage(content="报告生成阶段检测到终止信号")]
  601. }
  602. # 获取AI审查结果
  603. ai_review_result = state.get("ai_review_result")
  604. if not ai_review_result:
  605. raise ValueError("AI审查结果缺失,无法生成报告")
  606. # 获取 TaskFileInfo 实例
  607. task_file_info = state["task_file_info"]
  608. # 创建报告生成工作流实例
  609. report_workflow = ReportWorkflow(
  610. file_id=state["file_id"],
  611. file_name=state["file_name"],
  612. callback_task_id=state["callback_task_id"],
  613. user_id=state["user_id"],
  614. ai_review_results=ai_review_result,
  615. progress_manager=state["progress_manager"]
  616. )
  617. # 执行报告生成
  618. report_result = await report_workflow.execute()
  619. # 检查是否为降级报告
  620. is_fallback = report_result.get('is_fallback', False)
  621. if is_fallback:
  622. logger.warning(f"报告生成使用了降级方案: {state['callback_task_id']}")
  623. else:
  624. logger.info(f"报告生成完成: {state['callback_task_id']}")
  625. # 保存完整结果(包含文档处理、AI审查、报告生成)
  626. await self._save_complete_results(state, report_result)
  627. return {
  628. "current_stage": "report_generation",
  629. "overall_task_status": "processing",
  630. "stage_status": {**state["stage_status"], "report": "completed"},
  631. "report_result": report_result,
  632. "messages": [AIMessage(content="报告生成完成")]
  633. }
  634. except Exception as e:
  635. logger.error(f"报告生成失败: {str(e)}", exc_info=True)
  636. return {
  637. "current_stage": "report_generation",
  638. "overall_task_status": "failed",
  639. "stage_status": {**state["stage_status"], "report": "failed"},
  640. "error_message": f"报告生成失败: {str(e)}",
  641. "messages": [AIMessage(content=f"报告生成失败: {str(e)}")]
  642. }
  643. async def _complete_chain_node(self, state: TaskChainState) -> TaskChainState:
  644. """
  645. 任务链完成节点
  646. Args:
  647. state: 任务链状态
  648. Returns:
  649. TaskChainState: 更新后的状态,标记整体任务已完成
  650. Note:
  651. 只有在所有阶段(文档处理、AI审查、报告生成)都完成后才标记 overall_task_status="completed"
  652. 这解决了原有的状态语义混乱问题(P0-1)
  653. """
  654. logger.info(f"任务链工作流完成: {state['callback_task_id']}")
  655. # 标记整体任务完成
  656. if state["progress_manager"]:
  657. await state["progress_manager"].complete_task(
  658. state["callback_task_id"],
  659. state["user_id"],
  660. {"overall_task_status": "completed", "message": "所有阶段已完成"}
  661. )
  662. # 清理 Redis 缓存
  663. try:
  664. from foundation.utils.redis_utils import delete_file_info
  665. await delete_file_info(state["file_id"])
  666. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  667. except Exception as e:
  668. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  669. return {
  670. "current_stage": "complete",
  671. "overall_task_status": "completed", # ⚠️ 关键:只有到这里才标记整体完成
  672. "messages": [AIMessage(content="任务链工作流完成")]
  673. }
  674. async def _error_handler_chain_node(self, state: TaskChainState) -> TaskChainState:
  675. """
  676. 任务链错误处理节点
  677. Args:
  678. state: 任务链状态
  679. Returns:
  680. TaskChainState: 更新后的状态,标记为失败
  681. """
  682. logger.error(f"任务链工作流错误: {state['callback_task_id']}, 错误: {state.get('error_message', '未知错误')}")
  683. # 通知失败
  684. if state["progress_manager"]:
  685. error_data = {
  686. "overall_task_status": "failed",
  687. "error": state.get("error_message", "未知错误"),
  688. "status": "failed",
  689. "timestamp": datetime.now().isoformat()
  690. }
  691. await state["progress_manager"].complete_task(
  692. state["callback_task_id"],
  693. state["user_id"],
  694. error_data
  695. )
  696. # 清理 Redis 缓存(即使失败也清理)
  697. try:
  698. from foundation.utils.redis_utils import delete_file_info
  699. await delete_file_info(state["file_id"])
  700. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  701. except Exception as e:
  702. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  703. return {
  704. "current_stage": "error_handler",
  705. "overall_task_status": "failed",
  706. "messages": [AIMessage(content=f"任务链错误: {state.get('error_message', '未知错误')}")]
  707. }
  708. async def _terminate_chain_node(self, state: TaskChainState) -> TaskChainState:
  709. """
  710. 任务链终止节点
  711. Args:
  712. state: 任务链状态
  713. Returns:
  714. TaskChainState: 更新后的状态,标记为已终止
  715. """
  716. logger.warning(f"任务链工作流已终止: {state['callback_task_id']}")
  717. # 通知终止
  718. if state["progress_manager"]:
  719. await state["progress_manager"].complete_task(
  720. state["callback_task_id"],
  721. state["user_id"],
  722. {"overall_task_status": "terminated", "message": "任务已被用户终止"}
  723. )
  724. # 清理 Redis 终止信号
  725. await self.clear_terminate_signal(state["callback_task_id"])
  726. # 清理 Redis 文件缓存
  727. try:
  728. from foundation.utils.redis_utils import delete_file_info
  729. await delete_file_info(state["file_id"])
  730. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  731. except Exception as e:
  732. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  733. return {
  734. "current_stage": "terminated",
  735. "overall_task_status": "terminated",
  736. "messages": [AIMessage(content="任务链已被终止")]
  737. }
  738. def _should_terminate_or_error_chain(self, state: TaskChainState) -> str:
  739. """
  740. 检查任务链是否应该终止或发生错误
  741. Args:
  742. state: 任务链状态
  743. Returns:
  744. str: "terminate", "error", 或 "continue"
  745. Note:
  746. 这是条件边判断方法,用于决定工作流的下一步走向
  747. 1. 优先检查终止信号
  748. 2. 检查是否有错误
  749. 3. 都没有则继续执行
  750. """
  751. # 检查终止状态
  752. if state.get("overall_task_status") == "terminated":
  753. return "terminate"
  754. # 检查错误状态
  755. if state.get("overall_task_status") == "failed" or state.get("error_message"):
  756. return "error"
  757. # 默认继续执行
  758. return "continue"
  759. async def _save_complete_results(self, state: TaskChainState, report_result: Dict[str, Any]):
  760. """
  761. 保存完整结果(方案A-方式1:一次性保存)
  762. Args:
  763. state: 任务链状态
  764. report_result: 报告生成结果
  765. Note:
  766. 根据决策2(方案A-方式1),在报告工作流完成后一次性保存完整结果
  767. 包含:文档处理结果 + AI审查结果 + 报告生成结果
  768. """
  769. try:
  770. from foundation.observability.cachefiles import cache, CacheBaseDir
  771. logger.info(f"开始保存完整结果: {state['callback_task_id']}")
  772. # 构建完整结果
  773. ai_review_result = state.get("ai_review_result")
  774. complete_results = {
  775. "callback_task_id": state["callback_task_id"],
  776. "file_id": state["file_id"],
  777. "file_name": state["file_name"],
  778. "user_id": state["user_id"],
  779. "overall_task_status": "processing", # 此时还在处理中,complete节点才标记completed
  780. "stage_status": state["stage_status"],
  781. "document_result": state.get("document_result"),
  782. "ai_review_result": ai_review_result,
  783. "issues": ai_review_result.get("review_results") if ai_review_result else None,
  784. "report_result": report_result,
  785. "timestamp": datetime.now().isoformat()
  786. }
  787. # 使用 cache_manager 保存(指定文件名)
  788. import os
  789. target_dir = os.path.join(CacheBaseDir.CONSTRUCTION_REVIEW.value, "final_result")
  790. logger.info(f"准备保存结果到目录: {target_dir}")
  791. file_path = cache.save(
  792. complete_results,
  793. subdir="final_result",
  794. filename=f"{state['callback_task_id']}.json",
  795. base_cache_dir=CacheBaseDir.CONSTRUCTION_REVIEW
  796. )
  797. logger.info(f"完整结果已保存到: {file_path}")
  798. # 验证文件是否保存到正确位置
  799. if "final_result" not in str(file_path):
  800. logger.warning(f"警告:结果文件可能未保存到正确的final_result目录: {file_path}")
  801. except Exception as e:
  802. logger.error(f"保存完整结果失败: {str(e)}", exc_info=True)
  803. raise
  804. # ==================== 施工方案编写任务管理方法 ====================
  805. async def submit_outline_generation_task(self, sgbx_task_info: dict) -> str:
  806. """
  807. 提交大纲生成任务到 Celery
  808. Args:
  809. sgbx_task_info: 任务信息字典
  810. {
  811. "user_id": str,
  812. "project_info": dict,
  813. "template_id": str,
  814. "outline_config": dict,
  815. "similarity_config": dict (可选),
  816. "knowledge_config": dict (可选)
  817. }
  818. Returns:
  819. str: Celery 任务 ID
  820. """
  821. from foundation.infrastructure.messaging.tasks import submit_outline_generation_task
  822. from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
  823. try:
  824. callback_task_id = sgbx_task_info.get('callback_task_id')
  825. user_id = sgbx_task_info.get('user_id', 'unknown')
  826. logger.info(f"提交大纲生成任务到Celery: callback_task_id={callback_task_id}, user_id={user_id}")
  827. # 【关键修复】预先将任务信息写入 Redis,使 task_cancel 能立即查询到
  828. await self._pre_register_outline_task(sgbx_task_info)
  829. # 使用 CeleryTraceManager 提交任务,自动传递 trace_id
  830. task = CeleryTraceManager.submit_celery_task(
  831. submit_outline_generation_task,
  832. sgbx_task_info
  833. )
  834. logger.info(f"大纲生成Celery任务已提交,Task ID: {task.id}")
  835. return task.id
  836. except Exception as e:
  837. logger.error(f"提交大纲生成Celery任务失败: {str(e)}")
  838. raise
  839. @track_execution_time
  840. def submit_outline_generation_sync(self, sgbx_task_info: dict) -> dict:
  841. """
  842. 同步执行大纲生成任务(用于 Celery worker)
  843. Args:
  844. sgbx_task_info: 任务信息字典
  845. Returns:
  846. dict: 执行结果
  847. """
  848. import uuid
  849. from langchain_core.messages import HumanMessage
  850. from ..construction_write.component.state_models import OutlineGenerationState, OutlineTaskInfo
  851. from ..construction_write.workflows.outline_workflow import OutlineWorkflow
  852. callback_task_id = None
  853. try:
  854. # 1. 生成任务 ID(如果没有提供)
  855. callback_task_id = sgbx_task_info.get('callback_task_id') or f"outline_{uuid.uuid4().hex[:16]}"
  856. user_id = sgbx_task_info.get('user_id', 'unknown')
  857. logger.info(f"开始执行大纲生成任务(LangGraph): {callback_task_id}")
  858. # 【关键修复】检查任务是否已经被取消(在启动前被取消)
  859. loop = asyncio.new_event_loop()
  860. asyncio.set_event_loop(loop)
  861. try:
  862. is_cancelled = loop.run_until_complete(self.check_outline_terminate_signal(callback_task_id))
  863. if is_cancelled:
  864. logger.warning(f"任务已被取消,直接返回: {callback_task_id}")
  865. return {
  866. "callback_task_id": callback_task_id,
  867. "user_id": user_id,
  868. "overall_task_status": "terminated",
  869. "outline_structure": None,
  870. "key_points": None,
  871. "similar_cases": None,
  872. "similar_fragments": None,
  873. "knowledge_bases": None,
  874. "error_message": "任务在启动前被取消"
  875. }
  876. finally:
  877. loop.close()
  878. # 2. 创建任务信息对象(与 outline_views.py 传入的参数保持一致)
  879. outline_sgbx_task_info = OutlineTaskInfo(
  880. callback_task_id=callback_task_id,
  881. user_id=user_id,
  882. project_info=sgbx_task_info.get('project_info', {}),
  883. template_id=sgbx_task_info.get('template_id', ''),
  884. generation_chapterenum=sgbx_task_info.get('generation_chapterenum', []),
  885. generation_template=sgbx_task_info.get('generation_template', []),
  886. similarity_config=sgbx_task_info.get('similarity_config', {
  887. "topk_plans": 3,
  888. "topk_fragments": 10,
  889. "threshold": 0.75
  890. }),
  891. knowledge_config=sgbx_task_info.get('knowledge_config', {
  892. "topk": 3,
  893. "threshold": 0.75
  894. })
  895. )
  896. # 3. 添加到活跃任务跟踪
  897. self.active_outline_tasks[callback_task_id] = outline_sgbx_task_info
  898. # 4. 初始化进度管理
  899. loop = asyncio.new_event_loop()
  900. asyncio.set_event_loop(loop)
  901. loop.run_until_complete(self.progress_manager.initialize_progress(
  902. callback_task_id=callback_task_id,
  903. user_id=user_id,
  904. stages=[
  905. {"stage": "start", "status": "pending"},
  906. {"stage": "template_loading", "status": "pending"},
  907. {"stage": "outline_generation", "status": "pending"},
  908. {"stage": "similar_cases", "status": "pending"},
  909. {"stage": "similar_fragments", "status": "pending"},
  910. {"stage": "knowledge_bases", "status": "pending"},
  911. {"stage": "complete", "status": "pending"}
  912. ]
  913. ))
  914. # 4.1 注册 ProgressManager 到 Registry(供节点访问)
  915. ProgressManagerRegistry.register_progress_manager(callback_task_id, self.progress_manager)
  916. # 4.2 标记任务开始
  917. outline_sgbx_task_info.start_processing()
  918. # 5. 构建 LangGraph 大纲生成工作流(延迟初始化)
  919. if self.outline_generation_graph is None:
  920. outline_workflow = OutlineWorkflow()
  921. self.outline_generation_graph = outline_workflow.build_graph()
  922. # 6. 构建初始状态
  923. # 注意:progress_manager 和 sgbx_task_info 不能放入状态(不可序列化)
  924. # 它们通过类实例变量访问
  925. # 从 OutlineTaskInfo 中提取参数,与 outline_views.py 保持一致
  926. initial_state = OutlineGenerationState(
  927. callback_task_id=callback_task_id,
  928. user_id=user_id,
  929. project_info=outline_sgbx_task_info.project_info,
  930. template_id=outline_sgbx_task_info.template_id,
  931. # 直接使用 generation_chapterenum 和 generation_template(替代 outline_config)
  932. generation_chapterenum=outline_sgbx_task_info.generation_chapterenum,
  933. generation_template=outline_sgbx_task_info.generation_template,
  934. similarity_config=outline_sgbx_task_info.similarity_config,
  935. knowledge_config=outline_sgbx_task_info.knowledge_config,
  936. template=None,
  937. outline_structure=None,
  938. key_points=None,
  939. similar_cases=None,
  940. similar_fragments=None,
  941. knowledge_bases=None,
  942. current_stage="start",
  943. overall_task_status="processing",
  944. error_message=None,
  945. messages=[HumanMessage(content=f"开始大纲生成任务: {callback_task_id}")]
  946. )
  947. # 7. 执行 LangGraph 工作流
  948. # 需要提供 config 参数给 Checkpointer
  949. result = loop.run_until_complete(
  950. self.outline_generation_graph.ainvoke(
  951. initial_state,
  952. config={"configurable": {"thread_id": callback_task_id}}
  953. )
  954. )
  955. loop.close()
  956. logger.info(f"大纲生成任务完成!callback_task_id={callback_task_id}")
  957. # 8. 更新任务状态
  958. if result.get("overall_task_status") == "completed":
  959. outline_sgbx_task_info.complete_processing({
  960. "outline_structure": result.get("outline_structure"),
  961. "key_points": result.get("key_points"),
  962. "similar_cases": result.get("similar_cases"),
  963. "similar_fragments": result.get("similar_fragments"),
  964. "knowledge_bases": result.get("knowledge_bases")
  965. })
  966. elif result.get("overall_task_status") == "failed":
  967. outline_sgbx_task_info.fail_processing(result.get("error_message", "未知错误"))
  968. elif result.get("overall_task_status") == "terminated":
  969. outline_sgbx_task_info.cancel_processing()
  970. # 8.5 将任务结果保存到 Redis(供跨进程访问)
  971. async def save_result_to_redis():
  972. redis_client = await RedisConnectionFactory.get_connection()
  973. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  974. # 构建结果数据(过滤 None 值,Redis 不支持)
  975. result_data = {
  976. "callback_task_id": callback_task_id,
  977. "user_id": user_id,
  978. "overall_task_status": result.get("overall_task_status", ""),
  979. "outline_structure": json.dumps(result.get("outline_structure"), ensure_ascii=False) if result.get("outline_structure") else "",
  980. "key_points": json.dumps(result.get("key_points"), ensure_ascii=False) if result.get("key_points") else "",
  981. "similar_cases": json.dumps(result.get("similar_cases"), ensure_ascii=False) if result.get("similar_cases") else "",
  982. "similar_fragments": json.dumps(result.get("similar_fragments"), ensure_ascii=False) if result.get("similar_fragments") else "",
  983. "knowledge_bases": json.dumps(result.get("knowledge_bases"), ensure_ascii=False) if result.get("knowledge_bases") else "",
  984. "error_message": result.get("error_message") or "",
  985. "completed_time": str(time.time())
  986. }
  987. # 保存到 Redis(设置过期时间2小时)
  988. await redis_client.hmset(result_key, result_data)
  989. await redis_client.expire(result_key, self._task_expire_time)
  990. logger.info(f"大纲生成结果已保存到 Redis: {callback_task_id}")
  991. # 在同步函数中运行异步代码
  992. loop = asyncio.new_event_loop()
  993. asyncio.set_event_loop(loop)
  994. try:
  995. loop.run_until_complete(save_result_to_redis())
  996. finally:
  997. loop.close()
  998. # 9. 返回可序列化结果
  999. return {
  1000. "callback_task_id": result.get("callback_task_id"),
  1001. "user_id": result.get("user_id"),
  1002. "overall_task_status": result.get("overall_task_status"),
  1003. "outline_structure": result.get("outline_structure"),
  1004. "key_points": result.get("key_points"),
  1005. "similar_cases": result.get("similar_cases"),
  1006. "similar_fragments": result.get("similar_fragments"),
  1007. "knowledge_bases": result.get("knowledge_bases"),
  1008. "error_message": result.get("error_message")
  1009. }
  1010. except Exception as e:
  1011. logger.error(f"大纲生成任务失败: {str(e)}", exc_info=True)
  1012. # 标记任务失败
  1013. if callback_task_id and callback_task_id in self.active_outline_tasks:
  1014. self.active_outline_tasks[callback_task_id].fail_processing(str(e))
  1015. raise
  1016. finally:
  1017. # 清理活跃任务
  1018. if callback_task_id and callback_task_id in self.active_outline_tasks:
  1019. del self.active_outline_tasks[callback_task_id]
  1020. # 清理 Registry
  1021. ProgressManagerRegistry.unregister_progress_manager(callback_task_id)
  1022. async def _pre_register_outline_task(self, sgbx_task_info: dict):
  1023. """
  1024. 预注册大纲生成任务到 Redis
  1025. 【修复问题】解决任务提交后到 Celery Worker 实际执行前的时间窗口内,
  1026. task_cancel 接口无法查询到任务的问题。
  1027. Args:
  1028. sgbx_task_info: 任务信息字典
  1029. """
  1030. try:
  1031. callback_task_id = sgbx_task_info.get('callback_task_id')
  1032. user_id = sgbx_task_info.get('user_id', 'unknown')
  1033. project_info = sgbx_task_info.get('project_info', {})
  1034. redis_client = await RedisConnectionFactory.get_connection()
  1035. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  1036. # 构建预注册数据(状态为 pending,表示等待执行)
  1037. pre_register_data = {
  1038. "callback_task_id": callback_task_id,
  1039. "user_id": user_id,
  1040. "project_name": project_info.get('project_name', ''),
  1041. "project_type": project_info.get('engineering_type', ''),
  1042. "overall_task_status": "pending", # 关键:pending 状态表示等待执行
  1043. "outline_structure": "",
  1044. "key_points": "",
  1045. "similar_cases": "",
  1046. "similar_fragments": "",
  1047. "knowledge_bases": "",
  1048. "error_message": "",
  1049. "pre_registered": "true",
  1050. "pre_registered_at": str(time.time()),
  1051. "completed_time": ""
  1052. }
  1053. # 保存到 Redis(设置过期时间2小时)
  1054. await redis_client.hmset(result_key, pre_register_data)
  1055. await redis_client.expire(result_key, self._task_expire_time)
  1056. logger.info(f"大纲任务已预注册到 Redis: {callback_task_id}")
  1057. except Exception as e:
  1058. logger.error(f"预注册大纲任务失败: {str(e)}", exc_info=True)
  1059. # 预注册失败不影响主流程,继续提交 Celery 任务
  1060. async def set_outline_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
  1061. """
  1062. 设置大纲生成任务终止信号
  1063. Args:
  1064. callback_task_id: 任务回调ID
  1065. operator: 操作人
  1066. Returns:
  1067. Dict: 操作结果
  1068. """
  1069. try:
  1070. sgbx_task_info = None
  1071. task_status = None
  1072. task_user_id = None
  1073. project_name = ""
  1074. # 【修复】首先检查内存中的活跃任务
  1075. if callback_task_id in self.active_outline_tasks:
  1076. sgbx_task_info = self.active_outline_tasks[callback_task_id]
  1077. task_status = sgbx_task_info.status
  1078. task_user_id = sgbx_task_info.user_id
  1079. project_name = sgbx_task_info.project_name
  1080. # 检查任务状态
  1081. if task_status not in ["processing", "pending"]:
  1082. return {
  1083. "success": False,
  1084. "message": f"任务状态不是 processing/pending,无需终止: {callback_task_id} (当前状态: {task_status})",
  1085. "sgbx_task_info": {
  1086. "callback_task_id": callback_task_id,
  1087. "status": task_status,
  1088. "project_name": project_name
  1089. }
  1090. }
  1091. else:
  1092. # 【修复】如果内存中没有,检查 Redis 中的预注册任务
  1093. redis_client = await RedisConnectionFactory.get_connection()
  1094. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  1095. result_data = await redis_client.hgetall(result_key)
  1096. if not result_data:
  1097. return {
  1098. "success": False,
  1099. "message": f"任务不存在或已完成: {callback_task_id}",
  1100. "sgbx_task_info": None
  1101. }
  1102. task_status = result_data.get("overall_task_status", "unknown")
  1103. task_user_id = result_data.get("user_id", "unknown")
  1104. project_name = result_data.get("project_name", "")
  1105. # 预注册状态(pending)或正在执行(processing)都可以取消
  1106. if task_status not in ["pending", "processing"]:
  1107. status_mapping = {
  1108. "completed": "已完成",
  1109. "failed": "已失败",
  1110. "terminated": "已终止"
  1111. }
  1112. status_desc = status_mapping.get(task_status, task_status)
  1113. return {
  1114. "success": False,
  1115. "message": f"任务{status_desc},无法取消: {callback_task_id}",
  1116. "sgbx_task_info": {
  1117. "callback_task_id": callback_task_id,
  1118. "status": task_status,
  1119. "project_name": project_name
  1120. }
  1121. }
  1122. # 设置 Redis 终止信号
  1123. redis_client = await RedisConnectionFactory.get_connection()
  1124. terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
  1125. # 存储终止信号和操作人、时间
  1126. terminate_data = {
  1127. "operator": operator,
  1128. "terminate_time": str(time.time()),
  1129. "task_id": callback_task_id
  1130. }
  1131. # 使用 hash 存储更多信息
  1132. await redis_client.hmset(terminate_key, terminate_data)
  1133. # 设置过期时间(2小时)
  1134. await redis_client.expire(terminate_key, self._task_expire_time)
  1135. # 【修复】如果是预注册状态,更新 Redis 中的任务状态为 cancelled
  1136. if task_status == "pending":
  1137. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  1138. await redis_client.hmset(result_key, {
  1139. "overall_task_status": "terminated",
  1140. "error_message": "任务在启动前被取消"
  1141. })
  1142. logger.info(f"预注册任务已被取消: {callback_task_id}")
  1143. return {
  1144. "success": True,
  1145. "message": f"任务已成功取消(未开始执行)",
  1146. "sgbx_task_info": {
  1147. "callback_task_id": callback_task_id,
  1148. "user_id": task_user_id,
  1149. "project_name": project_name,
  1150. "status": "cancelled"
  1151. }
  1152. }
  1153. logger.info(f"已设置大纲任务终止信号: {callback_task_id} (操作人: {operator}, 项目: {project_name})")
  1154. return {
  1155. "success": True,
  1156. "message": f"终止信号已设置,任务将在当前节点完成后终止",
  1157. "sgbx_task_info": {
  1158. "callback_task_id": callback_task_id,
  1159. "user_id": task_user_id,
  1160. "project_name": project_name,
  1161. "status": task_status
  1162. }
  1163. }
  1164. except Exception as e:
  1165. logger.error(f"设置大纲任务终止信号失败: {str(e)}", exc_info=True)
  1166. return {
  1167. "success": False,
  1168. "message": f"设置终止信号失败: {str(e)}",
  1169. "sgbx_task_info": None
  1170. }
  1171. async def check_outline_terminate_signal(self, callback_task_id: str) -> bool:
  1172. """
  1173. 检查大纲生成任务是否有终止信号
  1174. Args:
  1175. callback_task_id: 任务回调ID
  1176. Returns:
  1177. bool: 有终止信号返回 True
  1178. """
  1179. try:
  1180. redis_client = await RedisConnectionFactory.get_connection()
  1181. terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
  1182. # 检查键是否存在
  1183. exists = await redis_client.exists(terminate_key)
  1184. if exists:
  1185. # 读取终止信息
  1186. terminate_info = await redis_client.hgetall(terminate_key)
  1187. logger.warning(f"检测到大纲任务终止信号: {callback_task_id}, "
  1188. f"操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
  1189. return True
  1190. return False
  1191. except Exception as e:
  1192. logger.error(f"检查大纲任务终止信号失败: {str(e)}", exc_info=True)
  1193. return False
  1194. async def clear_outline_terminate_signal(self, callback_task_id: str):
  1195. """
  1196. 清理 Redis 中的大纲任务终止信号
  1197. Args:
  1198. callback_task_id: 任务回调ID
  1199. """
  1200. try:
  1201. redis_client = await RedisConnectionFactory.get_connection()
  1202. terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
  1203. await redis_client.delete(terminate_key)
  1204. logger.debug(f"清理大纲任务终止信号: {callback_task_id}")
  1205. except Exception as e:
  1206. logger.warning(f"清理大纲任务终止信号失败: {str(e)}")
  1207. async def get_outline_active_tasks(self) -> list:
  1208. """
  1209. 获取活跃的大纲生成任务列表
  1210. Returns:
  1211. list: 活跃任务信息列表
  1212. """
  1213. try:
  1214. active_tasks = []
  1215. current_time = time.time()
  1216. for task_id, sgbx_task_info in self.active_outline_tasks.items():
  1217. if sgbx_task_info.status == "processing":
  1218. task_dict = {
  1219. "callback_task_id": task_id,
  1220. "user_id": sgbx_task_info.user_id,
  1221. "project_name": sgbx_task_info.project_name,
  1222. "project_type": sgbx_task_info.project_type,
  1223. "status": sgbx_task_info.status,
  1224. "start_time": sgbx_task_info.start_time,
  1225. "running_duration": int(current_time - sgbx_task_info.start_time) if sgbx_task_info.start_time else 0
  1226. }
  1227. active_tasks.append(task_dict)
  1228. return active_tasks
  1229. except Exception as e:
  1230. logger.error(f"获取活跃大纲任务列表失败: {str(e)}", exc_info=True)
  1231. return []
  1232. async def get_outline_sgbx_task_info(self, callback_task_id: str) -> Optional[Dict]:
  1233. """
  1234. 获取大纲生成任务信息
  1235. Args:
  1236. callback_task_id: 任务回调ID
  1237. Returns:
  1238. Optional[Dict]: 任务信息字典,不存在返回 None
  1239. """
  1240. try:
  1241. # 优先从内存中的活跃任务获取
  1242. sgbx_task_info = self.active_outline_tasks.get(callback_task_id)
  1243. if sgbx_task_info:
  1244. current_time = time.time()
  1245. return {
  1246. "callback_task_id": callback_task_id,
  1247. "user_id": sgbx_task_info.user_id,
  1248. "project_name": sgbx_task_info.project_name,
  1249. "project_type": sgbx_task_info.project_type,
  1250. "status": sgbx_task_info.status,
  1251. "start_time": sgbx_task_info.start_time,
  1252. "running_duration": int(current_time - sgbx_task_info.start_time) if sgbx_task_info.start_time else 0,
  1253. "results": sgbx_task_info.results
  1254. }
  1255. # 如果内存中没有,从 Redis 读取(用于跨进程访问 Celery worker 的结果)
  1256. redis_client = await RedisConnectionFactory.get_connection()
  1257. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  1258. result_data = await redis_client.hgetall(result_key)
  1259. if result_data:
  1260. # 解析 JSON 字符串
  1261. parsed_results = {}
  1262. for key in ["outline_structure", "key_points", "similar_cases", "similar_fragments", "knowledge_bases"]:
  1263. value = result_data.get(key)
  1264. if value and value != "":
  1265. try:
  1266. parsed_results[key] = json.loads(value)
  1267. except (json.JSONDecodeError, TypeError):
  1268. parsed_results[key] = None
  1269. else:
  1270. parsed_results[key] = None
  1271. # 映射状态
  1272. overall_status = result_data.get("overall_task_status", "unknown")
  1273. status_mapping = {
  1274. "completed": "completed",
  1275. "failed": "failed",
  1276. "terminated": "cancelled",
  1277. "pending": "pending", # 【新增】支持预注册状态
  1278. "processing": "processing"
  1279. }
  1280. status = status_mapping.get(overall_status, overall_status)
  1281. # 【新增】如果是预注册状态,添加标记
  1282. is_pre_registered = result_data.get("pre_registered") == "true"
  1283. result = {
  1284. "callback_task_id": result_data.get("callback_task_id"),
  1285. "user_id": result_data.get("user_id"),
  1286. "project_name": result_data.get("project_name", ""),
  1287. "project_type": result_data.get("project_type", ""),
  1288. "status": status,
  1289. "start_time": None,
  1290. "running_duration": 0,
  1291. "results": {
  1292. "outline_structure": parsed_results.get("outline_structure"),
  1293. "key_points": parsed_results.get("key_points"),
  1294. "similar_cases": parsed_results.get("similar_cases"),
  1295. "similar_fragments": parsed_results.get("similar_fragments"),
  1296. "knowledge_bases": parsed_results.get("knowledge_bases"),
  1297. "error": result_data.get("error_message") or None
  1298. }
  1299. }
  1300. # 【新增】如果是预注册状态,添加额外信息
  1301. if is_pre_registered:
  1302. result["is_pre_registered"] = True
  1303. result["pre_registered_at"] = result_data.get("pre_registered_at")
  1304. return result
  1305. return None
  1306. except Exception as e:
  1307. logger.error(f"获取大纲任务信息失败: {str(e)}", exc_info=True)
  1308. return None