""" 基于LangGraph的工作流管理器 负责任务的创建、编排和执行,使用LangGraph进行状态管理 """ import asyncio import uuid from typing import Dict, Optional, TypedDict, Annotated, List from datetime import datetime from dataclasses import dataclass from langgraph.graph import StateGraph, END from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage, HumanMessage, AIMessage import json from foundation.logger.loggering import server_logger as logger from foundation.utils.time_statistics import track_execution_time from .progress_manager import ProgressManager from .redis_duplicate_checker import RedisDuplicateChecker from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow,ReportWorkflow class ProgressManagerRegistry: """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例""" _registry = {} # {callback_task_id: ProgressManager} @classmethod def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager): """注册ProgressManager实例""" cls._registry[callback_task_id] = progress_manager logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}") @classmethod def get_progress_manager(cls, callback_task_id: str) -> ProgressManager: """获取ProgressManager实例""" return cls._registry.get(callback_task_id) @classmethod def unregister_progress_manager(cls, callback_task_id: str): """注销ProgressManager实例""" if callback_task_id in cls._registry: del cls._registry[callback_task_id] logger.info(f"注销ProgressManager实例: {callback_task_id}") @dataclass class TaskChain: """任务链""" callback_task_id: str file_id: str user_id: str status: str # processing, completed, failed current_stage: str created_at: datetime started_at: Optional[datetime] = None completed_at: Optional[datetime] = None results: Dict = None def __post_init__(self): if self.results is None: self.results = {} class WorkflowManager: """工作流管理器""" def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10): self.max_concurrent_docs = max_concurrent_docs self.max_concurrent_reviews = max_concurrent_reviews # 并发控制 self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs) self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews) # 服务组件 self.progress_manager = ProgressManager() # 简化:直接使用实例 self.redis_duplicate_checker = RedisDuplicateChecker() # 活跃任务跟踪 self.active_chains: Dict[str, TaskChain] = {} self._cleanup_task_started = False async def submit_task_processing(self, file_info: dict) -> str: """异步提交任务处理(用于file_upload层)""" from foundation.base.tasks import submit_task_processing_task from foundation.trace.celery_trace import CeleryTraceManager try: logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}") # 使用CeleryTraceManager提交任务,自动传递trace_id task = CeleryTraceManager.submit_celery_task( submit_task_processing_task, file_info ) logger.info(f"Celery任务已提交,Task ID: {task.id}") return task.id except Exception as e: logger.error(f"提交Celery任务失败: {str(e)}") raise @track_execution_time def submit_task_processing_sync(self, file_info: dict) -> dict: """同步提交任务处理(用于Celery worker)""" try: logger.info(f"提交文档处理任务: {file_info['file_id']}") # 1. 生成任务链ID callback_task_id = file_info['callback_task_id'] # 2. 创建任务链 task_chain = TaskChain( callback_task_id=callback_task_id, file_id=file_info.get('file_id', ''), user_id=file_info.get('user_id', 'default_user'), status="processing", current_stage="document_processing", created_at=datetime.now() ) # 3. 添加到活跃任务跟踪 self.active_chains[callback_task_id] = task_chain # 4. 初始化进度管理 asyncio.run(self.progress_manager.initialize_progress( callback_task_id=callback_task_id, user_id=file_info.get('user_id', 'default_user'), stages=[] )) # 6. 启动处理流程(同步执行) self._process_task_chain_sync(task_chain, file_info['file_content'], file_info['file_type']) # logger.info(f"提交文档处理任务: {callback_task_id}") logger.info(f"施工方案审查任务已完成! ") logger.info(f"文件ID: {file_info['file_id']}") logger.info(f"文件名:{file_info['file_name']}") except Exception as e: logger.error(f"提交文档处理任务失败: {str(e)}") raise def _process_task_chain_sync(self, task_chain: TaskChain, file_content: bytes, file_type: str): """同步处理文档任务链(用于Celery worker)""" try: task_chain.started_at = datetime.now() # 阶段1:文档处理(串行) task_chain.current_stage = "document_processing" document_workflow = DocumentWorkflow( file_id=task_chain.file_id, callback_task_id=task_chain.callback_task_id, user_id=task_chain.user_id, progress_manager=self.progress_manager, redis_duplicate_checker=self.redis_duplicate_checker ) # 同步执行文档处理 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) doc_result = loop.run_until_complete(document_workflow.execute(file_content, file_type)) loop.close() task_chain.results['document'] = doc_result # 阶段2:AI审查(内部并发) task_chain.current_stage = "ai_review" structured_content = doc_result['structured_content'] # 读取AI审查配置 import configparser config = configparser.ConfigParser() config.read('config/config.ini', encoding='utf-8') max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None) if max_review_units == 0: # 如果配置为0,表示审查所有 max_review_units = None review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all') logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}") ai_workflow = AIReviewWorkflow( file_id=task_chain.file_id, callback_task_id=task_chain.callback_task_id, user_id=task_chain.user_id, structured_content=structured_content, progress_manager=self.progress_manager, max_review_units=max_review_units, review_mode=review_mode ) # 同步执行AI审查 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) ai_result = loop.run_until_complete(ai_workflow.execute()) loop.close() task_chain.results['ai_review'] = ai_result # # 阶段3:报告生成(串行) # task_chain.current_stage = "report_generation" # report_workflow = ReportWorkflow( # file_id=task_chain.file_id, # callback_task_id=task_chain.callback_task_id, # user_id=task_chain.user_id, # ai_review_results=ai_result, # progress_manager=self.progress_manager # ) # # 同步执行报告生成 # loop = asyncio.new_event_loop() # asyncio.set_event_loop(loop) # report_result = loop.run_until_complete(report_workflow.execute()) # loop.close() # task_chain.results['report'] = report_result # 完成任务链 task_chain.status = "completed" task_chain.completed_at = datetime.now() # 清理任务注册 asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id)) # 通知SSE连接任务完成 asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id)) # 清理Redis文件缓存 try: from foundation.utils.redis_utils import delete_file_info asyncio.run(delete_file_info(task_chain.file_id)) logger.info(f"已清理Redis文件缓存: {task_chain.file_id}") except Exception as e: logger.warning(f"清理Redis文件缓存失败: {str(e)}") logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}") return task_chain.results except Exception as e: task_chain.status = "failed" logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}") # 清理任务注册 asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id)) # 清理Redis文件缓存(即使失败也清理) try: from foundation.utils.redis_utils import delete_file_info asyncio.run(delete_file_info(task_chain.file_id)) logger.info(f"已清理Redis文件缓存: {task_chain.file_id}") except Exception as cleanup_error: logger.warning(f"清理Redis文件缓存失败: {str(cleanup_error)}") # 通知SSE连接任务失败 error_result = { "error": str(e), "status": "failed", "timestamp": datetime.now().isoformat() } asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id)) raise finally: # 清理活跃任务 if task_chain.callback_task_id in self.active_chains: del self.active_chains[task_chain.callback_task_id]