| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611 |
- # -*- coding: utf-8 -*-
- """
- 重新生成大纲接口 (SSE 版本)
- """
- import uuid
- import json
- import time
- import asyncio
- from typing import Optional, Dict, Any, List, AsyncGenerator, Union
- from pydantic import BaseModel, Field
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from core.base.workflow_manager import WorkflowManager
- from core.base.sse_manager import unified_sse_manager
- from core.base.progress_manager import ProgressManager
- from redis.asyncio import Redis as AsyncRedis
- # 创建路由
- regenerate_outline_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- # 初始化工作流管理器
- workflow_manager = WorkflowManager(
- max_concurrent_docs=3,
- max_concurrent_reviews=5
- )
- # 初始化进度管理器
- progress_manager = ProgressManager()
- async def sse_progress_callback(callback_task_id: str, current_data: dict):
- """SSE 推送回调函数 - 接收进度更新并推送到客户端"""
- await unified_sse_manager.send_progress(callback_task_id, current_data)
- def format_sse_event(event_type: str, data: str) -> str:
- """格式化 SSE 事件 - 按照 SSE 协议格式化事件数据"""
- lines = [
- f"event: {event_type}",
- f"data: {data}",
- "",
- ""
- ]
- return "\n".join(lines) + "\n"
- class BaseInfo(BaseModel):
- """项目基础信息"""
- project_name: str = Field(..., description="方案名称", example="罗成依达大桥上部结构专项施工方案")
- construct_location: str = Field(..., description="建设地点", example="四川省凉山州")
- engineering_type: str = Field(..., description="方案模版类型", example="T型梁")
- class ProjectInfo(BaseModel):
- """项目信息(嵌套结构)"""
- base_info: BaseInfo = Field(..., description="基础信息")
- selectable: Optional[str] = Field("", description="其他可选信息")
- class TemplateStructureItem(BaseModel):
- """模板结构项(支持嵌套children)"""
- index: str = Field(..., description="章节编号", example="2")
- level: int = Field(..., description="层级", ge=1, le=5)
- title: str = Field(..., description="章节标题", example="工程概况")
- code: str = Field(..., description="章节代码", example="overview")
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class GenerationTemplate(BaseModel):
- """大纲生成模板"""
- source_file: Optional[str] = Field(None, description="源文件", example="方案编写助手原文关键词规范文档修改版-2026-2-5.md")
- alias: Optional[str] = Field(None, description="别名", example="施工方案知识审查与编写体系")
- structure: List[Union[TemplateStructureItem, Dict[str, Any]]] = Field(..., description="模板结构")
- class RegenerateOutlineRequest(BaseModel):
- """重新生成大纲请求
- 复用大纲生成接口的请求定义,额外添加 regenerate_config 字段用于指定重新生成配置。
- project_info 和 generation_template 为可选字段,不传入则使用原任务的信息。
- 示例请求:
- {
- "task_id": "task-20250130-123456",
- "user_id": "user-001",
- "project_info": { // 可选,不传则使用原任务的项目信息
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": { // 可选,不传则使用原任务的模板
- "source_file": "...",
- "alias": "...",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_MainTechnicalStandards"], // 可选
- "regenerate_config": {
- "regenerate_mode": "chapter",
- "target_path": "2.1",
- "preserve_children": true,
- "reason": "调整内容结构"
- }
- }
- """
- task_id: str = Field(..., description="原大纲生成任务ID")
- user_id: str = Field(..., description="用户ID")
- project_info: Optional[ProjectInfo] = Field(None, description="项目基础信息(可选)")
- generation_template: Optional[GenerationTemplate] = Field(None, description="大纲生成模板(可选)")
- generation_chapterenum: Optional[List[str]] = Field(None, description="生成章节代码列表(可选)")
- regenerate_config: Dict[str, Any] = Field(..., description="重新生成配置")
- @regenerate_outline_router.post("/regenerate_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def regenerate_outline(request: RegenerateOutlineRequest):
- """
- 重新生成大纲接口 (SSE 流式响应)
- 【任务状态管理】
- - 重新生成会创建**新任务**,原任务状态保持不变
- - 新任务通过 regenerate_config.source_task_id 关联原任务
- - 原任务仍可查询,不受影响
- 【字段说明】
- - generation_chapterenum: 可选,默认使用原任务的章节列表
- - project_info: 可选,默认使用原任务的项目信息
- - generation_template: 可选,默认使用原任务的模板
- 【错误处理】
- - 原任务不存在: 返回 404 错误事件
- - 原任务已完成/失败: 允许重新生成(基于已完成结果进行局部调整)
- - 重新生成配置缺失: 返回 400 错误事件
- 【与 /generating_outline 的复用关系】
- - 复用 generating_outline 的核心 SSE 事件流生成逻辑
- - 差异点:1) 构建任务信息时合并原任务数据 2) 添加 regenerate_config 标记
- """
- # ===== 1. 参数校验 =====
- if not request.regenerate_config:
- logger.error("重新生成配置缺失")
- raise HTTPException(status_code=400, detail="regenerate_config 为必填项")
- # 生成新任务ID(重要:重新生成创建新任务,不覆盖原任务)
- new_callback_task_id = f"outline_regen_{uuid.uuid4().hex[:16]}"
- source_task_id = request.task_id # 原任务ID用于数据查询
- TraceContext.set_trace_id(new_callback_task_id)
- user_id = request.user_id
- regenerate_config = request.regenerate_config
- logger.info(f"接收重新生成大纲 SSE 请求: "
- f"source_task_id={source_task_id}, "
- f"new_task_id={new_callback_task_id}, "
- f"user_id={user_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}")
- # ===== 2. 获取原任务信息(带错误处理)=====
- original_task = None
- try:
- original_task = await workflow_manager.get_outline_sgbx_task_info(source_task_id)
- except Exception as e:
- logger.warning(f"获取原任务信息异常: {source_task_id}, error={e}")
- # 原任务不存在处理
- if not original_task:
- logger.error(f"原任务不存在: {source_task_id}")
- async def error_not_found():
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "原任务不存在",
- "status": "error",
- "message": f"原任务不存在或已过期: {source_task_id}",
- "overall_task_status": "failed",
- "error_code": "SOURCE_TASK_NOT_FOUND",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- return StreamingResponse(
- error_not_found(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # 获取原任务状态
- original_status = original_task.get("status") or original_task.get("overall_task_status", "unknown")
- logger.info(f"原任务状态: {source_task_id} = {original_status}")
- # 使用统一 SSE 管理器建立连接(使用新任务ID)
- queue = await unified_sse_manager.establish_connection(new_callback_task_id, sse_progress_callback)
- # ===== 3. 复用 generating_outline 的核心逻辑 =====
- async def generate_regenerate_events() -> AsyncGenerator[str, None]:
- """生成重新生成 SSE 事件流 - 复用 generating_outline 模式"""
- redis_check_client = None
- try:
- # ===== 3.1 初始化 Redis 连接(复用 generating_outline 模式)=====
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{new_callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # 定义取消检查函数(复用 generating_outline 模式)
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not new_callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{new_callback_task_id}") > 0
- except Exception:
- return False
- # ===== 3.2 检查取消(复用 generating_outline 检查点1)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.3 发送连接确认(复用 generating_outline 模式)=====
- connected_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": f"SSE 连接已建立,正在启动重新生成任务(原任务: {source_task_id}, 状态: {original_status})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # ===== 3.4 构建任务信息(合并原任务数据 + 新配置)=====
- # 优先使用传入的 project_info,否则使用原任务的
- if request.project_info:
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or ""
- }
- else:
- project_info_flat = original_task.get("project_info", {})
- # 处理 generation_template
- if request.generation_template:
- outline_structure = [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ]
- template_alias = request.generation_template.alias or "default_template"
- else:
- # 从原任务提取模板结构
- outline_structure = original_task.get("generation_template", [])
- if not outline_structure:
- outline_structure = original_task.get("results", {}).get("outline_structure", [])
- template_alias = original_task.get("template_id", "default_template")
- # 处理 generation_chapterenum(可选,默认使用原任务)
- generation_chapterenum = request.generation_chapterenum
- if generation_chapterenum is None:
- generation_chapterenum = original_task.get("generation_chapterenum", [])
- # 如果原任务也没有,则根据 regenerate_config.target_path 推断
- if not generation_chapterenum and regenerate_config.get("target_path"):
- target_path = regenerate_config.get("target_path")
- # 内嵌:根据路径查找章节代码的逻辑
- original_outline = original_task.get("results", {}).get("outline_structure", [])
- chapter_code = None
- if original_outline and target_path:
- path_parts = target_path.split(".")
-
- def search_in_nodes(nodes, depth=0):
- if depth >= len(path_parts):
- return None
- target_index = path_parts[depth]
- for node in nodes:
- node_index = str(node.get("index", ""))
- if node_index == target_index:
- if depth == len(path_parts) - 1:
- return node.get("code")
- children = node.get("children", [])
- if children:
- result = search_in_nodes(children, depth + 1)
- if result:
- return result
- return None
-
- chapter_code = search_in_nodes(original_outline)
-
- if chapter_code:
- generation_chapterenum = [chapter_code]
- # 构建完整任务信息(与 generating_outline 格式保持一致)
- sgbx_task_info = {
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id, # 关联原任务
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": template_alias,
- "generation_chapterenum": generation_chapterenum,
- "generation_template": outline_structure,
- "similarity_config": original_task.get("similarity_config", {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- }),
- "knowledge_config": original_task.get("knowledge_config", {
- "topk": 3,
- "threshold": 0.75
- }),
- # 重新生成特有配置
- "regenerate_config": regenerate_config,
- "is_regenerate": True,
- "original_task_status": original_status # 记录原任务状态
- }
- logger.info(f"重新生成任务信息构建完成: "
- f"new_task_id={new_callback_task_id}, "
- f"source_task_id={source_task_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}, "
- f"chapters={generation_chapterenum}")
- # ===== 3.5 检查取消(复用 generating_outline 检查点2)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.6 发送处理中事件(复用 generating_outline 模式)=====
- processing_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": f"正在提交重新生成任务(目标: {regenerate_config.get('target_path', 'unknown')})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # ===== 3.7 提交任务到 Celery(复用 generating_outline 模式)=====
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"重新生成任务已提交: "
- f"new_callback_task_id={new_callback_task_id}, "
- f"celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "重新生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # ===== 3.8 持续监听进度(完全复用 generating_outline 模式)=====
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- while True:
- try:
- # 检查取消(复用 generating_outline 检查点3)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(new_callback_task_id)
- if progress_data:
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
- # 检测到取消立即返回
- if status == "cancelled":
- logger.info(f"[{new_callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "重新生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {new_callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # ===== 3.9 获取最终结果(复用 generating_outline 模式)=====
- final_result = await workflow_manager.get_outline_sgbx_task_info(new_callback_task_id)
- # 检查取消(复用 generating_outline 检查点4)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 检查任务结果是否为已取消
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.10 返回最终结果(复用 generating_outline 模式)=====
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "重新生成完成",
- "status": "completed",
- "message": "大纲重新生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "重新生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"重新生成大纲 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(new_callback_task_id)
- return StreamingResponse(
- generate_regenerate_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
|