""" 基于LangGraph的工作流管理器 负责任务的创建、编排和执行,使用LangGraph进行状态管理 新增功能: - 任务终止管理 - 终止信号设置和检测 """ import asyncio import time import json from typing import Dict, Optional, Any from datetime import datetime from langgraph.graph import StateGraph, END from langchain_core.messages import BaseMessage, HumanMessage, AIMessage from foundation.observability.logger.loggering import review_logger as logger from foundation.observability.monitoring.time_statistics import track_execution_time from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory from .progress_manager import ProgressManager from .redis_duplicate_checker import RedisDuplicateChecker from .task_models import TaskFileInfo, TaskChain from ..construction_review.workflows import DocumentWorkflow, AIReviewWorkflow, ReportWorkflow from ..construction_review.workflows.types import TaskChainState class ProgressManagerRegistry: """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例""" _registry = {} # {callback_task_id: ProgressManager} @classmethod def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager): """注册ProgressManager实例""" cls._registry[callback_task_id] = progress_manager logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}") @classmethod def get_progress_manager(cls, callback_task_id: str) -> ProgressManager: """获取ProgressManager实例""" return cls._registry.get(callback_task_id) @classmethod def unregister_progress_manager(cls, callback_task_id: str): """注销ProgressManager实例""" if callback_task_id in cls._registry: del cls._registry[callback_task_id] logger.info(f"注销ProgressManager实例: {callback_task_id}") class WorkflowManager: """工作流管理器""" def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10): self.max_concurrent_docs = max_concurrent_docs self.max_concurrent_reviews = max_concurrent_reviews # 并发控制 self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs) self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews) # 服务组件 self.progress_manager = ProgressManager() self.redis_duplicate_checker = RedisDuplicateChecker() # 活跃任务跟踪 self.active_chains: Dict[str, TaskChain] = {} self._cleanup_task_started = False # 任务终止管理 self._terminate_signal_prefix = "ai_review:terminate_signal:" self._task_expire_time = 7200 # 2小时 # LangGraph 任务链工作流(方案D) self.task_chain_graph = None # 延迟初始化,避免循环导入 # ==================== 施工方案编写任务管理 ==================== # 大纲生成活跃任务跟踪 self.active_outline_tasks: Dict[str, Any] = {} # 大纲生成任务 Redis 前缀 self._outline_result_prefix = "outline_write:result:" self._outline_terminate_signal_prefix = "outline_write:terminate_signal:" # 大纲生成工作流图(延迟初始化) self.outline_generation_graph = None async def submit_task_processing(self, file_info: dict) -> str: """异步提交任务处理(用于file_upload层)""" from foundation.infrastructure.messaging.tasks import submit_task_processing_task from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager try: logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}") # 使用CeleryTraceManager提交任务,自动传递trace_id task = CeleryTraceManager.submit_celery_task( submit_task_processing_task, file_info ) logger.info(f"Celery任务已提交,Task ID: {task.id}") return task.id except Exception as e: logger.error(f"提交Celery任务失败: {str(e)}") raise @track_execution_time def submit_construction_review_task_processing_sync(self, file_info: dict) -> dict: """ 同步提交施工审查任务处理(用于Celery worker) Note: 已切换到 LangGraph 任务链工作流(方案D) 使用统一的状态管理和嵌套子图架构 """ try: logger.info(f"提交文档处理任务(LangGraph方案D): {file_info['file_id']}") # 1. 创建TaskFileInfo对象(封装任务文件信息) task_file_info = TaskFileInfo(file_info) logger.info(f"创建任务文件信息: {task_file_info}") # 2. 生成任务链ID callback_task_id = task_file_info.callback_task_id # 3. 创建任务链(引用 TaskFileInfo,避免数据重复) task_chain = TaskChain(task_file_info) # 4. 标记任务开始 task_chain.start_processing() # 5. 添加到活跃任务跟踪 self.active_chains[callback_task_id] = task_chain # 6. 初始化进度管理 asyncio.run(self.progress_manager.initialize_progress( callback_task_id=callback_task_id, user_id=task_file_info.user_id, stages=[] )) # 7. 构建 LangGraph 任务链工作流(延迟初始化) if self.task_chain_graph is None: self.task_chain_graph = self._build_task_chain_workflow() # 8. 构建初始状态 initial_state = TaskChainState( file_id=task_file_info.file_id, callback_task_id=callback_task_id, user_id=task_file_info.user_id, file_name=task_file_info.file_name, file_type=task_file_info.file_type, file_content=task_file_info.file_content, current_stage="start", overall_task_status="processing", stage_status={ "document": "pending", "ai_review": "pending", "report": "pending" }, document_result=None, ai_review_result=None, report_result=None, error_message=None, progress_manager=self.progress_manager, task_file_info=task_file_info, messages=[HumanMessage(content=f"开始任务链: {task_file_info.file_id}")] ) # 9. 执行 LangGraph 任务链工作流 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) result = loop.run_until_complete(self.task_chain_graph.ainvoke(initial_state)) loop.close() # 10. 清理任务注册 asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id)) logger.info(f"施工方案审查任务已完成(LangGraph方案D)!") logger.info(f"文件ID: {task_file_info.file_id}") logger.info(f"文件名: {task_file_info.file_name}") logger.info(f"整体状态: {result.get('overall_task_status', 'unknown')}") # 构建可序列化的返回结果(移除不可序列化的对象) serializable_result = { "file_id": result.get("file_id"), "callback_task_id": result.get("callback_task_id"), "user_id": result.get("user_id"), "file_name": result.get("file_name"), "current_stage": result.get("current_stage"), "overall_task_status": result.get("overall_task_status"), "stage_status": result.get("stage_status"), "error_message": result.get("error_message"), # 注意:不包含 progress_manager, task_file_info, messages 等不可序列化对象 } return serializable_result except Exception as e: logger.error(f"提交文档处理任务失败: {str(e)}", exc_info=True) # 标记任务失败 if callback_task_id in self.active_chains: self.active_chains[callback_task_id].fail_processing(str(e)) # 清理任务注册 asyncio.run(self.redis_duplicate_checker.unregister_task(task_file_info.file_id)) # 通知SSE连接任务失败 error_data = { "error": str(e), "status": "failed", "overall_task_status": "failed", "timestamp": datetime.now().isoformat() } asyncio.run(self.progress_manager.complete_task(callback_task_id, task_file_info.user_id, error_data)) raise finally: # 清理活跃任务 if callback_task_id in self.active_chains: del self.active_chains[callback_task_id] async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]: """ 设置任务终止信号 Args: callback_task_id: 任务回调ID operator: 操作人(用户ID或系统标识) Returns: Dict: 操作结果 {"success": bool, "message": str, "task_info": dict} Note: 将终止信号写入 Redis,支持跨进程检测 AI审查节点在执行前会检查此信号 """ try: # 检查任务是否在活跃列表中 if callback_task_id not in self.active_chains: return { "success": False, "message": f"任务不存在或已完成: {callback_task_id}", "task_info": None } task_chain = self.active_chains[callback_task_id] # 检查任务状态 if task_chain.status != "processing": return { "success": False, "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})", "task_info": { "callback_task_id": callback_task_id, "status": task_chain.status, "file_name": task_chain.file_name } } # 设置 Redis 终止信号 redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}" # 存储终止信号和操作人、时间 terminate_data = { "operator": operator, "terminate_time": str(time.time()), "task_id": callback_task_id } # 使用 hash 存储更多信息 await redis_client.hset(terminate_key, mapping=terminate_data) # 设置过期时间(2小时) await redis_client.expire(terminate_key, self._task_expire_time) logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})") return { "success": True, "message": f"终止信号已设置,任务将在当前节点完成后终止", "task_info": { "callback_task_id": callback_task_id, "file_id": task_chain.file_id, "file_name": task_chain.file_name, "user_id": task_chain.user_id, "status": task_chain.status, "current_stage": task_chain.current_stage } } except Exception as e: logger.error(f"设置终止信号失败: {str(e)}", exc_info=True) return { "success": False, "message": f"设置终止信号失败: {str(e)}", "task_info": None } async def check_terminate_signal(self, callback_task_id: str) -> bool: """ 检查是否有终止信号 Args: callback_task_id: 任务回调ID Returns: bool: 有终止信号返回 True Note: 从 Redis 读取终止信号 工作流节点在执行前调用此方法检查是否需要终止 """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}" # 检查键是否存在 exists = await redis_client.exists(terminate_key) if exists: # 读取终止信息 terminate_info = await redis_client.hgetall(terminate_key) logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}") return True return False except RuntimeError as e: # 事件循环相关的错误处理 error_msg = str(e) if "Event loop is closed" in error_msg: # 事件循环关闭是正常情况(任务结束),不记录错误 logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}") return False elif "bound to a different event loop" in error_msg: # 事件循环不匹配,记录警告但不中断流程 logger.warning(f"检查终止信号时检测到事件循环不匹配: {callback_task_id},将忽略本次检查") return False else: # 其他 RuntimeError 记录错误 logger.error(f"检查终止信号失败(RuntimeError): {error_msg}", exc_info=True) return False except Exception as e: # 其他异常仍然记录错误 logger.error(f"检查终止信号失败: {str(e)}", exc_info=True) return False async def clear_terminate_signal(self, callback_task_id: str): """ 清理 Redis 中的终止信号 Args: callback_task_id: 任务回调ID """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}" await redis_client.delete(terminate_key) logger.debug(f"清理终止信号: {callback_task_id}") except Exception as e: logger.warning(f"清理终止信号失败: {str(e)}") async def get_active_tasks(self) -> list: """ 获取活跃任务列表 Returns: list: 活跃任务信息列表 """ try: active_tasks = [] current_time = time.time() for task_id, task_chain in self.active_chains.items(): if task_chain.status == "processing": task_info = { "callback_task_id": task_id, "file_id": task_chain.file_id, "file_name": task_chain.file_name, "user_id": task_chain.user_id, "status": task_chain.status, "current_stage": task_chain.current_stage, "start_time": task_chain.start_time, "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0 } active_tasks.append(task_info) return active_tasks except Exception as e: logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True) return [] async def get_task_info(self, callback_task_id: str) -> Optional[Dict]: """ 获取任务信息 Args: callback_task_id: 任务回调ID Returns: Optional[Dict]: 任务信息字典,不存在返回 None """ try: task_chain = self.active_chains.get(callback_task_id) if task_chain: current_time = time.time() return { "callback_task_id": callback_task_id, "file_id": task_chain.file_id, "file_name": task_chain.file_name, "user_id": task_chain.user_id, "status": task_chain.status, "current_stage": task_chain.current_stage, "start_time": task_chain.start_time, "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0, "results": task_chain.results } return None except Exception as e: logger.error(f"获取任务信息失败: {str(e)}", exc_info=True) return None def _build_task_chain_workflow(self) -> StateGraph: """ 构建 LangGraph 任务链工作流图(方案D) Returns: StateGraph: 配置完成的 LangGraph 任务链图实例 Note: 创建包含文档处理、AI审查(嵌套子图)、报告生成的完整任务链 设置节点间的转换关系和条件边,支持终止检查和错误处理 工作流路径: start → document_processing → ai_review_subgraph → report_generation → complete → END """ logger.info("开始构建 LangGraph 任务链工作流图") workflow = StateGraph(TaskChainState) # 添加节点 workflow.add_node("start", self._start_chain_node) workflow.add_node("document_processing", self._document_processing_node) workflow.add_node("ai_review_subgraph", self._ai_review_subgraph_node) workflow.add_node("report_generation", self._report_generation_node) workflow.add_node("complete", self._complete_chain_node) workflow.add_node("error_handler", self._error_handler_chain_node) workflow.add_node("terminate", self._terminate_chain_node) # 设置入口点 workflow.set_entry_point("start") # 添加边和条件边 workflow.add_edge("start", "document_processing") # 文档处理后检查终止信号 workflow.add_conditional_edges( "document_processing", self._should_terminate_or_error_chain, { "terminate": "terminate", "error": "error_handler", "continue": "ai_review_subgraph" } ) # AI审查后检查终止信号 workflow.add_conditional_edges( "ai_review_subgraph", self._should_terminate_or_error_chain, { "terminate": "terminate", "error": "error_handler", "continue": "report_generation" } ) # 报告生成后检查终止信号 workflow.add_conditional_edges( "report_generation", self._should_terminate_or_error_chain, { "terminate": "terminate", "error": "error_handler", "continue": "complete" } ) # 完成节点直接结束 workflow.add_edge("complete", END) workflow.add_edge("error_handler", END) workflow.add_edge("terminate", END) # 编译工作流图 compiled_graph = workflow.compile() logger.info("LangGraph 任务链工作流图构建完成") return compiled_graph async def _start_chain_node(self, state: TaskChainState) -> TaskChainState: """ 任务链开始节点 Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态 """ logger.info(f"任务链工作流启动: {state['callback_task_id']}") return { "current_stage": "start", "overall_task_status": "processing", "stage_status": { "document": "pending", "ai_review": "pending", "report": "pending" }, "messages": [AIMessage(content="任务链工作流启动")] } async def _document_processing_node(self, state: TaskChainState) -> TaskChainState: """ 文档处理节点 Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态,包含文档处理结果 """ try: logger.info(f"开始文档处理阶段: {state['callback_task_id']}") # 检查终止信号 if await self.check_terminate_signal(state["callback_task_id"]): logger.warning(f"文档处理阶段检测到终止信号: {state['callback_task_id']}") return { "current_stage": "document_processing", "overall_task_status": "terminated", "stage_status": {**state["stage_status"], "document": "terminated"}, "messages": [AIMessage(content="文档处理阶段检测到终止信号")] } # 获取 TaskFileInfo 实例 task_file_info = state["task_file_info"] # 创建文档工作流实例 document_workflow = DocumentWorkflow( task_file_info=task_file_info, progress_manager=state["progress_manager"], redis_duplicate_checker=self.redis_duplicate_checker ) # 执行文档处理 doc_result = await document_workflow.execute( state["file_content"], state["file_type"] ) logger.info(f"文档处理完成: {state['callback_task_id']}") return { "current_stage": "document_processing", "overall_task_status": "processing", "stage_status": {**state["stage_status"], "document": "completed"}, "document_result": doc_result, "messages": [AIMessage(content="文档处理完成")] } except Exception as e: logger.error(f"文档处理失败: {str(e)}", exc_info=True) return { "current_stage": "document_processing", "overall_task_status": "failed", "stage_status": {**state["stage_status"], "document": "failed"}, "error_message": f"文档处理失败: {str(e)}", "messages": [AIMessage(content=f"文档处理失败: {str(e)}")] } async def _ai_review_subgraph_node(self, state: TaskChainState) -> TaskChainState: """ AI审查子图节点(嵌套现有的 AIReviewWorkflow) Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态,包含AI审查结果 Note: 这是方案D的核心实现:将现有的 AIReviewWorkflow 作为子图嵌套 无需修改 AIReviewWorkflow 的代码,保持其独立性 """ try: logger.info(f"开始AI审查阶段: {state['callback_task_id']}") # 检查终止信号 if await self.check_terminate_signal(state["callback_task_id"]): logger.warning(f"AI审查阶段检测到终止信号: {state['callback_task_id']}") return { "current_stage": "ai_review", "overall_task_status": "terminated", "stage_status": {**state["stage_status"], "ai_review": "terminated"}, "messages": [AIMessage(content="AI审查阶段检测到终止信号")] } # 获取文档处理结果中的结构化内容 structured_content = state["document_result"].get("structured_content") if not structured_content: raise ValueError("文档处理结果中缺少结构化内容") # 获取 TaskFileInfo 实例 task_file_info = state["task_file_info"] # 读取AI审查配置 import configparser config = configparser.ConfigParser() config.read('config/config.ini', encoding='utf-8') max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None) if max_review_units == 0: max_review_units = None review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all') logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}") # 创建AI审查工作流实例(作为嵌套子图) ai_workflow = AIReviewWorkflow( task_file_info=task_file_info, structured_content=structured_content, progress_manager=state["progress_manager"], max_review_units=max_review_units, review_mode=review_mode ) # 执行AI审查(内部使用 LangGraph) ai_result = await ai_workflow.execute() logger.info(f"AI审查完成: {state['callback_task_id']}") return { "current_stage": "ai_review", "overall_task_status": "processing", "stage_status": {**state["stage_status"], "ai_review": "completed"}, "ai_review_result": ai_result, "messages": [AIMessage(content="AI审查完成")] } except Exception as e: logger.error(f"AI审查失败: {str(e)}", exc_info=True) return { "current_stage": "ai_review", "overall_task_status": "failed", "stage_status": {**state["stage_status"], "ai_review": "failed"}, "error_message": f"AI审查失败: {str(e)}", "messages": [AIMessage(content=f"AI审查失败: {str(e)}")] } async def _report_generation_node(self, state: TaskChainState) -> TaskChainState: """ 报告生成节点 Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态,包含报告生成结果 Note: 调用ReportWorkflow生成审查报告摘要(基于高中风险问题,使用LLM) 根据决策2(方案A-方式1),在此阶段生成完整报告后一次性保存 """ try: logger.info(f"开始报告生成阶段: {state['callback_task_id']}") # 检查终止信号 if await self.check_terminate_signal(state["callback_task_id"]): logger.warning(f"报告生成阶段检测到终止信号: {state['callback_task_id']}") return { "current_stage": "report_generation", "overall_task_status": "terminated", "stage_status": {**state["stage_status"], "report": "terminated"}, "messages": [AIMessage(content="报告生成阶段检测到终止信号")] } # 获取AI审查结果 ai_review_result = state.get("ai_review_result") if not ai_review_result: raise ValueError("AI审查结果缺失,无法生成报告") # 获取 TaskFileInfo 实例 task_file_info = state["task_file_info"] # 创建报告生成工作流实例 report_workflow = ReportWorkflow( file_id=state["file_id"], file_name=state["file_name"], callback_task_id=state["callback_task_id"], user_id=state["user_id"], ai_review_results=ai_review_result, progress_manager=state["progress_manager"] ) # 执行报告生成 report_result = await report_workflow.execute() logger.info(f"报告生成完成: {state['callback_task_id']}") # 保存完整结果(包含文档处理、AI审查、报告生成) await self._save_complete_results(state, report_result) return { "current_stage": "report_generation", "overall_task_status": "processing", "stage_status": {**state["stage_status"], "report": "completed"}, "report_result": report_result, "messages": [AIMessage(content="报告生成完成")] } except Exception as e: logger.error(f"报告生成失败: {str(e)}", exc_info=True) return { "current_stage": "report_generation", "overall_task_status": "failed", "stage_status": {**state["stage_status"], "report": "failed"}, "error_message": f"报告生成失败: {str(e)}", "messages": [AIMessage(content=f"报告生成失败: {str(e)}")] } async def _complete_chain_node(self, state: TaskChainState) -> TaskChainState: """ 任务链完成节点 Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态,标记整体任务已完成 Note: 只有在所有阶段(文档处理、AI审查、报告生成)都完成后才标记 overall_task_status="completed" 这解决了原有的状态语义混乱问题(P0-1) """ logger.info(f"任务链工作流完成: {state['callback_task_id']}") # 标记整体任务完成 if state["progress_manager"]: await state["progress_manager"].complete_task( state["callback_task_id"], state["user_id"], {"overall_task_status": "completed", "message": "所有阶段已完成"} ) # 清理 Redis 缓存 try: from foundation.utils.redis_utils import delete_file_info await delete_file_info(state["file_id"]) logger.info(f"已清理 Redis 文件缓存: {state['file_id']}") except Exception as e: logger.warning(f"清理 Redis 文件缓存失败: {str(e)}") return { "current_stage": "complete", "overall_task_status": "completed", # ⚠️ 关键:只有到这里才标记整体完成 "messages": [AIMessage(content="任务链工作流完成")] } async def _error_handler_chain_node(self, state: TaskChainState) -> TaskChainState: """ 任务链错误处理节点 Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态,标记为失败 """ logger.error(f"任务链工作流错误: {state['callback_task_id']}, 错误: {state.get('error_message', '未知错误')}") # 通知失败 if state["progress_manager"]: error_data = { "overall_task_status": "failed", "error": state.get("error_message", "未知错误"), "status": "failed", "timestamp": datetime.now().isoformat() } await state["progress_manager"].complete_task( state["callback_task_id"], state["user_id"], error_data ) # 清理 Redis 缓存(即使失败也清理) try: from foundation.utils.redis_utils import delete_file_info await delete_file_info(state["file_id"]) logger.info(f"已清理 Redis 文件缓存: {state['file_id']}") except Exception as e: logger.warning(f"清理 Redis 文件缓存失败: {str(e)}") return { "current_stage": "error_handler", "overall_task_status": "failed", "messages": [AIMessage(content=f"任务链错误: {state.get('error_message', '未知错误')}")] } async def _terminate_chain_node(self, state: TaskChainState) -> TaskChainState: """ 任务链终止节点 Args: state: 任务链状态 Returns: TaskChainState: 更新后的状态,标记为已终止 """ logger.warning(f"任务链工作流已终止: {state['callback_task_id']}") # 通知终止 if state["progress_manager"]: await state["progress_manager"].complete_task( state["callback_task_id"], state["user_id"], {"overall_task_status": "terminated", "message": "任务已被用户终止"} ) # 清理 Redis 终止信号 await self.clear_terminate_signal(state["callback_task_id"]) # 清理 Redis 文件缓存 try: from foundation.utils.redis_utils import delete_file_info await delete_file_info(state["file_id"]) logger.info(f"已清理 Redis 文件缓存: {state['file_id']}") except Exception as e: logger.warning(f"清理 Redis 文件缓存失败: {str(e)}") return { "current_stage": "terminated", "overall_task_status": "terminated", "messages": [AIMessage(content="任务链已被终止")] } def _should_terminate_or_error_chain(self, state: TaskChainState) -> str: """ 检查任务链是否应该终止或发生错误 Args: state: 任务链状态 Returns: str: "terminate", "error", 或 "continue" Note: 这是条件边判断方法,用于决定工作流的下一步走向 1. 优先检查终止信号 2. 检查是否有错误 3. 都没有则继续执行 """ # 检查终止状态 if state.get("overall_task_status") == "terminated": return "terminate" # 检查错误状态 if state.get("overall_task_status") == "failed" or state.get("error_message"): return "error" # 默认继续执行 return "continue" async def _save_complete_results(self, state: TaskChainState, report_result: Dict[str, Any]): """ 保存完整结果(方案A-方式1:一次性保存) Args: state: 任务链状态 report_result: 报告生成结果 Note: 根据决策2(方案A-方式1),在报告工作流完成后一次性保存完整结果 包含:文档处理结果 + AI审查结果 + 报告生成结果 """ try: import json import os logger.info(f"开始保存完整结果: {state['callback_task_id']}") # 创建 temp 目录 temp_dir = os.path.join("temp", "construction_review", "final_result") os.makedirs(temp_dir, exist_ok=True) # 构建完整结果 complete_results = { "callback_task_id": state["callback_task_id"], "file_id": state["file_id"], "file_name": state["file_name"], "user_id": state["user_id"], "overall_task_status": "processing", # 此时还在处理中,complete节点才标记completed "stage_status": state["stage_status"], "document_result": state.get("document_result"), "ai_review_result": state.get("ai_review_result"), "issues": state.get("ai_review_result").get("review_results"), "report_result": report_result, "timestamp": datetime.now().isoformat() } # 保存到文件 file_path = os.path.join(temp_dir, f"{state['callback_task_id']}.json") with open(file_path, 'w', encoding='utf-8') as f: json.dump(complete_results, f, ensure_ascii=False, indent=2) logger.info(f"完整结果已保存到: {file_path}") except Exception as e: logger.error(f"保存完整结果失败: {str(e)}", exc_info=True) raise # ==================== 施工方案编写任务管理方法 ==================== async def submit_outline_generation_task(self, task_info: dict) -> str: """ 提交大纲生成任务到 Celery Args: task_info: 任务信息字典 { "user_id": str, "project_info": dict, "template_id": str, "outline_config": dict, "similarity_config": dict (可选), "knowledge_config": dict (可选) } Returns: str: Celery 任务 ID """ from foundation.infrastructure.messaging.tasks import submit_outline_generation_task from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager try: logger.info(f"提交大纲生成任务到Celery: user_id={task_info.get('user_id')}") # 使用 CeleryTraceManager 提交任务,自动传递 trace_id task = CeleryTraceManager.submit_celery_task( submit_outline_generation_task, task_info ) logger.info(f"大纲生成Celery任务已提交,Task ID: {task.id}") return task.id except Exception as e: logger.error(f"提交大纲生成Celery任务失败: {str(e)}") raise @track_execution_time def submit_outline_generation_sync(self, task_info: dict) -> dict: """ 同步执行大纲生成任务(用于 Celery worker) Args: task_info: 任务信息字典 Returns: dict: 执行结果 """ import uuid from langchain_core.messages import HumanMessage from ..construction_write.component.state_models import OutlineGenerationState, OutlineTaskInfo from ..construction_write.workflows.outline_workflow import OutlineWorkflow callback_task_id = None try: logger.info(f"开始执行大纲生成任务(LangGraph)") # 1. 生成任务 ID(如果没有提供) callback_task_id = task_info.get('callback_task_id') or f"outline_{uuid.uuid4().hex[:16]}" user_id = task_info.get('user_id', 'unknown') # 2. 创建任务信息对象 outline_task_info = OutlineTaskInfo( callback_task_id=callback_task_id, user_id=user_id, project_info=task_info.get('project_info', {}), template_id=task_info.get('template_id', ''), outline_config=task_info.get('outline_config', {}), similarity_config=task_info.get('similarity_config', {}), knowledge_config=task_info.get('knowledge_config', {}) ) # 3. 添加到活跃任务跟踪 self.active_outline_tasks[callback_task_id] = outline_task_info # 4. 初始化进度管理 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self.progress_manager.initialize_progress( callback_task_id=callback_task_id, user_id=user_id, stages=[ {"stage": "start", "status": "pending"}, {"stage": "template_loading", "status": "pending"}, {"stage": "outline_generation", "status": "pending"}, {"stage": "similar_cases", "status": "pending"}, {"stage": "similar_fragments", "status": "pending"}, {"stage": "knowledge_bases", "status": "pending"}, {"stage": "complete", "status": "pending"} ] )) # 4.1 注册 ProgressManager 到 Registry(供节点访问) ProgressManagerRegistry.register_progress_manager(callback_task_id, self.progress_manager) # 4.2 标记任务开始 outline_task_info.start_processing() # 5. 构建 LangGraph 大纲生成工作流(延迟初始化) if self.outline_generation_graph is None: outline_workflow = OutlineWorkflow() self.outline_generation_graph = outline_workflow.build_graph() # 6. 构建初始状态 # 注意:progress_manager 和 task_info 不能放入状态(不可序列化) # 它们通过类实例变量访问 initial_state = OutlineGenerationState( callback_task_id=callback_task_id, user_id=user_id, project_info=outline_task_info.project_info, template_id=outline_task_info.template_id, outline_config=outline_task_info.outline_config, similarity_config=outline_task_info.similarity_config, knowledge_config=outline_task_info.knowledge_config, template=None, outline_structure=None, key_points=None, similar_cases=None, similar_fragments=None, knowledge_bases=None, current_stage="start", overall_task_status="processing", error_message=None, messages=[HumanMessage(content=f"开始大纲生成任务: {callback_task_id}")] ) # 7. 执行 LangGraph 工作流 # 需要提供 config 参数给 Checkpointer result = loop.run_until_complete( self.outline_generation_graph.ainvoke( initial_state, config={"configurable": {"thread_id": callback_task_id}} ) ) loop.close() logger.info(f"大纲生成任务完成!callback_task_id={callback_task_id}") # 8. 更新任务状态 if result.get("overall_task_status") == "completed": outline_task_info.complete_processing({ "outline_structure": result.get("outline_structure"), "key_points": result.get("key_points"), "similar_cases": result.get("similar_cases"), "similar_fragments": result.get("similar_fragments"), "knowledge_bases": result.get("knowledge_bases") }) elif result.get("overall_task_status") == "failed": outline_task_info.fail_processing(result.get("error_message", "未知错误")) elif result.get("overall_task_status") == "terminated": outline_task_info.cancel_processing() # 8.5 将任务结果保存到 Redis(供跨进程访问) async def save_result_to_redis(): redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" # 构建结果数据(过滤 None 值,Redis 不支持) result_data = { "callback_task_id": callback_task_id, "user_id": user_id, "overall_task_status": result.get("overall_task_status", ""), "outline_structure": json.dumps(result.get("outline_structure"), ensure_ascii=False) if result.get("outline_structure") else "", "key_points": json.dumps(result.get("key_points"), ensure_ascii=False) if result.get("key_points") else "", "similar_cases": json.dumps(result.get("similar_cases"), ensure_ascii=False) if result.get("similar_cases") else "", "similar_fragments": json.dumps(result.get("similar_fragments"), ensure_ascii=False) if result.get("similar_fragments") else "", "knowledge_bases": json.dumps(result.get("knowledge_bases"), ensure_ascii=False) if result.get("knowledge_bases") else "", "error_message": result.get("error_message") or "", "completed_time": str(time.time()) } # 保存到 Redis(设置过期时间2小时) await redis_client.hmset(result_key, result_data) await redis_client.expire(result_key, self._task_expire_time) logger.info(f"大纲生成结果已保存到 Redis: {callback_task_id}") # 在同步函数中运行异步代码 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(save_result_to_redis()) finally: loop.close() # 9. 返回可序列化结果 return { "callback_task_id": result.get("callback_task_id"), "user_id": result.get("user_id"), "overall_task_status": result.get("overall_task_status"), "outline_structure": result.get("outline_structure"), "key_points": result.get("key_points"), "similar_cases": result.get("similar_cases"), "similar_fragments": result.get("similar_fragments"), "knowledge_bases": result.get("knowledge_bases"), "error_message": result.get("error_message") } except Exception as e: logger.error(f"大纲生成任务失败: {str(e)}", exc_info=True) # 标记任务失败 if callback_task_id and callback_task_id in self.active_outline_tasks: self.active_outline_tasks[callback_task_id].fail_processing(str(e)) raise finally: # 清理活跃任务 if callback_task_id and callback_task_id in self.active_outline_tasks: del self.active_outline_tasks[callback_task_id] # 清理 Registry ProgressManagerRegistry.unregister_progress_manager(callback_task_id) async def set_outline_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]: """ 设置大纲生成任务终止信号 Args: callback_task_id: 任务回调ID operator: 操作人 Returns: Dict: 操作结果 """ try: # 检查任务是否在活跃列表中 if callback_task_id not in self.active_outline_tasks: return { "success": False, "message": f"任务不存在或已完成: {callback_task_id}", "task_info": None } task_info = self.active_outline_tasks[callback_task_id] # 检查任务状态 if task_info.status != "processing": return { "success": False, "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_info.status})", "task_info": { "callback_task_id": callback_task_id, "status": task_info.status, "project_name": task_info.project_name } } # 设置 Redis 终止信号 redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}" # 存储终止信号和操作人、时间 terminate_data = { "operator": operator, "terminate_time": str(time.time()), "task_id": callback_task_id } # 使用 hash 存储更多信息 await redis_client.hset(terminate_key, mapping=terminate_data) # 设置过期时间(2小时) await redis_client.expire(terminate_key, self._task_expire_time) logger.info(f"已设置大纲任务终止信号: {callback_task_id} (操作人: {operator}, 项目: {task_info.project_name})") return { "success": True, "message": f"终止信号已设置,任务将在当前节点完成后终止", "task_info": { "callback_task_id": callback_task_id, "user_id": task_info.user_id, "project_name": task_info.project_name, "status": task_info.status } } except Exception as e: logger.error(f"设置大纲任务终止信号失败: {str(e)}", exc_info=True) return { "success": False, "message": f"设置终止信号失败: {str(e)}", "task_info": None } async def check_outline_terminate_signal(self, callback_task_id: str) -> bool: """ 检查大纲生成任务是否有终止信号 Args: callback_task_id: 任务回调ID Returns: bool: 有终止信号返回 True """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}" # 检查键是否存在 exists = await redis_client.exists(terminate_key) if exists: # 读取终止信息 terminate_info = await redis_client.hgetall(terminate_key) logger.warning(f"检测到大纲任务终止信号: {callback_task_id}, " f"操作人: {terminate_info.get(b'operator', b'unknown').decode()}") return True return False except Exception as e: logger.error(f"检查大纲任务终止信号失败: {str(e)}", exc_info=True) return False async def clear_outline_terminate_signal(self, callback_task_id: str): """ 清理 Redis 中的大纲任务终止信号 Args: callback_task_id: 任务回调ID """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}" await redis_client.delete(terminate_key) logger.debug(f"清理大纲任务终止信号: {callback_task_id}") except Exception as e: logger.warning(f"清理大纲任务终止信号失败: {str(e)}") async def get_outline_active_tasks(self) -> list: """ 获取活跃的大纲生成任务列表 Returns: list: 活跃任务信息列表 """ try: active_tasks = [] current_time = time.time() for task_id, task_info in self.active_outline_tasks.items(): if task_info.status == "processing": task_dict = { "callback_task_id": task_id, "user_id": task_info.user_id, "project_name": task_info.project_name, "project_type": task_info.project_type, "status": task_info.status, "start_time": task_info.start_time, "running_duration": int(current_time - task_info.start_time) if task_info.start_time else 0 } active_tasks.append(task_dict) return active_tasks except Exception as e: logger.error(f"获取活跃大纲任务列表失败: {str(e)}", exc_info=True) return [] async def get_outline_task_info(self, callback_task_id: str) -> Optional[Dict]: """ 获取大纲生成任务信息 Args: callback_task_id: 任务回调ID Returns: Optional[Dict]: 任务信息字典,不存在返回 None """ try: # 优先从内存中的活跃任务获取 task_info = self.active_outline_tasks.get(callback_task_id) if task_info: current_time = time.time() return { "callback_task_id": callback_task_id, "user_id": task_info.user_id, "project_name": task_info.project_name, "project_type": task_info.project_type, "status": task_info.status, "start_time": task_info.start_time, "running_duration": int(current_time - task_info.start_time) if task_info.start_time else 0, "results": task_info.results } # 如果内存中没有,从 Redis 读取(用于跨进程访问 Celery worker 的结果) redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" result_data = await redis_client.hgetall(result_key) if result_data: # 解析 JSON 字符串 parsed_results = {} for key in ["outline_structure", "key_points", "similar_cases", "similar_fragments", "knowledge_bases"]: value = result_data.get(key) if value and value != "": try: parsed_results[key] = json.loads(value) except (json.JSONDecodeError, TypeError): parsed_results[key] = None else: parsed_results[key] = None # 映射状态 overall_status = result_data.get("overall_task_status", "unknown") status_mapping = { "completed": "completed", "failed": "failed", "terminated": "cancelled" } status = status_mapping.get(overall_status, overall_status) return { "callback_task_id": result_data.get("callback_task_id"), "user_id": result_data.get("user_id"), "project_name": result_data.get("project_name", ""), "project_type": result_data.get("project_type", ""), "status": status, "start_time": None, "running_duration": 0, "results": { "outline_structure": parsed_results.get("outline_structure"), "key_points": parsed_results.get("key_points"), "similar_cases": parsed_results.get("similar_cases"), "similar_fragments": parsed_results.get("similar_fragments"), "knowledge_bases": parsed_results.get("knowledge_bases"), "error": result_data.get("error_message") or None } } return None except Exception as e: logger.error(f"获取大纲任务信息失败: {str(e)}", exc_info=True) return None