| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- """
- 基于LangGraph的工作流管理器
- 负责任务的创建、编排和执行,使用LangGraph进行状态管理
- 新增功能:
- - 任务终止管理
- - 终止信号设置和检测
- """
- import asyncio
- import time
- from typing import Dict, Optional
- from datetime import datetime
- from foundation.observability.logger.loggering import server_logger as logger
- from foundation.observability.monitoring.time_statistics import track_execution_time
- from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory
- from .progress_manager import ProgressManager
- from .redis_duplicate_checker import RedisDuplicateChecker
- from .task_models import TaskFileInfo, TaskChain
- from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow
- class ProgressManagerRegistry:
- """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
- _registry = {} # {callback_task_id: ProgressManager}
- @classmethod
- def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
- """注册ProgressManager实例"""
- cls._registry[callback_task_id] = progress_manager
- logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
- @classmethod
- def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
- """获取ProgressManager实例"""
- return cls._registry.get(callback_task_id)
- @classmethod
- def unregister_progress_manager(cls, callback_task_id: str):
- """注销ProgressManager实例"""
- if callback_task_id in cls._registry:
- del cls._registry[callback_task_id]
- logger.info(f"注销ProgressManager实例: {callback_task_id}")
- class WorkflowManager:
- """工作流管理器"""
- def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10):
- self.max_concurrent_docs = max_concurrent_docs
- self.max_concurrent_reviews = max_concurrent_reviews
- # 并发控制
- self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs)
- self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
- # 服务组件
- self.progress_manager = ProgressManager()
- self.redis_duplicate_checker = RedisDuplicateChecker()
- # 活跃任务跟踪
- self.active_chains: Dict[str, TaskChain] = {}
- self._cleanup_task_started = False
- # 任务终止管理
- self._terminate_signal_prefix = "ai_review:terminate_signal:"
- self._task_expire_time = 7200 # 2小时
- async def submit_task_processing(self, file_info: dict) -> str:
- """异步提交任务处理(用于file_upload层)"""
- from foundation.infrastructure.messaging.tasks import submit_task_processing_task
- from foundation.infrastructure.tracing.celery_trace import CeleryTraceManager
- try:
- logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
- # 使用CeleryTraceManager提交任务,自动传递trace_id
- task = CeleryTraceManager.submit_celery_task(
- submit_task_processing_task,
- file_info
- )
- logger.info(f"Celery任务已提交,Task ID: {task.id}")
- return task.id
- except Exception as e:
- logger.error(f"提交Celery任务失败: {str(e)}")
- raise
- @track_execution_time
- def submit_task_processing_sync(self, file_info: dict) -> dict:
- """同步提交任务处理(用于Celery worker)"""
- try:
- logger.info(f"提交文档处理任务: {file_info['file_id']}")
- # 1. 创建TaskFileInfo对象(封装任务文件信息)
- task_file_info = TaskFileInfo(file_info)
- logger.info(f"创建任务文件信息: {task_file_info}")
- # 2. 生成任务链ID
- callback_task_id = task_file_info.callback_task_id
- # 3. 创建任务链(引用 TaskFileInfo,避免数据重复)
- task_chain = TaskChain(task_file_info)
- # 4. 标记任务开始
- task_chain.start_processing()
- # 5. 添加到活跃任务跟踪
- self.active_chains[callback_task_id] = task_chain
- # 5. 初始化进度管理
- asyncio.run(self.progress_manager.initialize_progress(
- callback_task_id=callback_task_id,
- user_id=task_file_info.user_id,
- stages=[]
- ))
- # 6. 启动处理流程(同步执行)
- self._process_task_chain_sync(task_chain, task_file_info, task_file_info.file_type)
- # logger.info(f"提交文档处理任务: {callback_task_id}")
- logger.info(f"施工方案审查任务已完成! ")
- logger.info(f"文件ID: {task_file_info.file_id}")
- logger.info(f"文件名:{task_file_info.file_name}")
- except Exception as e:
- logger.error(f"提交文档处理任务失败: {str(e)}")
- raise
- def _process_task_chain_sync(self, task_chain: TaskChain, task_file_info: TaskFileInfo, file_type: str):
- """同步处理文档任务链(用于Celery worker)"""
- try:
- file_content = task_file_info.file_content
- # 阶段1:文档处理(串行)
- document_workflow = DocumentWorkflow(
- task_file_info=task_file_info,
- progress_manager=self.progress_manager,
- redis_duplicate_checker=self.redis_duplicate_checker
- )
- # 同步执行文档处理
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- doc_result = loop.run_until_complete(document_workflow.execute(file_content, file_type))
- loop.close()
- task_chain.results['document'] = doc_result
- # 阶段2:AI审查(内部并发)
- task_chain.update_stage("ai_review")
- structured_content = doc_result['structured_content']
- # 读取AI审查配置
- import configparser
- config = configparser.ConfigParser()
- config.read('config/config.ini', encoding='utf-8')
- max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
- if max_review_units == 0: # 如果配置为0,表示审查所有
- max_review_units = None
- review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
- logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}")
- ai_workflow = AIReviewWorkflow(
- task_file_info=task_file_info,
- structured_content=structured_content,
- progress_manager=self.progress_manager,
- max_review_units=max_review_units,
- review_mode=review_mode
- )
- # 同步执行AI审查
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- ai_result = loop.run_until_complete(ai_workflow.execute())
- loop.close()
- task_chain.results['ai_review'] = ai_result
- # # 阶段3:报告生成(串行)
- # task_chain.current_stage = "report_generation"
- # report_workflow = ReportWorkflow(
- # file_id=task_chain.file_id,
- # callback_task_id=task_chain.callback_task_id,
- # user_id=task_chain.user_id,
- # ai_review_results=ai_result,
- # progress_manager=self.progress_manager
- # )
- # # 同步执行报告生成
- # loop = asyncio.new_event_loop()
- # asyncio.set_event_loop(loop)
- # report_result = loop.run_until_complete(report_workflow.execute())
- # loop.close()
- # task_chain.results['report'] = report_result
- # 完成任务链
- task_chain.complete_processing()
- # 清理任务注册
- asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
- # 通知SSE连接任务完成
- asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id, task_chain.user_id))
- # 清理Redis文件缓存
- try:
- from foundation.utils.redis_utils import delete_file_info
- asyncio.run(delete_file_info(task_chain.file_id))
- logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
- except Exception as e:
- logger.warning(f"清理Redis文件缓存失败: {str(e)}")
- logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
- return task_chain.results
- except Exception as e:
- # 标记任务失败
- task_chain.fail_processing(str(e))
- logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}")
- # 清理任务注册
- asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
- # 清理Redis文件缓存(即使失败也清理)
- try:
- from foundation.utils.redis_utils import delete_file_info
- asyncio.run(delete_file_info(task_chain.file_id))
- logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
- except Exception as cleanup_error:
- logger.warning(f"清理Redis文件缓存失败: {str(cleanup_error)}")
- # 通知SSE连接任务失败
- error_result = {
- "error": str(e),
- "status": "failed",
- "timestamp": datetime.now().isoformat()
- }
- current_data = {
- "status": "failed",
- "result": error_result
- }
- asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id, task_chain.user_id, current_data))
- raise
- finally:
- # 清理活跃任务
- if task_chain.callback_task_id in self.active_chains:
- del self.active_chains[task_chain.callback_task_id]
- # ==================== 任务终止管理方法 ====================
- async def set_terminate_signal(self, callback_task_id: str, operator: str = "unknown") -> Dict[str, any]:
- """
- 设置任务终止信号
- Args:
- callback_task_id: 任务回调ID
- operator: 操作人(用户ID或系统标识)
- Returns:
- Dict: 操作结果 {"success": bool, "message": str, "task_info": dict}
- Note:
- 将终止信号写入 Redis,支持跨进程检测
- AI审查节点在执行前会检查此信号
- """
- try:
- # 检查任务是否在活跃列表中
- if callback_task_id not in self.active_chains:
- return {
- "success": False,
- "message": f"任务不存在或已完成: {callback_task_id}",
- "task_info": None
- }
- task_chain = self.active_chains[callback_task_id]
- # 检查任务状态
- if task_chain.status != "processing":
- return {
- "success": False,
- "message": f"任务状态不是 processing,无需终止: {callback_task_id} (当前状态: {task_chain.status})",
- "task_info": {
- "callback_task_id": callback_task_id,
- "status": task_chain.status,
- "file_name": task_chain.file_name
- }
- }
- # 设置 Redis 终止信号
- redis_client = await RedisConnectionFactory.get_connection()
- terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
- # 存储终止信号和操作人、时间
- terminate_data = {
- "operator": operator,
- "terminate_time": str(time.time()),
- "task_id": callback_task_id
- }
- # 使用 hash 存储更多信息
- await redis_client.hset(terminate_key, mapping=terminate_data)
- # 设置过期时间(2小时)
- await redis_client.expire(terminate_key, self._task_expire_time)
- logger.info(f"已设置终止信号: {callback_task_id} (操作人: {operator}, 文件: {task_chain.file_name})")
- return {
- "success": True,
- "message": f"终止信号已设置,任务将在当前节点完成后终止",
- "task_info": {
- "callback_task_id": callback_task_id,
- "file_id": task_chain.file_id,
- "file_name": task_chain.file_name,
- "user_id": task_chain.user_id,
- "status": task_chain.status,
- "current_stage": task_chain.current_stage
- }
- }
- except Exception as e:
- logger.error(f"设置终止信号失败: {str(e)}", exc_info=True)
- return {
- "success": False,
- "message": f"设置终止信号失败: {str(e)}",
- "task_info": None
- }
- async def check_terminate_signal(self, callback_task_id: str) -> bool:
- """
- 检查是否有终止信号
- Args:
- callback_task_id: 任务回调ID
- Returns:
- bool: 有终止信号返回 True
- Note:
- 从 Redis 读取终止信号
- 工作流节点在执行前调用此方法检查是否需要终止
- """
- try:
- redis_client = await RedisConnectionFactory.get_connection()
- terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
- # 检查键是否存在
- exists = await redis_client.exists(terminate_key)
- if exists:
- # 读取终止信息
- terminate_info = await redis_client.hgetall(terminate_key)
- logger.warning(f"检测到终止信号: {callback_task_id}, 操作人: {terminate_info.get(b'operator', b'unknown').decode()}")
- return True
- return False
- except RuntimeError as e:
- # 事件循环相关的错误处理
- error_msg = str(e)
- if "Event loop is closed" in error_msg:
- # 事件循环关闭是正常情况(任务结束),不记录错误
- logger.debug(f"检查终止信号时事件循环已关闭: {callback_task_id}")
- return False
- elif "bound to a different event loop" in error_msg:
- # 事件循环不匹配,记录警告但不中断流程
- logger.warning(f"检查终止信号时检测到事件循环不匹配: {callback_task_id},将忽略本次检查")
- return False
- else:
- # 其他 RuntimeError 记录错误
- logger.error(f"检查终止信号失败(RuntimeError): {error_msg}", exc_info=True)
- return False
- except Exception as e:
- # 其他异常仍然记录错误
- logger.error(f"检查终止信号失败: {str(e)}", exc_info=True)
- return False
- async def clear_terminate_signal(self, callback_task_id: str):
- """
- 清理 Redis 中的终止信号
- Args:
- callback_task_id: 任务回调ID
- """
- try:
- redis_client = await RedisConnectionFactory.get_connection()
- terminate_key = f"{self._terminate_signal_prefix}{callback_task_id}"
- await redis_client.delete(terminate_key)
- logger.debug(f"清理终止信号: {callback_task_id}")
- except Exception as e:
- logger.warning(f"清理终止信号失败: {str(e)}")
- async def get_active_tasks(self) -> list:
- """
- 获取活跃任务列表
- Returns:
- list: 活跃任务信息列表
- """
- try:
- active_tasks = []
- current_time = time.time()
- for task_id, task_chain in self.active_chains.items():
- if task_chain.status == "processing":
- task_info = {
- "callback_task_id": task_id,
- "file_id": task_chain.file_id,
- "file_name": task_chain.file_name,
- "user_id": task_chain.user_id,
- "status": task_chain.status,
- "current_stage": task_chain.current_stage,
- "start_time": task_chain.start_time,
- "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0
- }
- active_tasks.append(task_info)
- return active_tasks
- except Exception as e:
- logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
- return []
- async def get_task_info(self, callback_task_id: str) -> Optional[Dict]:
- """
- 获取任务信息
- Args:
- callback_task_id: 任务回调ID
- Returns:
- Optional[Dict]: 任务信息字典,不存在返回 None
- """
- try:
- task_chain = self.active_chains.get(callback_task_id)
- if task_chain:
- current_time = time.time()
- return {
- "callback_task_id": callback_task_id,
- "file_id": task_chain.file_id,
- "file_name": task_chain.file_name,
- "user_id": task_chain.user_id,
- "status": task_chain.status,
- "current_stage": task_chain.current_stage,
- "start_time": task_chain.start_time,
- "running_duration": int(current_time - task_chain.start_time) if task_chain.start_time else 0,
- "results": task_chain.results
- }
- return None
- except Exception as e:
- logger.error(f"获取任务信息失败: {str(e)}", exc_info=True)
- return None
|