""" 基于LangGraph的工作流管理器 负责任务的创建、编排和执行,使用LangGraph进行状态管理 新增功能: - 任务终止管理 - 终止信号设置和检测 """ import asyncio import time from typing import Dict, Optional from datetime import datetime from foundation.observability.logger.loggering import server_logger as logger from foundation.observability.monitoring.time_statistics import track_execution_time from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory from .progress_manager import ProgressManager from .redis_duplicate_checker import RedisDuplicateChecker from .task_models import TaskFileInfo, TaskChain from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow class ProgressManagerRegistry: """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例""" _registry = {} # {callback_task_id: ProgressManager} @classmethod def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager): """注册ProgressManager实例""" cls._registry[callback_task_id] = progress_manager logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}") @classmethod def get_progress_manager(cls, callback_task_id: str) -> ProgressManager: """获取ProgressManager实例""" return cls._registry.get(callback_task_id) @classmethod def unregister_progress_manager(cls, callback_task_id: str): """注销ProgressManager实例""" if callback_task_id in cls._registry: del cls._registry[callback_task_id] logger.info(f"注销ProgressManager实例: {callback_task_id}") class WorkflowManager: """工作流管理器""" def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10): self.max_concurrent_docs = max_concurrent_docs self.max_concurrent_reviews = max_concurrent_reviews # 并发控制 self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs) self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews) # 服务组件 self.progress_manager = ProgressManager() self.redis_duplicate_checker = RedisDuplicateChecker() # 活跃任务跟踪 self.active_chains: Dict[str, TaskChain] = {} self._cleanup_task_started = False # 任务终止管理 self._terminate_signal_prefix = "ai_review:terminate_signal:" self._task_expire_time = 7200 # 2小时 async def submit_task_processing(self, file_info: dict) -> str: """异步提交任务处理(用于file_upload层)""" from foundation.infrastructure.messaging.tasks import submit_task_processing_task from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager try: logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}") # 使用CeleryTraceManager提交任务,自动传递trace_id task = CeleryTraceManager.submit_celery_task( submit_task_processing_task, file_info ) logger.info(f"Celery任务已提交,Task ID: {task.id}") return task.id except Exception as e: logger.error(f"提交Celery任务失败: {str(e)}") raise @track_execution_time def submit_task_processing_sync(self, file_info: dict) -> dict: """同步提交任务处理(用于Celery worker)""" try: logger.info(f"提交文档处理任务: {file_info['file_id']}") # 1. 创建TaskFileInfo对象(封装任务文件信息) task_file_info = TaskFileInfo(file_info) logger.info(f"创建任务文件信息: {task_file_info}") # 2. 生成任务链ID callback_task_id = task_file_info.callback_task_id # 3. 创建任务链(引用 TaskFileInfo,避免数据重复) task_chain = TaskChain(task_file_info) # 4. 标记任务开始 task_chain.start_processing() # 5. 添加到活跃任务跟踪 self.active_chains[callback_task_id] = task_chain # 5. 初始化进度管理 asyncio.run(self.progress_manager.initialize_progress( callback_task_id=callback_task_id, user_id=task_file_info.user_id, stages=[] )) # 6. 启动处理流程(同步执行) self._process_task_chain_sync(task_chain, task_file_info, task_file_info.file_type) # logger.info(f"提交文档处理任务: {callback_task_id}") logger.info(f"施工方案审查任务已完成! ") logger.info(f"文件ID: {task_file_info.file_id}") logger.info(f"文件名:{task_file_info.file_name}") except Exception as e: logger.error(f"提交文档处理任务失败: {str(e)}") raise def _process_task_chain_sync(self, task_chain: TaskChain, task_file_info: TaskFileInfo, file_type: str): """同步处理文档任务链(用于Celery worker)""" try: file_content = task_file_info.file_content # 阶段1:文档处理(串行) document_workflow = DocumentWorkflow( task_file_info=task_file_info, progress_manager=self.progress_manager, redis_duplicate_checker=self.redis_duplicate_checker ) # 同步执行文档处理 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) doc_result = loop.run_until_complete(document_workflow.execute(file_content, file_type)) loop.close() task_chain.results['document'] = doc_result # 阶段2:AI审查(内部并发) task_chain.update_stage("ai_review") structured_content = doc_result['structured_content'] # 读取AI审查配置 import configparser config = configparser.ConfigParser() config.read('config/config.ini', encoding='utf-8') max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None) if max_review_units == 0: # 如果配置为0,表示审查所有 max_review_units = None review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all') logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}") ai_workflow = AIReviewWorkflow( task_file_info=task_file_info, structured_content=structured_content, progress_manager=self.progress_manager, max_review_units=max_review_units, review_mode=review_mode ) # 同步执行AI审查 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) ai_result = loop.run_until_complete(ai_workflow.execute()) loop.close() task_chain.results['ai_review'] = ai_result # # 阶段3:报告生成(串行) # task_chain.current_stage = "report_generation" # report_workflow = ReportWorkflow( # file_id=task_chain.file_id, # callback_task_id=task_chain.callback_task_id, # user_id=task_chain.user_id, # ai_review_results=ai_result, # progress_manager=self.progress_manager # ) # # 同步执行报告生成 # loop = asyncio.new_event_loop() # asyncio.set_event_loop(loop) # report_result = loop.run_until_complete(report_workflow.execute()) # loop.close() # task_chain.results['report'] = report_result # 完成任务链 task_chain.complete_processing() # 清理任务注册 asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id)) # 通知SSE连接任务完成 asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id, task_chain.user_id)) # 清理Redis文件缓存 try: from foundation.utils.redis_utils import delete_file_info asyncio.run(delete_file_info(task_chain.file_id)) logger.info(f"已清理Redis文件缓存: {task_chain.file_id}") except Exception as e: logger.warning(f"清理Redis文件缓存失败: {str(e)}") logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}") return task_chain.results except Exception as e: # 标记任务失败 task_chain.fail_processing(str(e)) logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}") # 清理任务注册 asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id)) # 清理Redis文件缓存(即使失败也清理) try: from foundation.utils.redis_utils import delete_file_info asyncio.run(delete_file_info(task_chain.file_id)) logger.info(f"已清理Redis文件缓存: {task_chain.file_id}") except Exception as cleanup_error: logger.warning(f"清理Redis文件缓存失败: {str(cleanup_error)}") # 通知SSE连接任务失败 error_result = { "error": str(e), "status": "failed", "timestamp": datetime.now().isoformat() } current_data = { "status": "failed", "result": error_result } asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id, task_chain.user_id, current_data)) raise finally: # 清理活跃任务 if task_chain.callback_task_id in self.active_chains: del self.active_chains[task_chain.callback_task_id] # ==================== 任务终止管理方法 ==================== async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]: """ 设置任务终止信号 Args: callback_task_id: 任务回调ID operator: 操作人(用户ID或系统标识) Returns: Dict: 操作结果 {"success": bool, "message": str, "task_info": dict} Note: 将终止信号写入 Redis,支持跨进程检测 AI审查节点在执行前会检查此信号 """ try: # 检查任务是否在活跃列表中 if callback_task_id not in self.active_chains: return { "success": False, "message": f"任务不存在或已完成: {callback_task_id}", "task_info": None } task_chain = self.active_chains[callback_task_id] # 检查任务状态 if task_chain.status != "processing": return { "success": False, "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})", "task_info": { "callback_task_id": callback_task_id, "status": task_chain.status, "file_name": task_chain.file_name } } # 设置 Redis 终止信号 redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}" # 存储终止信号和操作人、时间 terminate_data = { "operator": operator, "terminate_time": str(time.time()), "task_id": callback_task_id } # 使用 hash 存储更多信息 await redis_client.hset(terminate_key, mapping=terminate_data) # 设置过期时间(2小时) await redis_client.expire(terminate_key, self._task_expire_time) logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})") return { "success": True, "message": f"终止信号已设置,任务将在当前节点完成后终止", "task_info": { "callback_task_id": callback_task_id, "file_id": task_chain.file_id, "file_name": task_chain.file_name, "user_id": task_chain.user_id, "status": task_chain.status, "current_stage": task_chain.current_stage } } except Exception as e: logger.error(f"设置终止信号失败: {str(e)}", exc_info=True) return { "success": False, "message": f"设置终止信号失败: {str(e)}", "task_info": None } async def check_terminate_signal(self, callback_task_id: str) -> bool: """ 检查是否有终止信号 Args: callback_task_id: 任务回调ID Returns: bool: 有终止信号返回 True Note: 从 Redis 读取终止信号 工作流节点在执行前调用此方法检查是否需要终止 """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}" # 检查键是否存在 exists = await redis_client.exists(terminate_key) if exists: # 读取终止信息 terminate_info = await redis_client.hgetall(terminate_key) logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}") return True return False except RuntimeError as e: # 事件循环关闭是正常情况(任务结束),不记录错误 if "Event loop is closed" in str(e): logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}") return False else: logger.error(f"检查终止信号失败: {str(e)}", exc_info=True) return False except Exception as e: # 其他异常仍然记录错误 logger.error(f"检查终止信号失败: {str(e)}", exc_info=True) return False async def clear_terminate_signal(self, callback_task_id: str): """ 清理 Redis 中的终止信号 Args: callback_task_id: 任务回调ID """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}" await redis_client.delete(terminate_key) logger.debug(f"清理终止信号: {callback_task_id}") except Exception as e: logger.warning(f"清理终止信号失败: {str(e)}") async def get_active_tasks(self) -> list: """ 获取活跃任务列表 Returns: list: 活跃任务信息列表 """ try: active_tasks = [] current_time = time.time() for task_id, task_chain in self.active_chains.items(): if task_chain.status == "processing": task_info = { "callback_task_id": task_id, "file_id": task_chain.file_id, "file_name": task_chain.file_name, "user_id": task_chain.user_id, "status": task_chain.status, "current_stage": task_chain.current_stage, "start_time": task_chain.start_time, "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0 } active_tasks.append(task_info) return active_tasks except Exception as e: logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True) return [] async def get_task_info(self, callback_task_id: str) -> Optional[Dict]: """ 获取任务信息 Args: callback_task_id: 任务回调ID Returns: Optional[Dict]: 任务信息字典,不存在返回 None """ try: task_chain = self.active_chains.get(callback_task_id) if task_chain: current_time = time.time() return { "callback_task_id": callback_task_id, "file_id": task_chain.file_id, "file_name": task_chain.file_name, "user_id": task_chain.user_id, "status": task_chain.status, "current_stage": task_chain.current_stage, "start_time": task_chain.start_time, "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0, "results": task_chain.results } return None except Exception as e: logger.error(f"获取任务信息失败: {str(e)}", exc_info=True) return None