workflow_manager.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422
  1. """
  2. 基于LangGraph的工作流管理器
  3. 负责任务的创建、编排和执行,使用LangGraph进行状态管理
  4. 新增功能:
  5. - 任务终止管理
  6. - 终止信号设置和检测
  7. """
  8. import asyncio
  9. import time
  10. import json
  11. import os
  12. from typing import Dict, Optional, Any
  13. from datetime import datetime
  14. from langgraph.graph import StateGraph, END
  15. from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
  16. from foundation.observability.logger.loggering import review_logger as logger
  17. from foundation.observability.monitoring.time_statistics import track_execution_time
  18. from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory
  19. from .progress_manager import ProgressManager
  20. from .redis_duplicate_checker import RedisDuplicateChecker
  21. from .task_models import TaskFileInfo, TaskChain
  22. from ..construction_review.workflows import DocumentWorkflow, AIReviewWorkflow, ReportWorkflow
  23. from ..construction_review.workflows.types import TaskChainState
  24. class ProgressManagerRegistry:
  25. """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
  26. _registry = {} # {callback_task_id: ProgressManager}
  27. @classmethod
  28. def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
  29. """注册ProgressManager实例"""
  30. cls._registry[callback_task_id] = progress_manager
  31. logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
  32. @classmethod
  33. def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
  34. """获取ProgressManager实例"""
  35. return cls._registry.get(callback_task_id)
  36. @classmethod
  37. def unregister_progress_manager(cls, callback_task_id: str):
  38. """注销ProgressManager实例"""
  39. if callback_task_id in cls._registry:
  40. del cls._registry[callback_task_id]
  41. logger.info(f"注销ProgressManager实例: {callback_task_id}")
  42. class WorkflowManager:
  43. """工作流管理器"""
  44. def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10):
  45. self.max_concurrent_docs = max_concurrent_docs
  46. self.max_concurrent_reviews = max_concurrent_reviews
  47. # 并发控制
  48. self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs)
  49. self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
  50. # 服务组件
  51. self.progress_manager = ProgressManager()
  52. self.redis_duplicate_checker = RedisDuplicateChecker()
  53. # 活跃任务跟踪
  54. self.active_chains: Dict[str, TaskChain] = {}
  55. self._cleanup_task_started = False
  56. # 任务终止管理
  57. self._terminate_signal_prefix = "ai_review:terminate_signal:"
  58. self._task_expire_time = 7200 # 2小时
  59. # LangGraph 任务链工作流(方案D)
  60. self.task_chain_graph = None # 延迟初始化,避免循环导入
  61. # ==================== 施工方案编写任务管理 ====================
  62. # 大纲生成活跃任务跟踪
  63. self.active_outline_tasks: Dict[str, Any] = {}
  64. # 大纲生成任务 Redis 前缀
  65. self._outline_result_prefix = "outline_write:result:"
  66. self._outline_terminate_signal_prefix = "outline_write:terminate_signal:"
  67. # 大纲生成工作流图(延迟初始化)
  68. self.outline_generation_graph = None
  69. async def submit_task_processing(self, file_info: dict) -> str:
  70. """异步提交任务处理(用于file_upload层)"""
  71. from foundation.infrastructure.messaging.tasks import submit_task_processing_task
  72. from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
  73. try:
  74. logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
  75. # 使用CeleryTraceManager提交任务,自动传递trace_id
  76. task = CeleryTraceManager.submit_celery_task(
  77. submit_task_processing_task,
  78. file_info
  79. )
  80. logger.info(f"Celery任务已提交,Task ID: {task.id}")
  81. return task.id
  82. except Exception as e:
  83. logger.error(f"提交Celery任务失败: {str(e)}")
  84. raise
  85. @track_execution_time
  86. def submit_construction_review_task_processing_sync(self, file_info: dict) -> dict:
  87. """
  88. 同步提交施工审查任务处理(用于Celery worker)
  89. Note:
  90. 已切换到 LangGraph 任务链工作流(方案D)
  91. 使用统一的状态管理和嵌套子图架构
  92. """
  93. try:
  94. logger.info(f"提交文档处理任务(LangGraph方案D): {file_info['file_id']}")
  95. # 1. 创建TaskFileInfo对象(封装任务文件信息)
  96. task_file_info = TaskFileInfo(file_info)
  97. logger.info(f"创建任务文件信息: {task_file_info}")
  98. # 2. 生成任务链ID
  99. callback_task_id = task_file_info.callback_task_id
  100. # 3. 创建任务链(引用 TaskFileInfo,避免数据重复)
  101. task_chain = TaskChain(task_file_info)
  102. # 4. 标记任务开始
  103. task_chain.start_processing()
  104. # 5. 添加到活跃任务跟踪
  105. self.active_chains[callback_task_id] = task_chain
  106. # 6. 初始化进度管理
  107. asyncio.run(self.progress_manager.initialize_progress(
  108. callback_task_id=callback_task_id,
  109. user_id=task_file_info.user_id,
  110. stages=[]
  111. ))
  112. # 7. 构建 LangGraph 任务链工作流(延迟初始化)
  113. if self.task_chain_graph is None:
  114. self.task_chain_graph = self._build_task_chain_workflow()
  115. # 8. 构建初始状态
  116. initial_state = TaskChainState(
  117. file_id=task_file_info.file_id,
  118. callback_task_id=callback_task_id,
  119. user_id=task_file_info.user_id,
  120. file_name=task_file_info.file_name,
  121. file_type=task_file_info.file_type,
  122. file_content=task_file_info.file_content,
  123. current_stage="start",
  124. overall_task_status="processing",
  125. stage_status={
  126. "document": "pending",
  127. "ai_review": "pending",
  128. "report": "pending"
  129. },
  130. document_result=None,
  131. ai_review_result=None,
  132. report_result=None,
  133. error_message=None,
  134. progress_manager=self.progress_manager,
  135. task_file_info=task_file_info,
  136. messages=[HumanMessage(content=f"开始任务链: {task_file_info.file_id}")]
  137. )
  138. # 9. 执行 LangGraph 任务链工作流
  139. loop = asyncio.new_event_loop()
  140. asyncio.set_event_loop(loop)
  141. result = loop.run_until_complete(self.task_chain_graph.ainvoke(initial_state))
  142. loop.close()
  143. # 10. 清理任务注册
  144. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  145. logger.info(f"施工方案审查任务已完成(LangGraph方案D)!")
  146. logger.info(f"文件ID: {task_file_info.file_id}")
  147. logger.info(f"文件名: {task_file_info.file_name}")
  148. logger.info(f"整体状态: {result.get('overall_task_status', 'unknown')}")
  149. # 构建可序列化的返回结果(移除不可序列化的对象)
  150. serializable_result = {
  151. "file_id": result.get("file_id"),
  152. "callback_task_id": result.get("callback_task_id"),
  153. "user_id": result.get("user_id"),
  154. "file_name": result.get("file_name"),
  155. "current_stage": result.get("current_stage"),
  156. "overall_task_status": result.get("overall_task_status"),
  157. "stage_status": result.get("stage_status"),
  158. "error_message": result.get("error_message"),
  159. # 注意:不包含 progress_manager, task_file_info, messages 等不可序列化对象
  160. }
  161. return serializable_result
  162. except Exception as e:
  163. logger.error(f"提交文档处理任务失败: {str(e)}", exc_info=True)
  164. # 标记任务失败
  165. if callback_task_id in self.active_chains:
  166. self.active_chains[callback_task_id].fail_processing(str(e))
  167. # 清理任务注册
  168. asyncio.run(self.redis_duplicate_checker.unregister_task(task_file_info.file_id))
  169. # 通知SSE连接任务失败
  170. error_data = {
  171. "error": str(e),
  172. "status": "failed",
  173. "overall_task_status": "failed",
  174. "timestamp": datetime.now().isoformat()
  175. }
  176. asyncio.run(self.progress_manager.complete_task(callback_task_id, task_file_info.user_id, error_data))
  177. raise
  178. finally:
  179. # 清理活跃任务
  180. if callback_task_id in self.active_chains:
  181. del self.active_chains[callback_task_id]
  182. async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
  183. """
  184. 设置任务终止信号
  185. Args:
  186. callback_task_id: 任务回调ID
  187. operator: 操作人(用户ID或系统标识)
  188. Returns:
  189. Dict: 操作结果 {"success": bool, "message": str, "sgbx_task_info": dict}
  190. Note:
  191. 将终止信号写入 Redis,支持跨进程检测
  192. AI审查节点在执行前会检查此信号
  193. """
  194. try:
  195. # 检查任务是否在活跃列表中
  196. if callback_task_id not in self.active_chains:
  197. return {
  198. "success": False,
  199. "message": f"任务不存在或已完成: {callback_task_id}",
  200. "sgbx_task_info": None
  201. }
  202. task_chain = self.active_chains[callback_task_id]
  203. # 检查任务状态
  204. if task_chain.status != "processing":
  205. return {
  206. "success": False,
  207. "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})",
  208. "sgbx_task_info": {
  209. "callback_task_id": callback_task_id,
  210. "status": task_chain.status,
  211. "file_name": task_chain.file_name
  212. }
  213. }
  214. # 设置 Redis 终止信号
  215. redis_client = await RedisConnectionFactory.get_connection()
  216. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  217. # 存储终止信号和操作人、时间
  218. terminate_data = {
  219. "operator": operator,
  220. "terminate_time": str(time.time()),
  221. "task_id": callback_task_id
  222. }
  223. # 使用 hash 存储更多信息
  224. await redis_client.hset(terminate_key, mapping=terminate_data)
  225. # 设置过期时间(2小时)
  226. await redis_client.expire(terminate_key, self._task_expire_time)
  227. logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})")
  228. return {
  229. "success": True,
  230. "message": f"终止信号已设置,任务将在当前节点完成后终止",
  231. "sgbx_task_info": {
  232. "callback_task_id": callback_task_id,
  233. "file_id": task_chain.file_id,
  234. "file_name": task_chain.file_name,
  235. "user_id": task_chain.user_id,
  236. "status": task_chain.status,
  237. "current_stage": task_chain.current_stage
  238. }
  239. }
  240. except Exception as e:
  241. logger.error(f"设置终止信号失败: {str(e)}", exc_info=True)
  242. return {
  243. "success": False,
  244. "message": f"设置终止信号失败: {str(e)}",
  245. "sgbx_task_info": None
  246. }
  247. async def check_terminate_signal(self, callback_task_id: str) -> bool:
  248. """
  249. 检查是否有终止信号
  250. Args:
  251. callback_task_id: 任务回调ID
  252. Returns:
  253. bool: 有终止信号返回 True
  254. Note:
  255. 从 Redis 读取终止信号
  256. 工作流节点在执行前调用此方法检查是否需要终止
  257. """
  258. try:
  259. redis_client = await RedisConnectionFactory.get_connection()
  260. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  261. # 检查键是否存在
  262. exists = await redis_client.exists(terminate_key)
  263. if exists:
  264. # 读取终止信息
  265. terminate_info = await redis_client.hgetall(terminate_key)
  266. logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
  267. return True
  268. return False
  269. except RuntimeError as e:
  270. # 事件循环相关的错误处理
  271. error_msg = str(e)
  272. if "Event loop is closed" in error_msg:
  273. # 事件循环关闭是正常情况(任务结束),不记录错误
  274. logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}")
  275. return False
  276. elif "bound to a different event loop" in error_msg:
  277. # 事件循环不匹配,记录警告但不中断流程
  278. logger.warning(f"检查终止信号时检测到事件循环不匹配: {callback_task_id},将忽略本次检查")
  279. return False
  280. else:
  281. # 其他 RuntimeError 记录错误
  282. logger.error(f"检查终止信号失败(RuntimeError): {error_msg}", exc_info=True)
  283. return False
  284. except Exception as e:
  285. # 其他异常仍然记录错误
  286. logger.error(f"检查终止信号失败: {str(e)}", exc_info=True)
  287. return False
  288. async def clear_terminate_signal(self, callback_task_id: str):
  289. """
  290. 清理 Redis 中的终止信号
  291. Args:
  292. callback_task_id: 任务回调ID
  293. """
  294. try:
  295. redis_client = await RedisConnectionFactory.get_connection()
  296. terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
  297. await redis_client.delete(terminate_key)
  298. logger.debug(f"清理终止信号: {callback_task_id}")
  299. except Exception as e:
  300. logger.warning(f"清理终止信号失败: {str(e)}")
  301. async def get_active_tasks(self) -> list:
  302. """
  303. 获取活跃任务列表
  304. Returns:
  305. list: 活跃任务信息列表
  306. """
  307. try:
  308. active_tasks = []
  309. current_time = time.time()
  310. for task_id, task_chain in self.active_chains.items():
  311. if task_chain.status == "processing":
  312. sgbx_task_info = {
  313. "callback_task_id": task_id,
  314. "file_id": task_chain.file_id,
  315. "file_name": task_chain.file_name,
  316. "user_id": task_chain.user_id,
  317. "status": task_chain.status,
  318. "current_stage": task_chain.current_stage,
  319. "start_time": task_chain.start_time,
  320. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0
  321. }
  322. active_tasks.append(sgbx_task_info)
  323. return active_tasks
  324. except Exception as e:
  325. logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
  326. return []
  327. async def get_sgbx_task_info(self, callback_task_id: str) -> Optional[Dict]:
  328. """
  329. 获取任务信息
  330. Args:
  331. callback_task_id: 任务回调ID
  332. Returns:
  333. Optional[Dict]: 任务信息字典,不存在返回 None
  334. """
  335. try:
  336. task_chain = self.active_chains.get(callback_task_id)
  337. if task_chain:
  338. current_time = time.time()
  339. return {
  340. "callback_task_id": callback_task_id,
  341. "file_id": task_chain.file_id,
  342. "file_name": task_chain.file_name,
  343. "user_id": task_chain.user_id,
  344. "status": task_chain.status,
  345. "current_stage": task_chain.current_stage,
  346. "start_time": task_chain.start_time,
  347. "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0,
  348. "results": task_chain.results
  349. }
  350. return None
  351. except Exception as e:
  352. logger.error(f"获取任务信息失败: {str(e)}", exc_info=True)
  353. return None
  354. def _build_task_chain_workflow(self) -> StateGraph:
  355. """
  356. 构建 LangGraph 任务链工作流图(方案D)
  357. Returns:
  358. StateGraph: 配置完成的 LangGraph 任务链图实例
  359. Note:
  360. 创建包含文档处理、AI审查(嵌套子图)、报告生成的完整任务链
  361. 设置节点间的转换关系和条件边,支持终止检查和错误处理
  362. 工作流路径: start → document_processing → ai_review_subgraph → report_generation → complete → END
  363. """
  364. logger.info("开始构建 LangGraph 任务链工作流图")
  365. workflow = StateGraph(TaskChainState)
  366. # 添加节点
  367. workflow.add_node("start", self._start_chain_node)
  368. workflow.add_node("document_processing", self._document_processing_node)
  369. workflow.add_node("ai_review_subgraph", self._ai_review_subgraph_node)
  370. workflow.add_node("report_generation", self._report_generation_node)
  371. workflow.add_node("complete", self._complete_chain_node)
  372. workflow.add_node("error_handler", self._error_handler_chain_node)
  373. workflow.add_node("terminate", self._terminate_chain_node)
  374. # 设置入口点
  375. workflow.set_entry_point("start")
  376. # 添加边和条件边
  377. workflow.add_edge("start", "document_processing")
  378. # 文档处理后检查终止信号
  379. workflow.add_conditional_edges(
  380. "document_processing",
  381. self._should_terminate_or_error_chain,
  382. {
  383. "terminate": "terminate",
  384. "error": "error_handler",
  385. "continue": "ai_review_subgraph"
  386. }
  387. )
  388. # AI审查后检查终止信号
  389. workflow.add_conditional_edges(
  390. "ai_review_subgraph",
  391. self._should_terminate_or_error_chain,
  392. {
  393. "terminate": "terminate",
  394. "error": "error_handler",
  395. "continue": "report_generation"
  396. }
  397. )
  398. # 报告生成后检查终止信号
  399. workflow.add_conditional_edges(
  400. "report_generation",
  401. self._should_terminate_or_error_chain,
  402. {
  403. "terminate": "terminate",
  404. "error": "error_handler",
  405. "continue": "complete"
  406. }
  407. )
  408. # 完成节点直接结束
  409. workflow.add_edge("complete", END)
  410. workflow.add_edge("error_handler", END)
  411. workflow.add_edge("terminate", END)
  412. # 编译工作流图
  413. compiled_graph = workflow.compile()
  414. # 保存工作流图到 temp/construction_review 目录
  415. self._save_workflow_graph(compiled_graph, "temp/construction_review/task_chain_workflow.png")
  416. logger.info("LangGraph 任务链工作流图构建完成")
  417. return compiled_graph
  418. def _save_workflow_graph(self, compiled_graph: StateGraph, output_path: str):
  419. """
  420. 保存 LangGraph 工作流图为 PNG 图片
  421. Args:
  422. compiled_graph: 编译后的 LangGraph 工作流图
  423. output_path: 输出文件路径
  424. """
  425. try:
  426. # 确保输出目录存在
  427. output_dir = os.path.dirname(output_path)
  428. if output_dir and not os.path.exists(output_dir):
  429. os.makedirs(output_dir, exist_ok=True)
  430. logger.info(f"创建输出目录:{output_dir}")
  431. # 使用 graphviz 保存图片
  432. # 需要安装 graphviz 和 graphviz Python 包
  433. try:
  434. graph = compiled_graph.get_graph()
  435. graph_image = graph.draw_mermaid_png()
  436. with open(output_path, "wb") as f:
  437. f.write(graph_image)
  438. logger.info(f"工作流图已保存到:{output_path}")
  439. except Exception as graphviz_error:
  440. logger.warning(f"Graphviz 保存失败:{str(graphviz_error)}, 尝试使用 JSON 格式保存")
  441. # 备用方案:保存为 JSON 格式
  442. json_path = output_path.replace(".png", ".json")
  443. with open(json_path, "w", encoding="utf-8") as f:
  444. json.dump(compiled_graph.get_graph().to_json(), f, indent=2, ensure_ascii=False)
  445. logger.info(f"工作流图已保存到 (JSON 格式): {json_path}")
  446. except Exception as e:
  447. logger.warning(f"保存工作流图失败:{str(e)}")
  448. async def _start_chain_node(self, state: TaskChainState) -> TaskChainState:
  449. """
  450. 任务链开始节点
  451. Args:
  452. state: 任务链状态
  453. Returns:
  454. TaskChainState: 更新后的状态
  455. """
  456. logger.info(f"任务链工作流启动: {state['callback_task_id']}")
  457. return {
  458. "current_stage": "start",
  459. "overall_task_status": "processing",
  460. "stage_status": {
  461. "document": "pending",
  462. "ai_review": "pending",
  463. "report": "pending"
  464. },
  465. "messages": [AIMessage(content="任务链工作流启动")]
  466. }
  467. async def _document_processing_node(self, state: TaskChainState) -> TaskChainState:
  468. """
  469. 文档处理节点
  470. Args:
  471. state: 任务链状态
  472. Returns:
  473. TaskChainState: 更新后的状态,包含文档处理结果
  474. """
  475. try:
  476. logger.info(f"开始文档处理阶段: {state['callback_task_id']}")
  477. # 检查终止信号
  478. if await self.check_terminate_signal(state["callback_task_id"]):
  479. logger.warning(f"文档处理阶段检测到终止信号: {state['callback_task_id']}")
  480. return {
  481. "current_stage": "document_processing",
  482. "overall_task_status": "terminated",
  483. "stage_status": {**state["stage_status"], "document": "terminated"},
  484. "messages": [AIMessage(content="文档处理阶段检测到终止信号")]
  485. }
  486. # 获取 TaskFileInfo 实例
  487. task_file_info = state["task_file_info"]
  488. # 创建文档工作流实例
  489. document_workflow = DocumentWorkflow(
  490. task_file_info=task_file_info,
  491. progress_manager=state["progress_manager"],
  492. redis_duplicate_checker=self.redis_duplicate_checker
  493. )
  494. # 执行文档处理
  495. doc_result = await document_workflow.execute(
  496. state["file_content"],
  497. state["file_type"]
  498. )
  499. logger.info(f"文档处理完成: {state['callback_task_id']}")
  500. return {
  501. "current_stage": "document_processing",
  502. "overall_task_status": "processing",
  503. "stage_status": {**state["stage_status"], "document": "completed"},
  504. "document_result": doc_result,
  505. "messages": [AIMessage(content="文档处理完成")]
  506. }
  507. except Exception as e:
  508. logger.error(f"文档处理失败: {str(e)}", exc_info=True)
  509. return {
  510. "current_stage": "document_processing",
  511. "overall_task_status": "failed",
  512. "stage_status": {**state["stage_status"], "document": "failed"},
  513. "error_message": f"文档处理失败: {str(e)}",
  514. "messages": [AIMessage(content=f"文档处理失败: {str(e)}")]
  515. }
  516. async def _ai_review_subgraph_node(self, state: TaskChainState) -> TaskChainState:
  517. """
  518. AI审查子图节点(嵌套现有的 AIReviewWorkflow)
  519. Args:
  520. state: 任务链状态
  521. Returns:
  522. TaskChainState: 更新后的状态,包含AI审查结果
  523. Note:
  524. 这是方案D的核心实现:将现有的 AIReviewWorkflow 作为子图嵌套
  525. 无需修改 AIReviewWorkflow 的代码,保持其独立性
  526. """
  527. try:
  528. logger.info(f"开始AI审查阶段: {state['callback_task_id']}")
  529. # 检查终止信号
  530. if await self.check_terminate_signal(state["callback_task_id"]):
  531. logger.warning(f"AI审查阶段检测到终止信号: {state['callback_task_id']}")
  532. return {
  533. "current_stage": "ai_review",
  534. "overall_task_status": "terminated",
  535. "stage_status": {**state["stage_status"], "ai_review": "terminated"},
  536. "messages": [AIMessage(content="AI审查阶段检测到终止信号")]
  537. }
  538. # 获取文档处理结果中的结构化内容
  539. structured_content = state["document_result"].get("structured_content")
  540. if not structured_content:
  541. raise ValueError("文档处理结果中缺少结构化内容")
  542. # 获取 TaskFileInfo 实例
  543. task_file_info = state["task_file_info"]
  544. # 读取AI审查配置
  545. import configparser
  546. config = configparser.ConfigParser()
  547. config.read('config/config.ini', encoding='utf-8')
  548. max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
  549. if max_review_units == 0:
  550. max_review_units = None
  551. review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
  552. logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}")
  553. # 创建AI审查工作流实例(作为嵌套子图)
  554. ai_workflow = AIReviewWorkflow(
  555. task_file_info=task_file_info,
  556. structured_content=structured_content,
  557. progress_manager=state["progress_manager"],
  558. max_review_units=max_review_units,
  559. review_mode=review_mode
  560. )
  561. # 执行AI审查(内部使用 LangGraph)
  562. ai_result = await ai_workflow.execute()
  563. logger.info(f"AI审查完成: {state['callback_task_id']}")
  564. return {
  565. "current_stage": "ai_review",
  566. "overall_task_status": "processing",
  567. "stage_status": {**state["stage_status"], "ai_review": "completed"},
  568. "ai_review_result": ai_result,
  569. "messages": [AIMessage(content="AI审查完成")]
  570. }
  571. except Exception as e:
  572. logger.error(f"AI审查失败: {str(e)}", exc_info=True)
  573. return {
  574. "current_stage": "ai_review",
  575. "overall_task_status": "failed",
  576. "stage_status": {**state["stage_status"], "ai_review": "failed"},
  577. "error_message": f"AI审查失败: {str(e)}",
  578. "messages": [AIMessage(content=f"AI审查失败: {str(e)}")]
  579. }
  580. async def _report_generation_node(self, state: TaskChainState) -> TaskChainState:
  581. """
  582. 报告生成节点
  583. Args:
  584. state: 任务链状态
  585. Returns:
  586. TaskChainState: 更新后的状态,包含报告生成结果
  587. Note:
  588. 调用ReportWorkflow生成审查报告摘要(基于高中风险问题,使用LLM)
  589. 根据决策2(方案A-方式1),在此阶段生成完整报告后一次性保存
  590. """
  591. try:
  592. logger.info(f"开始报告生成阶段: {state['callback_task_id']}")
  593. # 检查终止信号
  594. if await self.check_terminate_signal(state["callback_task_id"]):
  595. logger.warning(f"报告生成阶段检测到终止信号: {state['callback_task_id']}")
  596. return {
  597. "current_stage": "report_generation",
  598. "overall_task_status": "terminated",
  599. "stage_status": {**state["stage_status"], "report": "terminated"},
  600. "messages": [AIMessage(content="报告生成阶段检测到终止信号")]
  601. }
  602. # 获取AI审查结果
  603. ai_review_result = state.get("ai_review_result")
  604. if not ai_review_result:
  605. raise ValueError("AI审查结果缺失,无法生成报告")
  606. # 获取 TaskFileInfo 实例
  607. task_file_info = state["task_file_info"]
  608. # 创建报告生成工作流实例
  609. report_workflow = ReportWorkflow(
  610. file_id=state["file_id"],
  611. file_name=state["file_name"],
  612. callback_task_id=state["callback_task_id"],
  613. user_id=state["user_id"],
  614. ai_review_results=ai_review_result,
  615. progress_manager=state["progress_manager"]
  616. )
  617. # 执行报告生成
  618. report_result = await report_workflow.execute()
  619. # 检查是否为降级报告
  620. is_fallback = report_result.get('is_fallback', False)
  621. if is_fallback:
  622. logger.warning(f"报告生成使用了降级方案: {state['callback_task_id']}")
  623. else:
  624. logger.info(f"报告生成完成: {state['callback_task_id']}")
  625. # 保存完整结果(包含文档处理、AI审查、报告生成)
  626. await self._save_complete_results(state, report_result)
  627. return {
  628. "current_stage": "report_generation",
  629. "overall_task_status": "processing",
  630. "stage_status": {**state["stage_status"], "report": "completed"},
  631. "report_result": report_result,
  632. "messages": [AIMessage(content="报告生成完成")]
  633. }
  634. except Exception as e:
  635. logger.error(f"报告生成失败: {str(e)}", exc_info=True)
  636. return {
  637. "current_stage": "report_generation",
  638. "overall_task_status": "failed",
  639. "stage_status": {**state["stage_status"], "report": "failed"},
  640. "error_message": f"报告生成失败: {str(e)}",
  641. "messages": [AIMessage(content=f"报告生成失败: {str(e)}")]
  642. }
  643. async def _complete_chain_node(self, state: TaskChainState) -> TaskChainState:
  644. """
  645. 任务链完成节点
  646. Args:
  647. state: 任务链状态
  648. Returns:
  649. TaskChainState: 更新后的状态,标记整体任务已完成
  650. Note:
  651. 只有在所有阶段(文档处理、AI审查、报告生成)都完成后才标记 overall_task_status="completed"
  652. 这解决了原有的状态语义混乱问题(P0-1)
  653. """
  654. logger.info(f"任务链工作流完成: {state['callback_task_id']}")
  655. # 标记整体任务完成
  656. if state["progress_manager"]:
  657. await state["progress_manager"].complete_task(
  658. state["callback_task_id"],
  659. state["user_id"],
  660. {"overall_task_status": "completed", "message": "所有阶段已完成"}
  661. )
  662. # 清理 Redis 缓存
  663. try:
  664. from foundation.utils.redis_utils import delete_file_info
  665. await delete_file_info(state["file_id"])
  666. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  667. except Exception as e:
  668. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  669. return {
  670. "current_stage": "complete",
  671. "overall_task_status": "completed", # ⚠️ 关键:只有到这里才标记整体完成
  672. "messages": [AIMessage(content="任务链工作流完成")]
  673. }
  674. async def _error_handler_chain_node(self, state: TaskChainState) -> TaskChainState:
  675. """
  676. 任务链错误处理节点
  677. Args:
  678. state: 任务链状态
  679. Returns:
  680. TaskChainState: 更新后的状态,标记为失败
  681. """
  682. logger.error(f"任务链工作流错误: {state['callback_task_id']}, 错误: {state.get('error_message', '未知错误')}")
  683. # 通知失败
  684. if state["progress_manager"]:
  685. error_data = {
  686. "overall_task_status": "failed",
  687. "error": state.get("error_message", "未知错误"),
  688. "status": "failed",
  689. "timestamp": datetime.now().isoformat()
  690. }
  691. await state["progress_manager"].complete_task(
  692. state["callback_task_id"],
  693. state["user_id"],
  694. error_data
  695. )
  696. # 清理 Redis 缓存(即使失败也清理)
  697. try:
  698. from foundation.utils.redis_utils import delete_file_info
  699. await delete_file_info(state["file_id"])
  700. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  701. except Exception as e:
  702. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  703. return {
  704. "current_stage": "error_handler",
  705. "overall_task_status": "failed",
  706. "messages": [AIMessage(content=f"任务链错误: {state.get('error_message', '未知错误')}")]
  707. }
  708. async def _terminate_chain_node(self, state: TaskChainState) -> TaskChainState:
  709. """
  710. 任务链终止节点
  711. Args:
  712. state: 任务链状态
  713. Returns:
  714. TaskChainState: 更新后的状态,标记为已终止
  715. """
  716. logger.warning(f"任务链工作流已终止: {state['callback_task_id']}")
  717. # 通知终止
  718. if state["progress_manager"]:
  719. await state["progress_manager"].complete_task(
  720. state["callback_task_id"],
  721. state["user_id"],
  722. {"overall_task_status": "terminated", "message": "任务已被用户终止"}
  723. )
  724. # 清理 Redis 终止信号
  725. await self.clear_terminate_signal(state["callback_task_id"])
  726. # 清理 Redis 文件缓存
  727. try:
  728. from foundation.utils.redis_utils import delete_file_info
  729. await delete_file_info(state["file_id"])
  730. logger.info(f"已清理 Redis 文件缓存: {state['file_id']}")
  731. except Exception as e:
  732. logger.warning(f"清理 Redis 文件缓存失败: {str(e)}")
  733. return {
  734. "current_stage": "terminated",
  735. "overall_task_status": "terminated",
  736. "messages": [AIMessage(content="任务链已被终止")]
  737. }
  738. def _should_terminate_or_error_chain(self, state: TaskChainState) -> str:
  739. """
  740. 检查任务链是否应该终止或发生错误
  741. Args:
  742. state: 任务链状态
  743. Returns:
  744. str: "terminate", "error", 或 "continue"
  745. Note:
  746. 这是条件边判断方法,用于决定工作流的下一步走向
  747. 1. 优先检查终止信号
  748. 2. 检查是否有错误
  749. 3. 都没有则继续执行
  750. """
  751. # 检查终止状态
  752. if state.get("overall_task_status") == "terminated":
  753. return "terminate"
  754. # 检查错误状态
  755. if state.get("overall_task_status") == "failed" or state.get("error_message"):
  756. return "error"
  757. # 默认继续执行
  758. return "continue"
  759. async def _save_complete_results(self, state: TaskChainState, report_result: Dict[str, Any]):
  760. """
  761. 保存完整结果(方案A-方式1:一次性保存)
  762. Args:
  763. state: 任务链状态
  764. report_result: 报告生成结果
  765. Note:
  766. 根据决策2(方案A-方式1),在报告工作流完成后一次性保存完整结果
  767. 包含:文档处理结果 + AI审查结果 + 报告生成结果
  768. """
  769. try:
  770. from foundation.observability.cachefiles import cache, CacheBaseDir
  771. logger.info(f"开始保存完整结果: {state['callback_task_id']}")
  772. # 构建完整结果
  773. ai_review_result = state.get("ai_review_result")
  774. complete_results = {
  775. "callback_task_id": state["callback_task_id"],
  776. "file_id": state["file_id"],
  777. "file_name": state["file_name"],
  778. "user_id": state["user_id"],
  779. "overall_task_status": "processing", # 此时还在处理中,complete节点才标记completed
  780. "stage_status": state["stage_status"],
  781. "document_result": state.get("document_result"),
  782. "ai_review_result": ai_review_result,
  783. "issues": ai_review_result.get("review_results") if ai_review_result else None,
  784. "report_result": report_result,
  785. "timestamp": datetime.now().isoformat()
  786. }
  787. # 使用 cache_manager 保存(指定文件名)
  788. import os
  789. target_dir = os.path.join(CacheBaseDir.CONSTRUCTION_REVIEW.value, "final_result")
  790. logger.info(f"准备保存结果到目录: {target_dir}")
  791. file_path = cache.save(
  792. complete_results,
  793. subdir="final_result",
  794. filename=f"{state['callback_task_id']}.json",
  795. base_cache_dir=CacheBaseDir.CONSTRUCTION_REVIEW
  796. )
  797. logger.info(f"完整结果已保存到: {file_path}")
  798. # 验证文件是否保存到正确位置
  799. if "final_result" not in str(file_path):
  800. logger.warning(f"警告:结果文件可能未保存到正确的final_result目录: {file_path}")
  801. except Exception as e:
  802. logger.error(f"保存完整结果失败: {str(e)}", exc_info=True)
  803. raise
  804. # ==================== 施工方案编写任务管理方法 ====================
  805. async def submit_outline_generation_task(self, sgbx_task_info: dict) -> str:
  806. """
  807. 提交大纲生成任务到 Celery
  808. Args:
  809. sgbx_task_info: 任务信息字典
  810. {
  811. "user_id": str,
  812. "project_info": dict,
  813. "template_id": str,
  814. "outline_config": dict,
  815. "similarity_config": dict (可选),
  816. "knowledge_config": dict (可选)
  817. }
  818. Returns:
  819. str: Celery 任务 ID
  820. """
  821. from foundation.infrastructure.messaging.tasks import submit_outline_generation_task
  822. from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
  823. try:
  824. logger.info(f"提交大纲生成任务到Celery: user_id={sgbx_task_info.get('user_id')}")
  825. # 使用 CeleryTraceManager 提交任务,自动传递 trace_id
  826. task = CeleryTraceManager.submit_celery_task(
  827. submit_outline_generation_task,
  828. sgbx_task_info
  829. )
  830. logger.info(f"大纲生成Celery任务已提交,Task ID: {task.id}")
  831. return task.id
  832. except Exception as e:
  833. logger.error(f"提交大纲生成Celery任务失败: {str(e)}")
  834. raise
  835. @track_execution_time
  836. def submit_outline_generation_sync(self, sgbx_task_info: dict) -> dict:
  837. """
  838. 同步执行大纲生成任务(用于 Celery worker)
  839. Args:
  840. sgbx_task_info: 任务信息字典
  841. Returns:
  842. dict: 执行结果
  843. """
  844. import uuid
  845. from langchain_core.messages import HumanMessage
  846. from ..construction_write.component.state_models import OutlineGenerationState, OutlineTaskInfo
  847. from ..construction_write.workflows.outline_workflow import OutlineWorkflow
  848. callback_task_id = None
  849. try:
  850. logger.info(f"开始执行大纲生成任务(LangGraph)")
  851. # 1. 生成任务 ID(如果没有提供)
  852. callback_task_id = sgbx_task_info.get('callback_task_id') or f"outline_{uuid.uuid4().hex[:16]}"
  853. user_id = sgbx_task_info.get('user_id', 'unknown')
  854. # 2. 创建任务信息对象(与 outline_views.py 传入的参数保持一致)
  855. outline_sgbx_task_info = OutlineTaskInfo(
  856. callback_task_id=callback_task_id,
  857. user_id=user_id,
  858. project_info=sgbx_task_info.get('project_info', {}),
  859. template_id=sgbx_task_info.get('template_id', ''),
  860. generation_chapterenum=sgbx_task_info.get('generation_chapterenum', []),
  861. generation_template=sgbx_task_info.get('generation_template', []),
  862. similarity_config=sgbx_task_info.get('similarity_config', {
  863. "topk_plans": 3,
  864. "topk_fragments": 10,
  865. "threshold": 0.75
  866. }),
  867. knowledge_config=sgbx_task_info.get('knowledge_config', {
  868. "topk": 3,
  869. "threshold": 0.75
  870. })
  871. )
  872. # 3. 添加到活跃任务跟踪
  873. self.active_outline_tasks[callback_task_id] = outline_sgbx_task_info
  874. # 4. 初始化进度管理
  875. loop = asyncio.new_event_loop()
  876. asyncio.set_event_loop(loop)
  877. loop.run_until_complete(self.progress_manager.initialize_progress(
  878. callback_task_id=callback_task_id,
  879. user_id=user_id,
  880. stages=[
  881. {"stage": "start", "status": "pending"},
  882. {"stage": "template_loading", "status": "pending"},
  883. {"stage": "outline_generation", "status": "pending"},
  884. {"stage": "similar_cases", "status": "pending"},
  885. {"stage": "similar_fragments", "status": "pending"},
  886. {"stage": "knowledge_bases", "status": "pending"},
  887. {"stage": "complete", "status": "pending"}
  888. ]
  889. ))
  890. # 4.1 注册 ProgressManager 到 Registry(供节点访问)
  891. ProgressManagerRegistry.register_progress_manager(callback_task_id, self.progress_manager)
  892. # 4.2 标记任务开始
  893. outline_sgbx_task_info.start_processing()
  894. # 5. 构建 LangGraph 大纲生成工作流(延迟初始化)
  895. if self.outline_generation_graph is None:
  896. outline_workflow = OutlineWorkflow()
  897. self.outline_generation_graph = outline_workflow.build_graph()
  898. # 6. 构建初始状态
  899. # 注意:progress_manager 和 sgbx_task_info 不能放入状态(不可序列化)
  900. # 它们通过类实例变量访问
  901. # 从 OutlineTaskInfo 中提取参数,与 outline_views.py 保持一致
  902. initial_state = OutlineGenerationState(
  903. callback_task_id=callback_task_id,
  904. user_id=user_id,
  905. project_info=outline_sgbx_task_info.project_info,
  906. template_id=outline_sgbx_task_info.template_id,
  907. # 直接使用 generation_chapterenum 和 generation_template(替代 outline_config)
  908. generation_chapterenum=outline_sgbx_task_info.generation_chapterenum,
  909. generation_template=outline_sgbx_task_info.generation_template,
  910. similarity_config=outline_sgbx_task_info.similarity_config,
  911. knowledge_config=outline_sgbx_task_info.knowledge_config,
  912. template=None,
  913. outline_structure=None,
  914. key_points=None,
  915. similar_cases=None,
  916. similar_fragments=None,
  917. knowledge_bases=None,
  918. current_stage="start",
  919. overall_task_status="processing",
  920. error_message=None,
  921. messages=[HumanMessage(content=f"开始大纲生成任务: {callback_task_id}")]
  922. )
  923. # 7. 执行 LangGraph 工作流
  924. # 需要提供 config 参数给 Checkpointer
  925. result = loop.run_until_complete(
  926. self.outline_generation_graph.ainvoke(
  927. initial_state,
  928. config={"configurable": {"thread_id": callback_task_id}}
  929. )
  930. )
  931. loop.close()
  932. logger.info(f"大纲生成任务完成!callback_task_id={callback_task_id}")
  933. # 8. 更新任务状态
  934. if result.get("overall_task_status") == "completed":
  935. outline_sgbx_task_info.complete_processing({
  936. "outline_structure": result.get("outline_structure"),
  937. "key_points": result.get("key_points"),
  938. "similar_cases": result.get("similar_cases"),
  939. "similar_fragments": result.get("similar_fragments"),
  940. "knowledge_bases": result.get("knowledge_bases")
  941. })
  942. elif result.get("overall_task_status") == "failed":
  943. outline_sgbx_task_info.fail_processing(result.get("error_message", "未知错误"))
  944. elif result.get("overall_task_status") == "terminated":
  945. outline_sgbx_task_info.cancel_processing()
  946. # 8.5 将任务结果保存到 Redis(供跨进程访问)
  947. async def save_result_to_redis():
  948. redis_client = await RedisConnectionFactory.get_connection()
  949. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  950. # 构建结果数据(过滤 None 值,Redis 不支持)
  951. result_data = {
  952. "callback_task_id": callback_task_id,
  953. "user_id": user_id,
  954. "overall_task_status": result.get("overall_task_status", ""),
  955. "outline_structure": json.dumps(result.get("outline_structure"), ensure_ascii=False) if result.get("outline_structure") else "",
  956. "key_points": json.dumps(result.get("key_points"), ensure_ascii=False) if result.get("key_points") else "",
  957. "similar_cases": json.dumps(result.get("similar_cases"), ensure_ascii=False) if result.get("similar_cases") else "",
  958. "similar_fragments": json.dumps(result.get("similar_fragments"), ensure_ascii=False) if result.get("similar_fragments") else "",
  959. "knowledge_bases": json.dumps(result.get("knowledge_bases"), ensure_ascii=False) if result.get("knowledge_bases") else "",
  960. "error_message": result.get("error_message") or "",
  961. "completed_time": str(time.time())
  962. }
  963. # 保存到 Redis(设置过期时间2小时)
  964. await redis_client.hmset(result_key, result_data)
  965. await redis_client.expire(result_key, self._task_expire_time)
  966. logger.info(f"大纲生成结果已保存到 Redis: {callback_task_id}")
  967. # 在同步函数中运行异步代码
  968. loop = asyncio.new_event_loop()
  969. asyncio.set_event_loop(loop)
  970. try:
  971. loop.run_until_complete(save_result_to_redis())
  972. finally:
  973. loop.close()
  974. # 9. 返回可序列化结果
  975. return {
  976. "callback_task_id": result.get("callback_task_id"),
  977. "user_id": result.get("user_id"),
  978. "overall_task_status": result.get("overall_task_status"),
  979. "outline_structure": result.get("outline_structure"),
  980. "key_points": result.get("key_points"),
  981. "similar_cases": result.get("similar_cases"),
  982. "similar_fragments": result.get("similar_fragments"),
  983. "knowledge_bases": result.get("knowledge_bases"),
  984. "error_message": result.get("error_message")
  985. }
  986. except Exception as e:
  987. logger.error(f"大纲生成任务失败: {str(e)}", exc_info=True)
  988. # 标记任务失败
  989. if callback_task_id and callback_task_id in self.active_outline_tasks:
  990. self.active_outline_tasks[callback_task_id].fail_processing(str(e))
  991. raise
  992. finally:
  993. # 清理活跃任务
  994. if callback_task_id and callback_task_id in self.active_outline_tasks:
  995. del self.active_outline_tasks[callback_task_id]
  996. # 清理 Registry
  997. ProgressManagerRegistry.unregister_progress_manager(callback_task_id)
  998. async def set_outline_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
  999. """
  1000. 设置大纲生成任务终止信号
  1001. Args:
  1002. callback_task_id: 任务回调ID
  1003. operator: 操作人
  1004. Returns:
  1005. Dict: 操作结果
  1006. """
  1007. try:
  1008. # 检查任务是否在活跃列表中
  1009. if callback_task_id not in self.active_outline_tasks:
  1010. return {
  1011. "success": False,
  1012. "message": f"任务不存在或已完成: {callback_task_id}",
  1013. "sgbx_task_info": None
  1014. }
  1015. sgbx_task_info = self.active_outline_tasks[callback_task_id]
  1016. # 检查任务状态
  1017. if sgbx_task_info.status != "processing":
  1018. return {
  1019. "success": False,
  1020. "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {sgbx_task_info.status})",
  1021. "sgbx_task_info": {
  1022. "callback_task_id": callback_task_id,
  1023. "status": sgbx_task_info.status,
  1024. "project_name": sgbx_task_info.project_name
  1025. }
  1026. }
  1027. # 设置 Redis 终止信号
  1028. redis_client = await RedisConnectionFactory.get_connection()
  1029. terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
  1030. # 存储终止信号和操作人、时间
  1031. terminate_data = {
  1032. "operator": operator,
  1033. "terminate_time": str(time.time()),
  1034. "task_id": callback_task_id
  1035. }
  1036. # 使用 hash 存储更多信息
  1037. await redis_client.hset(terminate_key, mapping=terminate_data)
  1038. # 设置过期时间(2小时)
  1039. await redis_client.expire(terminate_key, self._task_expire_time)
  1040. logger.info(f"已设置大纲任务终止信号: {callback_task_id} (操作人: {operator}, 项目: {sgbx_task_info.project_name})")
  1041. return {
  1042. "success": True,
  1043. "message": f"终止信号已设置,任务将在当前节点完成后终止",
  1044. "sgbx_task_info": {
  1045. "callback_task_id": callback_task_id,
  1046. "user_id": sgbx_task_info.user_id,
  1047. "project_name": sgbx_task_info.project_name,
  1048. "status": sgbx_task_info.status
  1049. }
  1050. }
  1051. except Exception as e:
  1052. logger.error(f"设置大纲任务终止信号失败: {str(e)}", exc_info=True)
  1053. return {
  1054. "success": False,
  1055. "message": f"设置终止信号失败: {str(e)}",
  1056. "sgbx_task_info": None
  1057. }
  1058. async def check_outline_terminate_signal(self, callback_task_id: str) -> bool:
  1059. """
  1060. 检查大纲生成任务是否有终止信号
  1061. Args:
  1062. callback_task_id: 任务回调ID
  1063. Returns:
  1064. bool: 有终止信号返回 True
  1065. """
  1066. try:
  1067. redis_client = await RedisConnectionFactory.get_connection()
  1068. terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
  1069. # 检查键是否存在
  1070. exists = await redis_client.exists(terminate_key)
  1071. if exists:
  1072. # 读取终止信息
  1073. terminate_info = await redis_client.hgetall(terminate_key)
  1074. logger.warning(f"检测到大纲任务终止信号: {callback_task_id}, "
  1075. f"操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
  1076. return True
  1077. return False
  1078. except Exception as e:
  1079. logger.error(f"检查大纲任务终止信号失败: {str(e)}", exc_info=True)
  1080. return False
  1081. async def clear_outline_terminate_signal(self, callback_task_id: str):
  1082. """
  1083. 清理 Redis 中的大纲任务终止信号
  1084. Args:
  1085. callback_task_id: 任务回调ID
  1086. """
  1087. try:
  1088. redis_client = await RedisConnectionFactory.get_connection()
  1089. terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
  1090. await redis_client.delete(terminate_key)
  1091. logger.debug(f"清理大纲任务终止信号: {callback_task_id}")
  1092. except Exception as e:
  1093. logger.warning(f"清理大纲任务终止信号失败: {str(e)}")
  1094. async def get_outline_active_tasks(self) -> list:
  1095. """
  1096. 获取活跃的大纲生成任务列表
  1097. Returns:
  1098. list: 活跃任务信息列表
  1099. """
  1100. try:
  1101. active_tasks = []
  1102. current_time = time.time()
  1103. for task_id, sgbx_task_info in self.active_outline_tasks.items():
  1104. if sgbx_task_info.status == "processing":
  1105. task_dict = {
  1106. "callback_task_id": task_id,
  1107. "user_id": sgbx_task_info.user_id,
  1108. "project_name": sgbx_task_info.project_name,
  1109. "project_type": sgbx_task_info.project_type,
  1110. "status": sgbx_task_info.status,
  1111. "start_time": sgbx_task_info.start_time,
  1112. "running_duration": int(current_time - sgbx_task_info.start_time) if sgbx_task_info.start_time else 0
  1113. }
  1114. active_tasks.append(task_dict)
  1115. return active_tasks
  1116. except Exception as e:
  1117. logger.error(f"获取活跃大纲任务列表失败: {str(e)}", exc_info=True)
  1118. return []
  1119. async def get_outline_sgbx_task_info(self, callback_task_id: str) -> Optional[Dict]:
  1120. """
  1121. 获取大纲生成任务信息
  1122. Args:
  1123. callback_task_id: 任务回调ID
  1124. Returns:
  1125. Optional[Dict]: 任务信息字典,不存在返回 None
  1126. """
  1127. try:
  1128. # 优先从内存中的活跃任务获取
  1129. sgbx_task_info = self.active_outline_tasks.get(callback_task_id)
  1130. if sgbx_task_info:
  1131. current_time = time.time()
  1132. return {
  1133. "callback_task_id": callback_task_id,
  1134. "user_id": sgbx_task_info.user_id,
  1135. "project_name": sgbx_task_info.project_name,
  1136. "project_type": sgbx_task_info.project_type,
  1137. "status": sgbx_task_info.status,
  1138. "start_time": sgbx_task_info.start_time,
  1139. "running_duration": int(current_time - sgbx_task_info.start_time) if sgbx_task_info.start_time else 0,
  1140. "results": sgbx_task_info.results
  1141. }
  1142. # 如果内存中没有,从 Redis 读取(用于跨进程访问 Celery worker 的结果)
  1143. redis_client = await RedisConnectionFactory.get_connection()
  1144. result_key = f"{self._outline_result_prefix}{callback_task_id}"
  1145. result_data = await redis_client.hgetall(result_key)
  1146. if result_data:
  1147. # 解析 JSON 字符串
  1148. parsed_results = {}
  1149. for key in ["outline_structure", "key_points", "similar_cases", "similar_fragments", "knowledge_bases"]:
  1150. value = result_data.get(key)
  1151. if value and value != "":
  1152. try:
  1153. parsed_results[key] = json.loads(value)
  1154. except (json.JSONDecodeError, TypeError):
  1155. parsed_results[key] = None
  1156. else:
  1157. parsed_results[key] = None
  1158. # 映射状态
  1159. overall_status = result_data.get("overall_task_status", "unknown")
  1160. status_mapping = {
  1161. "completed": "completed",
  1162. "failed": "failed",
  1163. "terminated": "cancelled"
  1164. }
  1165. status = status_mapping.get(overall_status, overall_status)
  1166. return {
  1167. "callback_task_id": result_data.get("callback_task_id"),
  1168. "user_id": result_data.get("user_id"),
  1169. "project_name": result_data.get("project_name", ""),
  1170. "project_type": result_data.get("project_type", ""),
  1171. "status": status,
  1172. "start_time": None,
  1173. "running_duration": 0,
  1174. "results": {
  1175. "outline_structure": parsed_results.get("outline_structure"),
  1176. "key_points": parsed_results.get("key_points"),
  1177. "similar_cases": parsed_results.get("similar_cases"),
  1178. "similar_fragments": parsed_results.get("similar_fragments"),
  1179. "knowledge_bases": parsed_results.get("knowledge_bases"),
  1180. "error": result_data.get("error_message") or None
  1181. }
  1182. }
  1183. return None
  1184. except Exception as e:
  1185. logger.error(f"获取大纲任务信息失败: {str(e)}", exc_info=True)
  1186. return None