workflow_manager.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. 基于LangGraph的工作流管理器
  3. 负责任务的创建、编排和执行,使用LangGraph进行状态管理
  4. """
  5. import asyncio
  6. import uuid
  7. from typing import Dict, Optional, TypedDict, Annotated, List
  8. from datetime import datetime
  9. from dataclasses import dataclass
  10. from langgraph.graph import StateGraph, END
  11. from langgraph.graph.message import add_messages
  12. from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
  13. import json
  14. from foundation.logger.loggering import server_logger as logger
  15. from foundation.utils.time_statistics import track_execution_time
  16. from .progress_manager import ProgressManager
  17. from .redis_duplicate_checker import RedisDuplicateChecker
  18. from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow,ReportWorkflow
  19. class ProgressManagerRegistry:
  20. """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
  21. _registry = {} # {callback_task_id: ProgressManager}
  22. @classmethod
  23. def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
  24. """注册ProgressManager实例"""
  25. cls._registry[callback_task_id] = progress_manager
  26. logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
  27. @classmethod
  28. def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
  29. """获取ProgressManager实例"""
  30. return cls._registry.get(callback_task_id)
  31. @classmethod
  32. def unregister_progress_manager(cls, callback_task_id: str):
  33. """注销ProgressManager实例"""
  34. if callback_task_id in cls._registry:
  35. del cls._registry[callback_task_id]
  36. logger.info(f"注销ProgressManager实例: {callback_task_id}")
  37. @dataclass
  38. class TaskChain:
  39. """任务链"""
  40. callback_task_id: str
  41. file_id: str
  42. user_id: str
  43. status: str # processing, completed, failed
  44. current_stage: str
  45. created_at: datetime
  46. started_at: Optional[datetime] = None
  47. completed_at: Optional[datetime] = None
  48. results: Dict = None
  49. def __post_init__(self):
  50. if self.results is None:
  51. self.results = {}
  52. class WorkflowManager:
  53. """工作流管理器"""
  54. def __init__(self, max_concurrent_docs: int = 5, max_concurrent_reviews: int = 10):
  55. self.max_concurrent_docs = max_concurrent_docs
  56. self.max_concurrent_reviews = max_concurrent_reviews
  57. # 并发控制
  58. self.doc_semaphore = asyncio.Semaphore(max_concurrent_docs)
  59. self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
  60. # 服务组件
  61. self.progress_manager = ProgressManager() # 简化:直接使用实例
  62. self.redis_duplicate_checker = RedisDuplicateChecker()
  63. # 活跃任务跟踪
  64. self.active_chains: Dict[str, TaskChain] = {}
  65. self._cleanup_task_started = False
  66. async def submit_task_processing(self, file_info: dict) -> str:
  67. """异步提交任务处理(用于file_upload层)"""
  68. from foundation.base.tasks import submit_task_processing_task
  69. from foundation.trace.celery_trace import CeleryTraceManager
  70. try:
  71. logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
  72. # 使用CeleryTraceManager提交任务,自动传递trace_id
  73. task = CeleryTraceManager.submit_celery_task(
  74. submit_task_processing_task,
  75. file_info
  76. )
  77. logger.info(f"Celery任务已提交,Task ID: {task.id}")
  78. return task.id
  79. except Exception as e:
  80. logger.error(f"提交Celery任务失败: {str(e)}")
  81. raise
  82. @track_execution_time
  83. def submit_task_processing_sync(self, file_info: dict) -> dict:
  84. """同步提交任务处理(用于Celery worker)"""
  85. try:
  86. logger.info(f"提交文档处理任务: {file_info['file_id']}")
  87. # 1. 生成任务链ID
  88. callback_task_id = file_info['callback_task_id']
  89. # 2. 创建任务链
  90. task_chain = TaskChain(
  91. callback_task_id=callback_task_id,
  92. file_id=file_info.get('file_id', ''),
  93. user_id=file_info.get('user_id', 'default_user'),
  94. status="processing",
  95. current_stage="document_processing",
  96. created_at=datetime.now()
  97. )
  98. # 4. 注册任务
  99. asyncio.run(self.redis_duplicate_checker.register_task(file_info, callback_task_id))
  100. self.active_chains[callback_task_id] = task_chain
  101. # 5. 初始化进度管理
  102. asyncio.run(self.progress_manager.initialize_progress(
  103. callback_task_id=callback_task_id,
  104. user_id=file_info.get('user_id', 'default_user'),
  105. stages=[]
  106. ))
  107. # 6. 启动处理流程(同步执行)
  108. self._process_task_chain_sync(task_chain, file_info['file_content'], file_info['file_type'])
  109. # logger.info(f"提交文档处理任务: {callback_task_id}")
  110. logger.info(f"施工方案审查任务已完成! ")
  111. logger.info(f"文件ID: {file_info['file_id']}")
  112. logger.info(f"文件名:{file_info['file_name']}")
  113. except Exception as e:
  114. logger.error(f"提交文档处理任务失败: {str(e)}")
  115. raise
  116. def _process_task_chain_sync(self, task_chain: TaskChain, file_content: bytes, file_type: str):
  117. """同步处理文档任务链(用于Celery worker)"""
  118. try:
  119. task_chain.started_at = datetime.now()
  120. # 阶段1:文档处理(串行)
  121. task_chain.current_stage = "document_processing"
  122. document_workflow = DocumentWorkflow(
  123. file_id=task_chain.file_id,
  124. callback_task_id=task_chain.callback_task_id,
  125. user_id=task_chain.user_id,
  126. progress_manager=self.progress_manager,
  127. redis_duplicate_checker=self.redis_duplicate_checker
  128. )
  129. # 同步执行文档处理
  130. loop = asyncio.new_event_loop()
  131. asyncio.set_event_loop(loop)
  132. doc_result = loop.run_until_complete(document_workflow.execute(file_content, file_type))
  133. loop.close()
  134. task_chain.results['document'] = doc_result
  135. # 阶段2:AI审查(内部并发)
  136. task_chain.current_stage = "ai_review"
  137. structured_content = doc_result['structured_content']
  138. # 读取AI审查配置
  139. import configparser
  140. config = configparser.ConfigParser()
  141. config.read('config/config.ini', encoding='utf-8')
  142. max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
  143. if max_review_units == 0: # 如果配置为0,表示审查所有
  144. max_review_units = None
  145. review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
  146. logger.info(f"AI审查配置: 最大审查数量={max_review_units}, 审查模式={review_mode}")
  147. ai_workflow = AIReviewWorkflow(
  148. file_id=task_chain.file_id,
  149. callback_task_id=task_chain.callback_task_id,
  150. user_id=task_chain.user_id,
  151. structured_content=structured_content,
  152. progress_manager=self.progress_manager,
  153. max_review_units=max_review_units,
  154. review_mode=review_mode
  155. )
  156. # 同步执行AI审查
  157. loop = asyncio.new_event_loop()
  158. asyncio.set_event_loop(loop)
  159. ai_result = loop.run_until_complete(ai_workflow.execute())
  160. loop.close()
  161. task_chain.results['ai_review'] = ai_result
  162. # 阶段3:报告生成(串行)
  163. task_chain.current_stage = "report_generation"
  164. report_workflow = ReportWorkflow(
  165. file_id=task_chain.file_id,
  166. callback_task_id=task_chain.callback_task_id,
  167. user_id=task_chain.user_id,
  168. ai_review_results=ai_result,
  169. progress_manager=self.progress_manager
  170. )
  171. # 同步执行报告生成
  172. loop = asyncio.new_event_loop()
  173. asyncio.set_event_loop(loop)
  174. report_result = loop.run_until_complete(report_workflow.execute())
  175. loop.close()
  176. task_chain.results['report'] = report_result
  177. # 完成任务链
  178. task_chain.status = "completed"
  179. task_chain.completed_at = datetime.now()
  180. # 清理任务注册
  181. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  182. # 通知SSE连接任务完成
  183. asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id))
  184. # 清理Redis文件缓存
  185. try:
  186. from foundation.utils.redis_utils import delete_file_info
  187. asyncio.run(delete_file_info(task_chain.file_id))
  188. logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
  189. except Exception as e:
  190. logger.warning(f"清理Redis文件缓存失败: {str(e)}")
  191. logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
  192. return task_chain.results
  193. except Exception as e:
  194. task_chain.status = "failed"
  195. logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}")
  196. # 清理任务注册
  197. asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
  198. # 清理Redis文件缓存(即使失败也清理)
  199. try:
  200. from foundation.utils.redis_utils import delete_file_info
  201. asyncio.run(delete_file_info(task_chain.file_id))
  202. logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
  203. except Exception as cleanup_error:
  204. logger.warning(f"清理Redis文件缓存失败: {str(cleanup_error)}")
  205. # 通知SSE连接任务失败
  206. error_result = {
  207. "error": str(e),
  208. "status": "failed",
  209. "timestamp": datetime.now().isoformat()
  210. }
  211. asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id))
  212. raise
  213. finally:
  214. # 清理活跃任务
  215. if task_chain.callback_task_id in self.active_chains:
  216. del self.active_chains[task_chain.callback_task_id]
  217. async def update_task_status(self, callback_task_id: str) -> Optional[Dict]:
  218. """更新任务状态"""
  219. pass