import asyncio import json import time import uuid from typing import Any, Dict, Optional from langchain_core.messages import HumanMessage from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory from foundation.infrastructure.tracing import TraceContext from foundation.observability.logger.loggering import write_logger as logger from foundation.observability.monitoring.time_statistics import track_execution_time from core.base.progress_manager import ProgressManager class ProgressManagerRegistry: """进度管理器注册表,用于在 LangGraph 工作流节点间共享 ProgressManager 实例。""" _registry: Dict[str, ProgressManager] = {} @classmethod def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager): """注册任务的进度管理器。 Args: callback_task_id: 任务回调 ID progress_manager: ProgressManager 实例 """ cls._registry[callback_task_id] = progress_manager logger.info(f"注册 ProgressManager: {callback_task_id}") @classmethod def get_progress_manager(cls, callback_task_id: str) -> Optional[ProgressManager]: """获取指定任务的进度管理器。 Args: callback_task_id: 任务回调 ID Returns: ProgressManager 实例,未找到则返回 None """ return cls._registry.get(callback_task_id) @classmethod def unregister_progress_manager(cls, callback_task_id: Optional[str]): """注销指定任务的进度管理器。 Args: callback_task_id: 任务回调 ID """ if callback_task_id and callback_task_id in cls._registry: del cls._registry[callback_task_id] logger.info(f"注销 ProgressManager: {callback_task_id}") class WorkflowManager: """施工方案编写任务的工作流管理器。 负责: 1. 大纲生成任务的 Celery 提交与同步执行 2. LangGraph 工作流图的编译与调度 3. 任务终止信号管理(Redis 存储) 4. 任务结果的 Redis 持久化 5. 活跃任务状态查询 """ def __init__(self): self.progress_manager = ProgressManager() self.active_outline_tasks: Dict[str, Any] = {} self._outline_result_prefix = "outline_write:result:" self._outline_terminate_signal_prefix = "outline_write:terminate_signal:" self._task_expire_time = 7200 self.outline_generation_graph = None def _run_async(self, coro): """在同步上下文中运行异步协程。 Args: coro: 待执行的协程对象 Returns: 协程的执行结果 """ loop = asyncio.new_event_loop() try: asyncio.set_event_loop(loop) return loop.run_until_complete(coro) finally: try: loop.close() finally: asyncio.set_event_loop(None) def _outline_progress_stages(self) -> list: """返回大纲生成任务的进度阶段定义。 Returns: 阶段列表,每个阶段包含 stage 名称和 status """ return [ {"stage": "start", "status": "pending"}, {"stage": "template_loading", "status": "pending"}, {"stage": "outline_generation", "status": "pending"}, {"stage": "similar_cases", "status": "pending"}, {"stage": "similar_fragments", "status": "pending"}, {"stage": "knowledge_bases", "status": "pending"}, {"stage": "complete", "status": "pending"}, ] async def submit_outline_generation_task(self, sgbx_task_info: dict) -> str: """提交大纲生成任务到 Celery 队列。 流程:预注册任务 → 初始化进度 → 提交 Celery → 更新状态 Args: sgbx_task_info: 施工方案任务信息字典 Returns: Celery 任务 ID Raises: Exception: Celery 任务提交失败时抛出 """ from foundation.infrastructure.messaging.tasks import submit_outline_generation_task callback_task_id = sgbx_task_info.get("callback_task_id") user_id = sgbx_task_info.get("user_id", "unknown") logger.info( f"Submit outline generation task: callback_task_id={callback_task_id}, user_id={user_id}" ) await self._pre_register_outline_task(sgbx_task_info) await self.progress_manager.initialize_progress( callback_task_id=callback_task_id, user_id=user_id, stages=self._outline_progress_stages(), ) await self.progress_manager.update_stage_progress( callback_task_id=callback_task_id, user_id=user_id, current=5, stage_name="任务提交中", status="processing", message="正在提交大纲生成任务...", overall_task_status="processing", ) kwargs = {} trace_id = TraceContext.get_trace_id() if trace_id and trace_id != "no-trace": kwargs["_system_trace_id"] = trace_id try: task = submit_outline_generation_task.apply_async( args=[sgbx_task_info], kwargs=kwargs, queue="construction_write", ) except Exception as exc: await self.progress_manager.update_stage_progress( callback_task_id=callback_task_id, user_id=user_id, current=5, stage_name="任务提交失败", status="failed", message=f"提交大纲生成任务失败: {exc}", overall_task_status="failed", event_type="failed", ) raise await self._update_outline_result_status( callback_task_id, overall_task_status="processing", celery_task_id=task.id, submitted_time=str(time.time()), ) await self.progress_manager.update_stage_progress( callback_task_id=callback_task_id, user_id=user_id, current=10, stage_name="任务已提交", status="submitted", message="大纲生成任务已提交,正在执行...", overall_task_status="processing", ) logger.info(f"Outline generation Celery task submitted: {task.id}") return task.id @track_execution_time def submit_outline_generation_sync(self, sgbx_task_info: dict) -> dict: """同步执行大纲生成任务(Celery worker 内部调用)。 Args: sgbx_task_info: 施工方案任务信息字典 Returns: 包含大纲结构、关键要点、相似推荐等结果的字典 """ callback_task_id = sgbx_task_info.get("callback_task_id") or f"outline_{uuid.uuid4().hex[:16]}" return self._run_async(self._execute_outline_generation(sgbx_task_info, callback_task_id)) async def _execute_outline_generation(self, sgbx_task_info: dict, callback_task_id: str) -> dict: """执行大纲生成工作流的核心方法。 完整流程:检查终止信号 → 构建初始状态 → 编译/加载 LangGraph 图 → 执行工作流 → 保存结果到 Redis Args: sgbx_task_info: 施工方案任务信息字典 callback_task_id: 任务回调 ID Returns: 包含大纲结构、关键要点、相似推荐等结果的字典 """ from core.construction_write.component.state_models import OutlineGenerationState, OutlineTaskInfo from core.construction_write.workflows.outline_workflow import OutlineWorkflow user_id = sgbx_task_info.get("user_id", "unknown") outline_sgbx_task_info: Optional[OutlineTaskInfo] = None try: logger.info(f"Start outline generation workflow: {callback_task_id}") if await self.check_outline_terminate_signal(callback_task_id): return self._terminated_result(callback_task_id, user_id, "Task was cancelled before start") outline_sgbx_task_info = OutlineTaskInfo( callback_task_id=callback_task_id, user_id=user_id, project_info=sgbx_task_info.get("project_info", {}), template_id=sgbx_task_info.get("template_id", ""), generation_chapterenum=sgbx_task_info.get("generation_chapterenum", []), generation_template=sgbx_task_info.get("generation_template", []), similarity_config=sgbx_task_info.get( "similarity_config", {"topk_plans": 3, "topk_fragments": 10, "threshold": 0.75}, ), knowledge_config=sgbx_task_info.get( "knowledge_config", {"topk": 3, "threshold": 0.75}, ), ) self.active_outline_tasks[callback_task_id] = outline_sgbx_task_info await self._update_outline_result_status( callback_task_id, overall_task_status="processing", worker_started_at=str(time.time()), ) await self.progress_manager.initialize_progress( callback_task_id=callback_task_id, user_id=user_id, stages=self._outline_progress_stages(), ) ProgressManagerRegistry.register_progress_manager(callback_task_id, self.progress_manager) outline_sgbx_task_info.start_processing() if self.outline_generation_graph is None: self.outline_generation_graph = OutlineWorkflow().build_graph() initial_state = OutlineGenerationState( callback_task_id=callback_task_id, user_id=user_id, project_info=outline_sgbx_task_info.project_info, template_id=outline_sgbx_task_info.template_id, generation_chapterenum=outline_sgbx_task_info.generation_chapterenum, generation_template=outline_sgbx_task_info.generation_template, similarity_config=outline_sgbx_task_info.similarity_config, knowledge_config=outline_sgbx_task_info.knowledge_config, template=None, outline_structure=None, key_points=None, similar_cases=None, similar_fragments=None, knowledge_bases=None, current_stage="start", overall_task_status="processing", error_message=None, messages=[HumanMessage(content=f"Start outline generation task {callback_task_id}")], ) result = await self.outline_generation_graph.ainvoke( initial_state, config={"configurable": {"thread_id": callback_task_id}}, ) status = result.get("overall_task_status") if status == "completed": outline_sgbx_task_info.complete_processing( { "outline_structure": result.get("outline_structure"), "key_points": result.get("key_points"), "similar_cases": result.get("similar_cases"), "similar_fragments": result.get("similar_fragments"), "knowledge_bases": result.get("knowledge_bases"), } ) elif status == "failed": outline_sgbx_task_info.fail_processing(result.get("error_message", "unknown error")) elif status == "terminated": outline_sgbx_task_info.cancel_processing() await self._save_result_to_redis(callback_task_id, user_id, result) return { "callback_task_id": result.get("callback_task_id"), "user_id": result.get("user_id"), "overall_task_status": result.get("overall_task_status"), "outline_structure": result.get("outline_structure"), "key_points": result.get("key_points"), "similar_cases": result.get("similar_cases"), "similar_fragments": result.get("similar_fragments"), "knowledge_bases": result.get("knowledge_bases"), "error_message": result.get("error_message"), } except Exception as exc: logger.error(f"Outline generation task failed: {exc}", exc_info=True) if outline_sgbx_task_info: outline_sgbx_task_info.fail_processing(str(exc)) error_message = str(exc) failed_result = { "callback_task_id": callback_task_id, "user_id": user_id, "overall_task_status": "failed", "outline_structure": None, "key_points": None, "similar_cases": None, "similar_fragments": None, "knowledge_bases": None, "error_message": error_message, } try: await self.progress_manager.update_stage_progress( callback_task_id=callback_task_id, user_id=user_id, stage_name="任务失败", status="failed", message=f"大纲生成任务失败: {error_message}", overall_task_status="failed", event_type="failed", ) await self._save_result_to_redis(callback_task_id, user_id, failed_result) except Exception as update_exc: logger.warning( f"Failed to persist outline failure state: {callback_task_id}, {update_exc}", exc_info=True, ) raise finally: self.active_outline_tasks.pop(callback_task_id, None) ProgressManagerRegistry.unregister_progress_manager(callback_task_id) def _terminated_result(self, callback_task_id: str, user_id: str, message: str) -> dict: """构建任务终止时的返回结果。 Args: callback_task_id: 任务回调 ID user_id: 用户 ID message: 终止原因 Returns: 终止状态字典 """ return { "callback_task_id": callback_task_id, "user_id": user_id, "overall_task_status": "terminated", "outline_structure": None, "key_points": None, "similar_cases": None, "similar_fragments": None, "knowledge_bases": None, "error_message": message, } async def _save_result_to_redis(self, callback_task_id: str, user_id: str, result: dict): """将大纲生成结果保存到 Redis。 Args: callback_task_id: 任务回调 ID user_id: 用户 ID result: 工作流执行结果字典 """ redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" result_data = { "callback_task_id": callback_task_id, "user_id": user_id, "overall_task_status": result.get("overall_task_status", ""), "outline_structure": json.dumps(result.get("outline_structure"), ensure_ascii=False) if result.get("outline_structure") else "", "key_points": json.dumps(result.get("key_points"), ensure_ascii=False) if result.get("key_points") else "", "similar_cases": json.dumps(result.get("similar_cases"), ensure_ascii=False) if result.get("similar_cases") else "", "similar_fragments": json.dumps(result.get("similar_fragments"), ensure_ascii=False) if result.get("similar_fragments") else "", "knowledge_bases": json.dumps(result.get("knowledge_bases"), ensure_ascii=False) if result.get("knowledge_bases") else "", "error_message": result.get("error_message") or "", "completed_time": str(time.time()), } await redis_client.hmset(result_key, result_data) await redis_client.expire(result_key, self._task_expire_time) logger.info(f"Outline generation result saved to Redis: {callback_task_id}") async def _update_outline_result_status(self, callback_task_id: str, **fields): """更新 Redis 中任务状态的字段。 Args: callback_task_id: 任务回调 ID **fields: 需要更新的字段键值对 """ if not callback_task_id: return try: redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" fields = {key: "" if value is None else str(value) for key, value in fields.items()} if not fields: return await redis_client.hmset(result_key, fields) await redis_client.expire(result_key, self._task_expire_time) except Exception as exc: logger.warning(f"Update outline result status failed: {callback_task_id}, {exc}") async def _pre_register_outline_task(self, sgbx_task_info: dict): """预注册任务信息到 Redis(在 Celery 提交前调用,用于进度查询兜底)。 Args: sgbx_task_info: 施工方案任务信息字典 """ try: callback_task_id = sgbx_task_info.get("callback_task_id") user_id = sgbx_task_info.get("user_id", "unknown") project_info = sgbx_task_info.get("project_info", {}) redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" await redis_client.hmset( result_key, { "callback_task_id": callback_task_id, "user_id": user_id, "project_name": project_info.get("project_name", ""), "project_type": project_info.get("engineering_type", ""), "overall_task_status": "pending", "outline_structure": "", "key_points": "", "similar_cases": "", "similar_fragments": "", "knowledge_bases": "", "error_message": "", "celery_task_id": "", "pre_registered": "true", "pre_registered_at": str(time.time()), "completed_time": "", }, ) await redis_client.expire(result_key, self._task_expire_time) except Exception as exc: logger.error(f"Pre-register outline task failed: {exc}", exc_info=True) async def set_outline_terminate_signal( self, callback_task_id: str, operator: str = "unknown", reason: str = "", ) -> Dict[str, Any]: """设置大纲生成任务的终止信号。 Args: callback_task_id: 任务回调 ID operator: 操作人标识 reason: 终止原因 Returns: 包含 success/message/sgbx_task_info 的结果字典 """ try: task_status = None task_user_id = "unknown" project_name = "" if callback_task_id in self.active_outline_tasks: task_info = self.active_outline_tasks[callback_task_id] task_status = task_info.status task_user_id = task_info.user_id project_name = task_info.project_name else: redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" result_data = await redis_client.hgetall(result_key) if not result_data: return {"success": False, "message": f"Task not found: {callback_task_id}", "sgbx_task_info": None} task_status = result_data.get("overall_task_status", "unknown") task_user_id = result_data.get("user_id", "unknown") project_name = result_data.get("project_name", "") if task_status not in ["pending", "processing", "submitted"]: return { "success": False, "message": f"Task cannot be cancelled in status {task_status}: {callback_task_id}", "sgbx_task_info": { "callback_task_id": callback_task_id, "status": task_status, "project_name": project_name, }, } redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}" await redis_client.hmset( terminate_key, { "operator": operator, "reason": reason, "terminate_time": str(time.time()), "task_id": callback_task_id, }, ) await redis_client.expire(terminate_key, self._task_expire_time) await self._update_outline_result_status( callback_task_id, overall_task_status="terminated", error_message=reason or "Task cancellation requested", terminated_by=operator, terminated_time=str(time.time()), ) if task_status == "pending": return { "success": True, "message": "Task cancelled before start", "sgbx_task_info": { "callback_task_id": callback_task_id, "user_id": task_user_id, "project_name": project_name, "status": "cancelled", }, } return { "success": True, "message": "Terminate signal set", "sgbx_task_info": { "callback_task_id": callback_task_id, "user_id": task_user_id, "project_name": project_name, "status": "cancelled", }, } except Exception as exc: logger.error(f"Set terminate signal failed: {exc}", exc_info=True) return {"success": False, "message": f"Set terminate signal failed: {exc}", "sgbx_task_info": None} async def check_outline_terminate_signal(self, callback_task_id: str) -> bool: """检查任务是否存在终止信号。 Args: callback_task_id: 任务回调 ID Returns: True 表示有终止信号,False 表示无 """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}" exists = await redis_client.exists(terminate_key) if exists: logger.warning(f"Detected outline terminate signal: {callback_task_id}") return True return False except Exception as exc: logger.error(f"Check terminate signal failed: {exc}", exc_info=True) return False async def clear_outline_terminate_signal(self, callback_task_id: str): """清除任务的终止信号。 Args: callback_task_id: 任务回调 ID """ try: redis_client = await RedisConnectionFactory.get_connection() terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}" await redis_client.delete(terminate_key) except Exception as exc: logger.warning(f"Clear terminate signal failed: {exc}") async def get_outline_active_tasks(self) -> list: """获取当前正在处理中的大纲任务列表。 Returns: 活跃任务信息列表,包含任务 ID、用户、项目名、运行时长等 """ current_time = time.time() active_tasks = [] for task_id, task_info in self.active_outline_tasks.items(): if task_info.status == "processing": active_tasks.append( { "callback_task_id": task_id, "user_id": task_info.user_id, "project_name": task_info.project_name, "project_type": task_info.project_type, "status": task_info.status, "start_time": task_info.start_time, "running_duration": int(current_time - task_info.start_time) if task_info.start_time else 0, } ) return active_tasks async def get_outline_sgbx_task_info(self, callback_task_id: str) -> Optional[Dict[str, Any]]: """获取指定大纲任务的详细信息。 优先从内存中的活跃任务获取,其次从 Redis 中读取持久化结果。 Args: callback_task_id: 任务回调 ID Returns: 任务信息字典,未找到则返回 None """ task_info = self.active_outline_tasks.get(callback_task_id) if task_info: current_time = time.time() return { "callback_task_id": callback_task_id, "user_id": task_info.user_id, "project_name": task_info.project_name, "project_type": task_info.project_type, "status": task_info.status, "start_time": task_info.start_time, "running_duration": int(current_time - task_info.start_time) if task_info.start_time else 0, "results": task_info.results, } try: redis_client = await RedisConnectionFactory.get_connection() result_key = f"{self._outline_result_prefix}{callback_task_id}" result_data = await redis_client.hgetall(result_key) if not result_data: return None parsed_results = {} for key in ["outline_structure", "key_points", "similar_cases", "similar_fragments", "knowledge_bases"]: value = result_data.get(key) if value: try: parsed_results[key] = json.loads(value) except (json.JSONDecodeError, TypeError): parsed_results[key] = None else: parsed_results[key] = None overall_status = result_data.get("overall_task_status", "unknown") status = { "completed": "completed", "failed": "failed", "terminated": "cancelled", "pending": "pending", "processing": "processing", "submitted": "processing", }.get(overall_status, overall_status) response = { "callback_task_id": result_data.get("callback_task_id"), "user_id": result_data.get("user_id"), "project_name": result_data.get("project_name", ""), "project_type": result_data.get("project_type", ""), "status": status, "start_time": None, "running_duration": 0, "results": { "outline_structure": parsed_results.get("outline_structure"), "key_points": parsed_results.get("key_points"), "similar_cases": parsed_results.get("similar_cases"), "similar_fragments": parsed_results.get("similar_fragments"), "knowledge_bases": parsed_results.get("knowledge_bases"), "error": result_data.get("error_message") or None, }, } if result_data.get("pre_registered") == "true": response["is_pre_registered"] = True response["pre_registered_at"] = result_data.get("pre_registered_at") return response except Exception as exc: logger.error(f"Get outline task info failed: {exc}", exc_info=True) return None workflow_manager = WorkflowManager()