| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702 |
- import asyncio
- import json
- import time
- import uuid
- from typing import Any, Dict, Optional
- from langchain_core.messages import HumanMessage
- from foundation.infrastructure.cache.redis_connection import RedisConnectionFactory
- from foundation.infrastructure.tracing import TraceContext
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.observability.monitoring.time_statistics import track_execution_time
- from core.base.progress_manager import ProgressManager
- class ProgressManagerRegistry:
- """进度管理器注册表,用于在 LangGraph 工作流节点间共享 ProgressManager 实例。"""
- _registry: Dict[str, ProgressManager] = {}
- @classmethod
- def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
- """注册任务的进度管理器。
- Args:
- callback_task_id: 任务回调 ID
- progress_manager: ProgressManager 实例
- """
- cls._registry[callback_task_id] = progress_manager
- logger.info(f"注册 ProgressManager: {callback_task_id}")
- @classmethod
- def get_progress_manager(cls, callback_task_id: str) -> Optional[ProgressManager]:
- """获取指定任务的进度管理器。
- Args:
- callback_task_id: 任务回调 ID
- Returns:
- ProgressManager 实例,未找到则返回 None
- """
- return cls._registry.get(callback_task_id)
- @classmethod
- def unregister_progress_manager(cls, callback_task_id: Optional[str]):
- """注销指定任务的进度管理器。
- Args:
- callback_task_id: 任务回调 ID
- """
- if callback_task_id and callback_task_id in cls._registry:
- del cls._registry[callback_task_id]
- logger.info(f"注销 ProgressManager: {callback_task_id}")
- class WorkflowManager:
- """施工方案编写任务的工作流管理器。
- 负责:
- 1. 大纲生成任务的 Celery 提交与同步执行
- 2. LangGraph 工作流图的编译与调度
- 3. 任务终止信号管理(Redis 存储)
- 4. 任务结果的 Redis 持久化
- 5. 活跃任务状态查询
- """
- def __init__(self):
- self.progress_manager = ProgressManager()
- self.active_outline_tasks: Dict[str, Any] = {}
- self._outline_result_prefix = "outline_write:result:"
- self._outline_terminate_signal_prefix = "outline_write:terminate_signal:"
- self._task_expire_time = 7200
- self.outline_generation_graph = None
- def _run_async(self, coro):
- """在同步上下文中运行异步协程。
- Args:
- coro: 待执行的协程对象
- Returns:
- 协程的执行结果
- """
- loop = asyncio.new_event_loop()
- try:
- asyncio.set_event_loop(loop)
- return loop.run_until_complete(coro)
- finally:
- try:
- loop.close()
- finally:
- asyncio.set_event_loop(None)
- def _outline_progress_stages(self) -> list:
- """返回大纲生成任务的进度阶段定义。
- Returns:
- 阶段列表,每个阶段包含 stage 名称和 status
- """
- return [
- {"stage": "start", "status": "pending"},
- {"stage": "template_loading", "status": "pending"},
- {"stage": "outline_generation", "status": "pending"},
- {"stage": "similar_cases", "status": "pending"},
- {"stage": "similar_fragments", "status": "pending"},
- {"stage": "knowledge_bases", "status": "pending"},
- {"stage": "complete", "status": "pending"},
- ]
- async def submit_outline_generation_task(self, sgbx_task_info: dict) -> str:
- """提交大纲生成任务到 Celery 队列。
- 流程:预注册任务 → 初始化进度 → 提交 Celery → 更新状态
- Args:
- sgbx_task_info: 施工方案任务信息字典
- Returns:
- Celery 任务 ID
- Raises:
- Exception: Celery 任务提交失败时抛出
- """
- from foundation.infrastructure.messaging.tasks import submit_outline_generation_task
- callback_task_id = sgbx_task_info.get("callback_task_id")
- user_id = sgbx_task_info.get("user_id", "unknown")
- logger.info(
- f"Submit outline generation task: callback_task_id={callback_task_id}, user_id={user_id}"
- )
- await self._pre_register_outline_task(sgbx_task_info)
- await self.progress_manager.initialize_progress(
- callback_task_id=callback_task_id,
- user_id=user_id,
- stages=self._outline_progress_stages(),
- )
- await self.progress_manager.update_stage_progress(
- callback_task_id=callback_task_id,
- user_id=user_id,
- current=5,
- stage_name="任务提交中",
- status="processing",
- message="正在提交大纲生成任务...",
- overall_task_status="processing",
- )
- kwargs = {}
- trace_id = TraceContext.get_trace_id()
- if trace_id and trace_id != "no-trace":
- kwargs["_system_trace_id"] = trace_id
- try:
- task = submit_outline_generation_task.apply_async(
- args=[sgbx_task_info],
- kwargs=kwargs,
- queue="construction_write",
- )
- except Exception as exc:
- await self.progress_manager.update_stage_progress(
- callback_task_id=callback_task_id,
- user_id=user_id,
- current=5,
- stage_name="任务提交失败",
- status="failed",
- message=f"提交大纲生成任务失败: {exc}",
- overall_task_status="failed",
- event_type="failed",
- )
- raise
- await self._update_outline_result_status(
- callback_task_id,
- overall_task_status="processing",
- celery_task_id=task.id,
- submitted_time=str(time.time()),
- )
- await self.progress_manager.update_stage_progress(
- callback_task_id=callback_task_id,
- user_id=user_id,
- current=10,
- stage_name="任务已提交",
- status="submitted",
- message="大纲生成任务已提交,正在执行...",
- overall_task_status="processing",
- )
- logger.info(f"Outline generation Celery task submitted: {task.id}")
- return task.id
- @track_execution_time
- def submit_outline_generation_sync(self, sgbx_task_info: dict) -> dict:
- """同步执行大纲生成任务(Celery worker 内部调用)。
- Args:
- sgbx_task_info: 施工方案任务信息字典
- Returns:
- 包含大纲结构、关键要点、相似推荐等结果的字典
- """
- callback_task_id = sgbx_task_info.get("callback_task_id") or f"outline_{uuid.uuid4().hex[:16]}"
- return self._run_async(self._execute_outline_generation(sgbx_task_info, callback_task_id))
- async def _execute_outline_generation(self, sgbx_task_info: dict, callback_task_id: str) -> dict:
- """执行大纲生成工作流的核心方法。
- 完整流程:检查终止信号 → 构建初始状态 → 编译/加载 LangGraph 图 →
- 执行工作流 → 保存结果到 Redis
- Args:
- sgbx_task_info: 施工方案任务信息字典
- callback_task_id: 任务回调 ID
- Returns:
- 包含大纲结构、关键要点、相似推荐等结果的字典
- """
- from core.construction_write.component.state_models import OutlineGenerationState, OutlineTaskInfo
- from core.construction_write.workflows.outline_workflow import OutlineWorkflow
- user_id = sgbx_task_info.get("user_id", "unknown")
- outline_sgbx_task_info: Optional[OutlineTaskInfo] = None
- try:
- logger.info(f"Start outline generation workflow: {callback_task_id}")
- if await self.check_outline_terminate_signal(callback_task_id):
- return self._terminated_result(callback_task_id, user_id, "Task was cancelled before start")
- outline_sgbx_task_info = OutlineTaskInfo(
- callback_task_id=callback_task_id,
- user_id=user_id,
- project_info=sgbx_task_info.get("project_info", {}),
- template_id=sgbx_task_info.get("template_id", ""),
- generation_chapterenum=sgbx_task_info.get("generation_chapterenum", []),
- generation_template=sgbx_task_info.get("generation_template", []),
- similarity_config=sgbx_task_info.get(
- "similarity_config",
- {"topk_plans": 3, "topk_fragments": 10, "threshold": 0.75},
- ),
- knowledge_config=sgbx_task_info.get(
- "knowledge_config",
- {"topk": 3, "threshold": 0.75},
- ),
- )
- self.active_outline_tasks[callback_task_id] = outline_sgbx_task_info
- await self._update_outline_result_status(
- callback_task_id,
- overall_task_status="processing",
- worker_started_at=str(time.time()),
- )
- await self.progress_manager.initialize_progress(
- callback_task_id=callback_task_id,
- user_id=user_id,
- stages=self._outline_progress_stages(),
- )
- ProgressManagerRegistry.register_progress_manager(callback_task_id, self.progress_manager)
- outline_sgbx_task_info.start_processing()
- if self.outline_generation_graph is None:
- self.outline_generation_graph = OutlineWorkflow().build_graph()
- initial_state = OutlineGenerationState(
- callback_task_id=callback_task_id,
- user_id=user_id,
- project_info=outline_sgbx_task_info.project_info,
- template_id=outline_sgbx_task_info.template_id,
- generation_chapterenum=outline_sgbx_task_info.generation_chapterenum,
- generation_template=outline_sgbx_task_info.generation_template,
- similarity_config=outline_sgbx_task_info.similarity_config,
- knowledge_config=outline_sgbx_task_info.knowledge_config,
- template=None,
- outline_structure=None,
- key_points=None,
- similar_cases=None,
- similar_fragments=None,
- knowledge_bases=None,
- current_stage="start",
- overall_task_status="processing",
- error_message=None,
- messages=[HumanMessage(content=f"Start outline generation task {callback_task_id}")],
- )
- result = await self.outline_generation_graph.ainvoke(
- initial_state,
- config={"configurable": {"thread_id": callback_task_id}},
- )
- status = result.get("overall_task_status")
- if status == "completed":
- outline_sgbx_task_info.complete_processing(
- {
- "outline_structure": result.get("outline_structure"),
- "key_points": result.get("key_points"),
- "similar_cases": result.get("similar_cases"),
- "similar_fragments": result.get("similar_fragments"),
- "knowledge_bases": result.get("knowledge_bases"),
- }
- )
- elif status == "failed":
- outline_sgbx_task_info.fail_processing(result.get("error_message", "unknown error"))
- elif status == "terminated":
- outline_sgbx_task_info.cancel_processing()
- await self._save_result_to_redis(callback_task_id, user_id, result)
- return {
- "callback_task_id": result.get("callback_task_id"),
- "user_id": result.get("user_id"),
- "overall_task_status": result.get("overall_task_status"),
- "outline_structure": result.get("outline_structure"),
- "key_points": result.get("key_points"),
- "similar_cases": result.get("similar_cases"),
- "similar_fragments": result.get("similar_fragments"),
- "knowledge_bases": result.get("knowledge_bases"),
- "error_message": result.get("error_message"),
- }
- except Exception as exc:
- logger.error(f"Outline generation task failed: {exc}", exc_info=True)
- if outline_sgbx_task_info:
- outline_sgbx_task_info.fail_processing(str(exc))
- error_message = str(exc)
- failed_result = {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "overall_task_status": "failed",
- "outline_structure": None,
- "key_points": None,
- "similar_cases": None,
- "similar_fragments": None,
- "knowledge_bases": None,
- "error_message": error_message,
- }
- try:
- await self.progress_manager.update_stage_progress(
- callback_task_id=callback_task_id,
- user_id=user_id,
- stage_name="任务失败",
- status="failed",
- message=f"大纲生成任务失败: {error_message}",
- overall_task_status="failed",
- event_type="failed",
- )
- await self._save_result_to_redis(callback_task_id, user_id, failed_result)
- except Exception as update_exc:
- logger.warning(
- f"Failed to persist outline failure state: {callback_task_id}, {update_exc}",
- exc_info=True,
- )
- raise
- finally:
- self.active_outline_tasks.pop(callback_task_id, None)
- ProgressManagerRegistry.unregister_progress_manager(callback_task_id)
- def _terminated_result(self, callback_task_id: str, user_id: str, message: str) -> dict:
- """构建任务终止时的返回结果。
- Args:
- callback_task_id: 任务回调 ID
- user_id: 用户 ID
- message: 终止原因
- Returns:
- 终止状态字典
- """
- return {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "overall_task_status": "terminated",
- "outline_structure": None,
- "key_points": None,
- "similar_cases": None,
- "similar_fragments": None,
- "knowledge_bases": None,
- "error_message": message,
- }
- async def _save_result_to_redis(self, callback_task_id: str, user_id: str, result: dict):
- """将大纲生成结果保存到 Redis。
- Args:
- callback_task_id: 任务回调 ID
- user_id: 用户 ID
- result: 工作流执行结果字典
- """
- redis_client = await RedisConnectionFactory.get_connection()
- result_key = f"{self._outline_result_prefix}{callback_task_id}"
- result_data = {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "overall_task_status": result.get("overall_task_status", ""),
- "outline_structure": json.dumps(result.get("outline_structure"), ensure_ascii=False)
- if result.get("outline_structure")
- else "",
- "key_points": json.dumps(result.get("key_points"), ensure_ascii=False)
- if result.get("key_points")
- else "",
- "similar_cases": json.dumps(result.get("similar_cases"), ensure_ascii=False)
- if result.get("similar_cases")
- else "",
- "similar_fragments": json.dumps(result.get("similar_fragments"), ensure_ascii=False)
- if result.get("similar_fragments")
- else "",
- "knowledge_bases": json.dumps(result.get("knowledge_bases"), ensure_ascii=False)
- if result.get("knowledge_bases")
- else "",
- "error_message": result.get("error_message") or "",
- "completed_time": str(time.time()),
- }
- await redis_client.hmset(result_key, result_data)
- await redis_client.expire(result_key, self._task_expire_time)
- logger.info(f"Outline generation result saved to Redis: {callback_task_id}")
- async def _update_outline_result_status(self, callback_task_id: str, **fields):
- """更新 Redis 中任务状态的字段。
- Args:
- callback_task_id: 任务回调 ID
- **fields: 需要更新的字段键值对
- """
- if not callback_task_id:
- return
- try:
- redis_client = await RedisConnectionFactory.get_connection()
- result_key = f"{self._outline_result_prefix}{callback_task_id}"
- fields = {key: "" if value is None else str(value) for key, value in fields.items()}
- if not fields:
- return
- await redis_client.hmset(result_key, fields)
- await redis_client.expire(result_key, self._task_expire_time)
- except Exception as exc:
- logger.warning(f"Update outline result status failed: {callback_task_id}, {exc}")
- async def _pre_register_outline_task(self, sgbx_task_info: dict):
- """预注册任务信息到 Redis(在 Celery 提交前调用,用于进度查询兜底)。
- Args:
- sgbx_task_info: 施工方案任务信息字典
- """
- try:
- callback_task_id = sgbx_task_info.get("callback_task_id")
- user_id = sgbx_task_info.get("user_id", "unknown")
- project_info = sgbx_task_info.get("project_info", {})
- redis_client = await RedisConnectionFactory.get_connection()
- result_key = f"{self._outline_result_prefix}{callback_task_id}"
- await redis_client.hmset(
- result_key,
- {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "project_name": project_info.get("project_name", ""),
- "project_type": project_info.get("engineering_type", ""),
- "overall_task_status": "pending",
- "outline_structure": "",
- "key_points": "",
- "similar_cases": "",
- "similar_fragments": "",
- "knowledge_bases": "",
- "error_message": "",
- "celery_task_id": "",
- "pre_registered": "true",
- "pre_registered_at": str(time.time()),
- "completed_time": "",
- },
- )
- await redis_client.expire(result_key, self._task_expire_time)
- except Exception as exc:
- logger.error(f"Pre-register outline task failed: {exc}", exc_info=True)
- async def set_outline_terminate_signal(
- self,
- callback_task_id: str,
- operator: str = "unknown",
- reason: str = "",
- ) -> Dict[str, Any]:
- """设置大纲生成任务的终止信号。
- Args:
- callback_task_id: 任务回调 ID
- operator: 操作人标识
- reason: 终止原因
- Returns:
- 包含 success/message/sgbx_task_info 的结果字典
- """
- try:
- task_status = None
- task_user_id = "unknown"
- project_name = ""
- if callback_task_id in self.active_outline_tasks:
- task_info = self.active_outline_tasks[callback_task_id]
- task_status = task_info.status
- task_user_id = task_info.user_id
- project_name = task_info.project_name
- else:
- redis_client = await RedisConnectionFactory.get_connection()
- result_key = f"{self._outline_result_prefix}{callback_task_id}"
- result_data = await redis_client.hgetall(result_key)
- if not result_data:
- return {"success": False, "message": f"Task not found: {callback_task_id}", "sgbx_task_info": None}
- task_status = result_data.get("overall_task_status", "unknown")
- task_user_id = result_data.get("user_id", "unknown")
- project_name = result_data.get("project_name", "")
- if task_status not in ["pending", "processing", "submitted"]:
- return {
- "success": False,
- "message": f"Task cannot be cancelled in status {task_status}: {callback_task_id}",
- "sgbx_task_info": {
- "callback_task_id": callback_task_id,
- "status": task_status,
- "project_name": project_name,
- },
- }
- redis_client = await RedisConnectionFactory.get_connection()
- terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
- await redis_client.hmset(
- terminate_key,
- {
- "operator": operator,
- "reason": reason,
- "terminate_time": str(time.time()),
- "task_id": callback_task_id,
- },
- )
- await redis_client.expire(terminate_key, self._task_expire_time)
- await self._update_outline_result_status(
- callback_task_id,
- overall_task_status="terminated",
- error_message=reason or "Task cancellation requested",
- terminated_by=operator,
- terminated_time=str(time.time()),
- )
- if task_status == "pending":
- return {
- "success": True,
- "message": "Task cancelled before start",
- "sgbx_task_info": {
- "callback_task_id": callback_task_id,
- "user_id": task_user_id,
- "project_name": project_name,
- "status": "cancelled",
- },
- }
- return {
- "success": True,
- "message": "Terminate signal set",
- "sgbx_task_info": {
- "callback_task_id": callback_task_id,
- "user_id": task_user_id,
- "project_name": project_name,
- "status": "cancelled",
- },
- }
- except Exception as exc:
- logger.error(f"Set terminate signal failed: {exc}", exc_info=True)
- return {"success": False, "message": f"Set terminate signal failed: {exc}", "sgbx_task_info": None}
- async def check_outline_terminate_signal(self, callback_task_id: str) -> bool:
- """检查任务是否存在终止信号。
- Args:
- callback_task_id: 任务回调 ID
- Returns:
- True 表示有终止信号,False 表示无
- """
- try:
- redis_client = await RedisConnectionFactory.get_connection()
- terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
- exists = await redis_client.exists(terminate_key)
- if exists:
- logger.warning(f"Detected outline terminate signal: {callback_task_id}")
- return True
- return False
- except Exception as exc:
- logger.error(f"Check terminate signal failed: {exc}", exc_info=True)
- return False
- async def clear_outline_terminate_signal(self, callback_task_id: str):
- """清除任务的终止信号。
- Args:
- callback_task_id: 任务回调 ID
- """
- try:
- redis_client = await RedisConnectionFactory.get_connection()
- terminate_key = f"{self._outline_terminate_signal_prefix}{callback_task_id}"
- await redis_client.delete(terminate_key)
- except Exception as exc:
- logger.warning(f"Clear terminate signal failed: {exc}")
- async def get_outline_active_tasks(self) -> list:
- """获取当前正在处理中的大纲任务列表。
- Returns:
- 活跃任务信息列表,包含任务 ID、用户、项目名、运行时长等
- """
- current_time = time.time()
- active_tasks = []
- for task_id, task_info in self.active_outline_tasks.items():
- if task_info.status == "processing":
- active_tasks.append(
- {
- "callback_task_id": task_id,
- "user_id": task_info.user_id,
- "project_name": task_info.project_name,
- "project_type": task_info.project_type,
- "status": task_info.status,
- "start_time": task_info.start_time,
- "running_duration": int(current_time - task_info.start_time)
- if task_info.start_time
- else 0,
- }
- )
- return active_tasks
- async def get_outline_sgbx_task_info(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
- """获取指定大纲任务的详细信息。
- 优先从内存中的活跃任务获取,其次从 Redis 中读取持久化结果。
- Args:
- callback_task_id: 任务回调 ID
- Returns:
- 任务信息字典,未找到则返回 None
- """
- task_info = self.active_outline_tasks.get(callback_task_id)
- if task_info:
- current_time = time.time()
- return {
- "callback_task_id": callback_task_id,
- "user_id": task_info.user_id,
- "project_name": task_info.project_name,
- "project_type": task_info.project_type,
- "status": task_info.status,
- "start_time": task_info.start_time,
- "running_duration": int(current_time - task_info.start_time) if task_info.start_time else 0,
- "results": task_info.results,
- }
- try:
- redis_client = await RedisConnectionFactory.get_connection()
- result_key = f"{self._outline_result_prefix}{callback_task_id}"
- result_data = await redis_client.hgetall(result_key)
- if not result_data:
- return None
- parsed_results = {}
- for key in ["outline_structure", "key_points", "similar_cases", "similar_fragments", "knowledge_bases"]:
- value = result_data.get(key)
- if value:
- try:
- parsed_results[key] = json.loads(value)
- except (json.JSONDecodeError, TypeError):
- parsed_results[key] = None
- else:
- parsed_results[key] = None
- overall_status = result_data.get("overall_task_status", "unknown")
- status = {
- "completed": "completed",
- "failed": "failed",
- "terminated": "cancelled",
- "pending": "pending",
- "processing": "processing",
- "submitted": "processing",
- }.get(overall_status, overall_status)
- response = {
- "callback_task_id": result_data.get("callback_task_id"),
- "user_id": result_data.get("user_id"),
- "project_name": result_data.get("project_name", ""),
- "project_type": result_data.get("project_type", ""),
- "status": status,
- "start_time": None,
- "running_duration": 0,
- "results": {
- "outline_structure": parsed_results.get("outline_structure"),
- "key_points": parsed_results.get("key_points"),
- "similar_cases": parsed_results.get("similar_cases"),
- "similar_fragments": parsed_results.get("similar_fragments"),
- "knowledge_bases": parsed_results.get("knowledge_bases"),
- "error": result_data.get("error_message") or None,
- },
- }
- if result_data.get("pre_registered") == "true":
- response["is_pre_registered"] = True
- response["pre_registered_at"] = result_data.get("pre_registered_at")
- return response
- except Exception as exc:
- logger.error(f"Get outline task info failed: {exc}", exc_info=True)
- return None
- workflow_manager = WorkflowManager()
|