| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755 |
- # -*- coding: utf-8 -*-
- """
- 大纲生成 API 接口 (SSE 版本)
- 集成到 Celery + WorkflowManager 架构中
- 提供以下接口:
- - SSE /sgbx/generating_outline: SSE 流式大纲生成
- - SSE /sgbx/regenerate_outline: SSE 流式重新生成
- - POST /sgbx/task_cancel: 取消大纲生成任务
- - POST /sgbx/context_generate: SSE 流式上下文生成 (新增)
- """
- import uuid
- import json
- import time
- import asyncio
- import aiohttp
- from typing import Optional, Dict, Any, List, AsyncGenerator, Union
- from pydantic import BaseModel, Field
- from fastapi import APIRouter, HTTPException, Query
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from foundation.infrastructure.config.config import config_handler
- from core.base.workflow_manager import WorkflowManager
- from core.base.sse_manager import unified_sse_manager
- from core.base.progress_manager import ProgressManager
- from redis.asyncio import Redis as AsyncRedis
- # 创建路由
- outline_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- # 初始化工作流管理器
- workflow_manager = WorkflowManager(
- max_concurrent_docs=3,
- max_concurrent_reviews=5
- )
- # 初始化进度管理器
- progress_manager = ProgressManager()
- async def sse_progress_callback(callback_task_id: str, current_data: dict):
- """SSE 推送回调函数 - 接收进度更新并推送到客户端"""
- await unified_sse_manager.send_progress(callback_task_id, current_data)
- def format_sse_event(event_type: str, data: str) -> str:
- """格式化 SSE 事件 - 按照 SSE 协议格式化事件数据"""
- lines = [
- f"event: {event_type}",
- f"data: {data}",
- "",
- ""
- ]
- return "\n".join(lines) + "\n"
- # ==================== 请求/响应模型 ====================
- class BaseInfo(BaseModel):
- """项目基础信息"""
- project_name: str = Field(..., description="方案名称", example="罗成依达大桥上部结构专项施工方案")
- construct_location: str = Field(..., description="建设地点", example="四川省凉山州")
- engineering_type: str = Field(..., description="方案模版类型", example="T型梁")
- class ProjectInfo(BaseModel):
- """项目信息(嵌套结构)"""
- base_info: BaseInfo = Field(..., description="基础信息")
- selectable: Optional[str] = Field("", description="其他可选信息")
- class TemplateStructureItem(BaseModel):
- """模板结构项(支持嵌套children)"""
- index: str = Field(..., description="章节编号", example="2")
- level: int = Field(..., description="层级", ge=1, le=5)
- title: str = Field(..., description="章节标题", example="工程概况")
- code: str = Field(..., description="章节代码", example="overview")
- # 使用 Union 支持递归类型
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class GenerationTemplate(BaseModel):
- """大纲生成模板
- 示例:
- {
- "source_file": "方案编写助手原文关键词规范文档修改版-2026-2-5.md",
- "alias": "施工方案知识审查与编写体系",
- "structure": [
- {
- "index": "2",
- "level": 1,
- "title": "工程概况",
- "code": "overview",
- "children": [...]
- }
- ]
- }
- """
- source_file: Optional[str] = Field(None, description="源文件", example="方案编写助手原文关键词规范文档修改版-2026-2-5.md")
- alias: Optional[str] = Field(None, description="别名", example="施工方案知识审查与编写体系")
- structure: List[Union[TemplateStructureItem, Dict[str, Any]]] = Field(..., description="模板结构")
- class OutlineGenerationRequest(BaseModel):
- """大纲生成请求
- 示例请求体(适配curl示例):
- {
- "user_id": "user-001",
- "project_info": {
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": {
- "source_file": "方案编写助手原文关键词规范文档修改版-2026-2-5.md",
- "alias": "施工方案知识审查与编写体系",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_ProjectIntroduction", ...]
- }
- """
- user_id: str = Field(..., description="用户标识", example="user-001")
- project_info: ProjectInfo = Field(..., description="项目基础信息")
- generation_template: GenerationTemplate = Field(..., description="大纲生成模板")
- generation_chapterenum: List[str] = Field(default_factory=list, description="生成章节代码列表,为空时生成全部章节")
- class RegenerateOutlineRequest(BaseModel):
- """重新生成大纲请求
- 复用大纲生成接口的请求定义,额外添加 regenerate_config 字段用于指定重新生成配置。
- project_info 和 generation_template 为可选字段,不传入则使用原任务的信息。
- 示例请求:
- {
- "task_id": "task-20250130-123456",
- "user_id": "user-001",
- "project_info": { // 可选,不传则使用原任务的项目信息
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": { // 可选,不传则使用原任务的模板
- "source_file": "...",
- "alias": "...",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_MainTechnicalStandards"], // 可选
- "regenerate_config": {
- "regenerate_mode": "chapter",
- "target_path": "2.1",
- "preserve_children": true,
- "reason": "调整内容结构"
- }
- }
- """
- task_id: str = Field(..., description="原大纲生成任务ID")
- user_id: str = Field(..., description="用户ID")
- # 可选:复用大纲生成接口的项目信息(不传则使用原任务的)
- project_info: Optional[ProjectInfo] = Field(None, description="项目基础信息(可选)")
- # 可选:复用大纲生成接口的模板(不传则使用原任务的)
- generation_template: Optional[GenerationTemplate] = Field(None, description="大纲生成模板(可选)")
- # 可选:复用大纲生成接口的章节代码列表
- generation_chapterenum: Optional[List[str]] = Field(None, description="生成章节代码列表(可选)")
- # 重新生成特有的配置
- regenerate_config: Dict[str, Any] = Field(..., description="重新生成配置")
- class TaskCancelRequest(BaseModel):
- """任务取消请求"""
- task_id: str = Field(..., description="任务ID")
- user_id: str = Field(..., description="用户ID")
- cancel_reason: Optional[str] = Field("用户主动取消", description="取消原因")
- # ==================== 响应模型 ====================
- class OutlineNodeResponse(BaseModel):
- """大纲节点响应模型
- 与请求的 TemplateStructureItem 对应,增加以下字段:
- - 每级节点都包含 generated_content
- - 2级和3级节点包含 similar_fragments
- - 末级节点额外包含 key_points 和 knowledge_bases
- """
- index: str = Field(..., description="章节编号", example="2.1.1")
- level: int = Field(..., description="层级", ge=1, le=5, example=3)
- title: str = Field(..., description="章节标题", example="工程简介")
- code: str = Field(..., description="章节代码", example="overview_DesignSummary_ProjectIntroduction")
- generated_content: str = Field(..., description="AI生成的内容", example="罗成依达大桥位于四川省凉山州...")
- # 2级和3级节点包含 similar_fragments
- similar_fragments: Optional[List[Dict[str, Any]]] = Field(None, description="相似片段推荐(2级和3级节点)")
- key_points: Optional[List[str]] = Field(None, description="核心要点(仅末级节点)", example=["桥位位置", "桥梁规模"])
- knowledge_bases: Optional[List[str]] = Field(None, description="知识点/编制依据(仅末级节点)", example=["《公路桥涵设计通用规范》JTG D60-2015"])
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class SimilarPlanResponse(BaseModel):
- """相似方案响应(整篇方案级推荐)"""
- plan_id: str = Field(..., description="方案ID")
- plan_title: str = Field(..., description="方案标题")
- similarity_score: float = Field(..., description="相似度分数", ge=0.0, le=1.0)
- plan_type: str = Field(..., description="方案类型")
- outline: Optional[List[Dict]] = Field(None, description="方案大纲结构")
- metadata: Optional[Dict[str, Any]] = Field(None, description="元数据")
- class SimilarFragmentResponse(BaseModel):
- """相似片段响应"""
- fragment_id: str = Field(..., description="片段ID")
- section_path: str = Field(..., description="所属章节路径")
- section_title: str = Field(..., description="章节标题")
- fragment_content: str = Field(..., description="片段内容")
- similarity_score: float = Field(..., description="相似度分数", ge=0.0, le=1.0)
- source_document_id: str = Field(..., description="来源文档ID")
- source_document_title: str = Field(..., description="来源文档标题")
- class OutlineGenerationResult(BaseModel):
- """大纲生成结果
- 结构说明:
- - outline_structure: 嵌套大纲结构,每个节点包含 generated_content
- - 2级和3级节点额外包含 similar_fragments
- - 末级节点额外包含 key_points 和 knowledge_bases
- - similar_plan: 整篇方案的相似方案推荐(顶层)
- """
- outline_structure: List[OutlineNodeResponse] = Field(..., description="大纲结构(包含AI生成内容和章节级similar_fragments)")
- similar_plan: List[SimilarPlanResponse] = Field(default_factory=list, description="相似方案推荐(整篇方案级)")
- # ==================== 上下文生成新增模型 ====================
- class CompletionConfig(BaseModel):
- section_path: str = Field(..., description="章节路径")
- current_content: str = Field(default="", description="当前已有内容")
- context_window: int = Field(default=2000, ge=500, le=5000)
- completion_mode: str = Field(default="continue", description="模式")
- target_length: int = Field(default=1000, ge=100, le=5000)
- include_references: bool = Field(default=True)
- style_match: bool = Field(default=True)
- hint_keywords: Optional[List[str]] = Field(default=None)
- class ProjectInfoSimple(BaseModel):
- project_name: str = Field(default="施工方案")
- construct_location: Optional[str] = Field(default=None)
- engineering_type: Optional[str] = Field(default=None)
- class ContextGenerateRequest(BaseModel):
- task_id: Optional[str] = Field(default=None)
- user_id: str = Field(...)
- project_info: Optional[ProjectInfoSimple] = Field(default=None)
- completion_config: CompletionConfig = Field(...)
- model_name: Optional[str] = Field(default=None)
- class Config: extra = "forbid"
- class ContextGenerateResponse(BaseModel):
- code: int
- message: str
- data: Optional[Dict[str, Any]] = None
- # ==================== 全局资源池 (速度优化核心) ====================
- GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None
- async def init_global_resources():
- """初始化全局连接池"""
- global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT
-
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- # 增加 DNS 缓存和连接复用,针对阿里云域名优化
- connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False)
- GLOBAL_HTTP_SESSION = aiohttp.ClientSession(
- timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10), # 连接超时稍长以防网络波动
- connector=connector,
- headers={"User-Agent": "FastAPI-DashScope-Optimized/2.0"}
- )
- logger.info("✅ 全局 HTTP 连接池已初始化 (DashScope Ready)")
- if GLOBAL_REDIS_CLIENT is None:
- try:
- GLOBAL_REDIS_CLIENT = AsyncRedis(
- host='127.0.0.1', port=6379, password='123456', db=0,
- decode_responses=True, socket_connect_timeout=1,
- socket_keepalive=True, max_connections=50
- )
- asyncio.create_task(_background_ping())
- logger.info("✅ 全局 Redis 连接池已初始化")
- except Exception as e:
- logger.warning(f"⚠️ Redis 初始化失败: {e}")
- GLOBAL_REDIS_CLIENT = None
- async def _background_ping():
- if GLOBAL_REDIS_CLIENT:
- try: await GLOBAL_REDIS_CLIENT.ping()
- except: pass
- async def get_http_session():
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- await init_global_resources()
- return GLOBAL_HTTP_SESSION
- async def get_redis_client():
- if GLOBAL_REDIS_CLIENT is None:
- await init_global_resources()
- return GLOBAL_REDIS_CLIENT
- # ==================== 自定义 API 配置 (阿里云 DashScope) ====================
- class CustomAPIConfig:
- # 【关键修改】阿里云 DashScope 兼容模式地址
- # 注意:必须包含 /chat/completions 后缀
- DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
- DASHSCOPE_CHAT_URL = f"{DASHSCOPE_BASE_URL}/chat/completions"
-
- # 【关键修改】您的 API Key
- DASHSCOPE_API_KEY = "sk-ae805c991b6a4a8da3a09351c34963a5"
-
- # 【关键修改】目标模型
- DEFAULT_MODEL_NAME = "qwen3-30b-a3b-instruct-2507"
-
- @staticmethod
- def get_api_url() -> str:
- # 优先使用硬编码的阿里云地址
- return CustomAPIConfig.DASHSCOPE_CHAT_URL
-
- @staticmethod
- def get_api_key() -> str:
- return CustomAPIConfig.DASHSCOPE_API_KEY
-
- @staticmethod
- def get_model_name() -> str:
- # 允许配置覆盖,否则使用默认
- configured_model = config_handler.get("custom_api", "MODEL_NAME", "")
- return configured_model if configured_model else CustomAPIConfig.DEFAULT_MODEL_NAME
-
- @staticmethod
- def is_enabled() -> bool:
- # 只要 Key 不为空即启用
- return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url())
- # ==================== 极速流式调用 (核心优化) ====================
- async def call_custom_api_stream(
- prompt: str, system_prompt: str = "", max_tokens: int = 2000,
- temperature: float = 0.7, trace_id: str = ""
- ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
-
- api_url = CustomAPIConfig.get_api_url()
- model_name = CustomAPIConfig.get_model_name()
- api_key = CustomAPIConfig.get_api_key()
-
- logger.debug(f"[{trace_id}] 正在调用阿里云 DashScope: {model_name} @ {api_url}")
- # 截断过长的 Prompt (阿里云对输入长度有限制,且为了速度)
- max_prompt_len = 10000
- if len(prompt) > max_prompt_len:
- prompt = prompt[-max_prompt_len:]
- logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
- payload = {
- "model": model_name,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt}
- ],
- "max_tokens": max_tokens,
- "temperature": temperature,
- "stream": True,
- "incremental_output": True # 阿里云兼容模式可能支持此参数,优化流式体验
- }
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}"
- }
-
- start_time = time.time()
- first_token_time: Optional[float] = None
- buffer = ""
-
- session = await get_http_session()
-
- try:
- # 阿里云 HTTPS 连接,保持 read_bufsize=1 以获取最快首字
- async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response:
- if response.status != 200:
- error_text = await response.text()
- logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}")
- raise Exception(f"API 错误 {response.status}: {error_text}")
-
- async for chunk in response.content.iter_any():
- if not chunk: continue
- try:
- text = chunk.decode('utf-8', errors='ignore')
- if not text: continue
- buffer += text
-
- while '\n' in buffer:
- line, buffer = buffer.split('\n', 1)
- line = line.strip()
-
- if line.startswith('data: '):
- data = line[6:]
- if data == '[DONE]':
- return
-
- try:
- event_data = json.loads(data)
- # 处理阿里云可能的错误格式
- if "error" in event_data:
- err_msg = event_data["error"].get("message", "Unknown Error")
- logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}")
- continue
- choices = event_data.get("choices", [])
- if choices:
- delta = choices[0].get("delta", {})
- content = delta.get("content", "")
-
- if content:
- if first_token_time is None:
- first_token_time = time.time() - start_time
- yield (content, first_token_time)
- except json.JSONDecodeError:
- continue
- except UnicodeDecodeError:
- continue
- except Exception as e:
- logger.error(f"[{trace_id}] API 流式请求异常: {e}")
- raise
- # ==================== 上下文生成业务逻辑辅助 ====================
- CONTEXT_GENERATE_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
- def build_context_generate_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
- parts = []
- parts.append(f"【项目】{project_info.get('project_name', '未知')}")
- parts.append(f"【章节】{section_title} ({section_path})")
- parts.append(f"【模式】{completion_mode} (目标:{target_length})")
-
- if context_before: parts.append(f"【前文】...{context_before[-500:]}")
- if current_content: parts.append(f"【当前】{current_content}")
- if context_after: parts.append(f"【后文】{context_after[:500]}...")
-
- parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
- return "\n".join(parts)
- def extract_chunk_content(chunk: Any) -> str:
- if isinstance(chunk, str): return chunk
- if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
- if isinstance(chunk, dict): return str(chunk.get('content', ''))
- return str(chunk)
- def validate_user_id(user_id: str):
- supported_users = {'user-001', 'user-002', 'user-003'}
- if user_id not in supported_users:
- raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
- def validate_completion_config(config: CompletionConfig):
- if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
- raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
- def validate_request(request: ContextGenerateRequest):
- if not request.task_id and not request.project_info:
- raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
- # ==================== 上下文生成核心流式逻辑 ====================
- async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
- async def is_cancelled() -> bool:
- if not redis_client: return False
- try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
- except: return False
- stream_start_time = time.time()
- first_token_latency: Optional[float] = None
- full_content_parts: List[str] = []
- chunk_count = 0
- try:
- yield format_sse_event("connected", json.dumps({
- "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
- }, ensure_ascii=False))
- project_info = request.project_info.dict() if request.project_info else {}
- section_title = f"章节 {request.completion_config.section_path}"
-
- user_prompt = build_context_generate_prompt(
- project_info=project_info,
- section_path=request.completion_config.section_path,
- section_title=section_title,
- current_content=request.completion_config.current_content,
- completion_mode=request.completion_config.completion_mode,
- target_length=request.completion_config.target_length,
- include_references=request.completion_config.include_references,
- style_match=request.completion_config.style_match,
- hint_keywords=request.completion_config.hint_keywords
- )
- yield format_sse_event("generating", json.dumps({
- "status": "generating",
- "message": f"正在调用阿里云 Qwen3 ({CustomAPIConfig.get_model_name()})...",
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- # 执行生成
- if CustomAPIConfig.is_enabled():
- logger.info(f"[{callback_task_id}] 使用阿里云 DashScope API (模型:{CustomAPIConfig.get_model_name()})")
- async for content, ftl in call_custom_api_stream(
- prompt=user_prompt,
- system_prompt=CONTEXT_GENERATE_SYSTEM_PROMPT,
- max_tokens=min(request.completion_config.target_length, 4000),
- temperature=0.7,
- trace_id=callback_task_id
- ):
- if await is_cancelled():
- yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
- return
-
- if content:
- full_content_parts.append(content)
- chunk_count += 1
-
- if first_token_latency is None:
- first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
- logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s (Model: {CustomAPIConfig.get_model_name()})")
-
- yield format_sse_event("chunk", json.dumps({
- "chunk": content,
- "first_token_latency": round(first_token_latency, 3),
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- else:
- # 备用逻辑 (理论上不会触发,因为 Key 已硬编码)
- logger.warning(f"[{callback_task_id}] API 配置失效,回退到默认模型 (不应发生)")
- raise Exception("API 配置未生效,请检查 CustomAPIConfig")
- # 完成统计
- total_duration = time.time() - stream_start_time
- full_content = "".join(full_content_parts)
-
- logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
- yield format_sse_event("completed", json.dumps({
- "callback_task_id": callback_task_id,
- "status": "completed",
- "metrics": {
- "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
- "total_duration": round(total_duration, 3),
- "char_count": len(full_content),
- "chunk_count": chunk_count,
- "model_used": CustomAPIConfig.get_model_name()
- },
- "full_content": full_content,
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- except Exception as e:
- logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
- yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
- # ==================== 上下文生成API路由 ====================
- context_generate_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- @context_generate_router.post("/context_generate")
- @auto_trace(generate_if_missing=True)
- async def context_generate(request: ContextGenerateRequest):
- callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
-
- receive_time = time.time()
-
- try:
- validate_user_id(request.user_id)
- validate_completion_config(request.completion_config)
- validate_request(request)
-
- redis_client = await get_redis_client()
-
- logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
- return StreamingResponse(
- generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache, no-store, must-revalidate",
- "Pragma": "no-cache",
- "Expires": "0",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- "Content-Type": "text/event-stream; charset=utf-8",
- "Access-Control-Allow-Origin": "*"
- }
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @context_generate_router.get("/context_generate_health")
- async def health_check():
- return {
- "status": "healthy",
- "provider": "Aliyun DashScope",
- "current_model": CustomAPIConfig.get_model_name(),
- "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
- }
- @context_generate_router.get("/context_generate_modes", response_model=ContextGenerateResponse)
- async def get_modes():
- modes = [
- {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
- {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
- ]
- return ContextGenerateResponse(code=200, message="success", data={"modes": modes})
- @context_generate_router.get("/context_generate_api_status", response_model=ContextGenerateResponse)
- async def get_api_status():
- enabled = CustomAPIConfig.is_enabled()
- return ContextGenerateResponse(
- code=200, message="success",
- data={
- "enabled": enabled,
- "provider": "Aliyun DashScope",
- "model": CustomAPIConfig.get_model_name()
- }
- )
- # ==================== 原有大纲生成接口实现 ====================
- @outline_router.post("/generating_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def generating_outline(request: OutlineGenerationRequest):
- """
- 大纲生成接口 (SSE 流式响应)
- """
- callback_task_id = f"outline_{uuid.uuid4().hex[:16]}"
- TraceContext.set_trace_id(callback_task_id)
- user_id = request.user_id
- logger.info(f"接收大纲生成 SSE 请求: user_id={user_id}, project={request.project_info.base_info.project_name}")
- # ===== 新增:创建 Redis 连接用于检查终止标志 =====
- redis_check_client = None
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # ===== 新增:定义取消检查函数 =====
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{callback_task_id}") > 0
- except Exception:
- return False
- # 使用统一 SSE 管理器建立连接并注册回调
- queue = await unified_sse_manager.establish_connection(callback_task_id, sse_progress_callback)
- async def generate_outline_events() -> AsyncGenerator[str, None]:
- """生成大纲生成 SSE 事件流"""
- try:
- # ===== 检查点1: 函数开始 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 发送连接确认事件
- connected_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": "SSE 连接已建立,正在启动大纲生成任务...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # 构建任务信息
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or ""
- }
- sgbx_task_info = {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": request.generation_template.alias or "default_template",
- "generation_chapterenum": request.generation_chapterenum,
- "generation_template": [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ],
- "similarity_config": {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- },
- "knowledge_config": {
- "topk": 3,
- "threshold": 0.75
- },
- }
- # ===== 检查点2: 任务提交前 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 发送处理中事件
- processing_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": "正在提交大纲生成任务...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # 提交任务到 Celery
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"大纲生成任务已提交: callback_task_id={callback_task_id}, celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "大纲生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # 持续监听进度并转发
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- while True:
- try:
- # ===== 检查点3: 每次轮询前检查取消 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(callback_task_id)
- if progress_data:
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
-
- # ===== 新增:检测到取消立即返回 =====
- if status == "cancelled":
- logger.info(f"[{callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
-
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "大纲生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # 获取最终结果
- # 获取最终结果
- final_result = await workflow_manager.get_outline_sgbx_task_info(callback_task_id)
- # ===== 检查点4: 结果返回前检查取消 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 新增:检查任务结果是否为已取消 =====
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "大纲生成完成",
- "status": "completed",
- "message": "大纲生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "大纲生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"大纲生成 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(callback_task_id)
- return StreamingResponse(
- generate_outline_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- @outline_router.post("/regenerate_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def regenerate_outline(request: RegenerateOutlineRequest):
- """
- 重新生成大纲接口 (SSE 流式响应)
- 【任务状态管理】
- - 重新生成会创建**新任务**,原任务状态保持不变
- - 新任务通过 regenerate_config.source_task_id 关联原任务
- - 原任务仍可查询,不受影响
- 【字段说明】
- - generation_chapterenum: 可选,默认使用原任务的章节列表
- - project_info: 可选,默认使用原任务的项目信息
- - generation_template: 可选,默认使用原任务的模板
- 【错误处理】
- - 原任务不存在: 返回 404 错误事件
- - 原任务已完成/失败: 允许重新生成(基于已完成结果进行局部调整)
- - 重新生成配置缺失: 返回 400 错误事件
- 【与 /generating_outline 的复用关系】
- - 复用 generating_outline 的核心 SSE 事件流生成逻辑
- - 差异点:1) 构建任务信息时合并原任务数据 2) 添加 regenerate_config 标记
- """
- # ===== 1. 参数校验 =====
- if not request.regenerate_config:
- logger.error("重新生成配置缺失")
- raise HTTPException(status_code=400, detail="regenerate_config 为必填项")
- # 生成新任务ID(重要:重新生成创建新任务,不覆盖原任务)
- new_callback_task_id = f"outline_regen_{uuid.uuid4().hex[:16]}"
- source_task_id = request.task_id # 原任务ID用于数据查询
- TraceContext.set_trace_id(new_callback_task_id)
- user_id = request.user_id
- regenerate_config = request.regenerate_config
- logger.info(f"接收重新生成大纲 SSE 请求: "
- f"source_task_id={source_task_id}, "
- f"new_task_id={new_callback_task_id}, "
- f"user_id={user_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}")
- # ===== 2. 获取原任务信息(带错误处理)=====
- original_task = None
- try:
- original_task = await workflow_manager.get_outline_sgbx_task_info(source_task_id)
- except Exception as e:
- logger.warning(f"获取原任务信息异常: {source_task_id}, error={e}")
- # 原任务不存在处理
- if not original_task:
- logger.error(f"原任务不存在: {source_task_id}")
- async def error_not_found():
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "原任务不存在",
- "status": "error",
- "message": f"原任务不存在或已过期: {source_task_id}",
- "overall_task_status": "failed",
- "error_code": "SOURCE_TASK_NOT_FOUND",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- return StreamingResponse(
- error_not_found(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # 获取原任务状态
- original_status = original_task.get("status") or original_task.get("overall_task_status", "unknown")
- logger.info(f"原任务状态: {source_task_id} = {original_status}")
- # 使用统一 SSE 管理器建立连接(使用新任务ID)
- queue = await unified_sse_manager.establish_connection(new_callback_task_id, sse_progress_callback)
- # ===== 3. 复用 generating_outline 的核心逻辑 =====
- async def generate_regenerate_events() -> AsyncGenerator[str, None]:
- """生成重新生成 SSE 事件流 - 复用 generating_outline 模式"""
- redis_check_client = None
- try:
- # ===== 3.1 初始化 Redis 连接(复用 generating_outline 模式)=====
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{new_callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # 定义取消检查函数(复用 generating_outline 模式)
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not new_callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{new_callback_task_id}") > 0
- except Exception:
- return False
- # ===== 3.2 检查取消(复用 generating_outline 检查点1)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.3 发送连接确认(复用 generating_outline 模式)=====
- connected_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": f"SSE 连接已建立,正在启动重新生成任务(原任务: {source_task_id}, 状态: {original_status})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # ===== 3.4 构建任务信息(合并原任务数据 + 新配置)=====
- # 优先使用传入的 project_info,否则使用原任务的
- if request.project_info:
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or ""
- }
- else:
- project_info_flat = original_task.get("project_info", {})
- # 处理 generation_template
- if request.generation_template:
- outline_structure = [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ]
- template_alias = request.generation_template.alias or "default_template"
- else:
- # 从原任务提取模板结构
- outline_structure = original_task.get("generation_template", [])
- if not outline_structure:
- outline_structure = original_task.get("results", {}).get("outline_structure", [])
- template_alias = original_task.get("template_id", "default_template")
- # 处理 generation_chapterenum(可选,默认使用原任务)
- generation_chapterenum = request.generation_chapterenum
- if generation_chapterenum is None:
- generation_chapterenum = original_task.get("generation_chapterenum", [])
- # 如果原任务也没有,则根据 regenerate_config.target_path 推断
- if not generation_chapterenum and regenerate_config.get("target_path"):
- target_path = regenerate_config.get("target_path")
- # 内嵌:根据路径查找章节代码的逻辑
- original_outline = original_task.get("results", {}).get("outline_structure", [])
- chapter_code = None
- if original_outline and target_path:
- path_parts = target_path.split(".")
-
- def search_in_nodes(nodes, depth=0):
- if depth >= len(path_parts):
- return None
- target_index = path_parts[depth]
- for node in nodes:
- node_index = str(node.get("index", ""))
- if node_index == target_index:
- if depth == len(path_parts) - 1:
- return node.get("code")
- children = node.get("children", [])
- if children:
- result = search_in_nodes(children, depth + 1)
- if result:
- return result
- return None
-
- chapter_code = search_in_nodes(original_outline)
-
- if chapter_code:
- generation_chapterenum = [chapter_code]
- # 构建完整任务信息(与 generating_outline 格式保持一致)
- sgbx_task_info = {
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id, # 关联原任务
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": template_alias,
- "generation_chapterenum": generation_chapterenum,
- "generation_template": outline_structure,
- "similarity_config": original_task.get("similarity_config", {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- }),
- "knowledge_config": original_task.get("knowledge_config", {
- "topk": 3,
- "threshold": 0.75
- }),
- # 重新生成特有配置
- "regenerate_config": regenerate_config,
- "is_regenerate": True,
- "original_task_status": original_status # 记录原任务状态
- }
- logger.info(f"重新生成任务信息构建完成: "
- f"new_task_id={new_callback_task_id}, "
- f"source_task_id={source_task_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}, "
- f"chapters={generation_chapterenum}")
- # ===== 3.5 检查取消(复用 generating_outline 检查点2)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.6 发送处理中事件(复用 generating_outline 模式)=====
- processing_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": f"正在提交重新生成任务(目标: {regenerate_config.get('target_path', 'unknown')})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # ===== 3.7 提交任务到 Celery(复用 generating_outline 模式)=====
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"重新生成任务已提交: "
- f"new_callback_task_id={new_callback_task_id}, "
- f"celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "重新生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # ===== 3.8 持续监听进度(完全复用 generating_outline 模式)=====
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- while True:
- try:
- # 检查取消(复用 generating_outline 检查点3)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(new_callback_task_id)
- if progress_data:
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
- # 检测到取消立即返回
- if status == "cancelled":
- logger.info(f"[{new_callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "重新生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {new_callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # ===== 3.9 获取最终结果(复用 generating_outline 模式)=====
- final_result = await workflow_manager.get_outline_sgbx_task_info(new_callback_task_id)
- # 检查取消(复用 generating_outline 检查点4)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 检查任务结果是否为已取消
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.10 返回最终结果(复用 generating_outline 模式)=====
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "重新生成完成",
- "status": "completed",
- "message": "大纲重新生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "重新生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"重新生成大纲 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(new_callback_task_id)
- return StreamingResponse(
- generate_regenerate_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # ==================== POST 接口 ====================
- @outline_router.post("/task_cancel")
- @auto_trace(generate_if_missing=True)
- async def task_cancel(request: TaskCancelRequest):
- """
- 取消大纲生成任务
-
- 【修复】现在支持取消预注册状态(pending)的任务,即任务提交后、Worker 执行前的时间段。
- """
- import redis.asyncio as redis_async
- from redis.asyncio.connection import ConnectionPool
-
- redis_client = None
-
- try:
- logger.info(f"接收取消任务请求: task_id={request.task_id}")
-
- # 【修复】优先使用 workflow_manager 获取任务信息(支持预注册状态)
- task_info = await workflow_manager.get_outline_sgbx_task_info(request.task_id)
-
- if not task_info:
- return {
- "code": 404,
- "message": "任务不存在",
- "data": {"task_id": request.task_id, "error_type": "TASK_NOT_FOUND"}
- }
-
- # 检查任务状态
- task_status = task_info.get("status") or task_info.get("overall_task_status", "unknown")
- is_pre_registered = task_info.get("is_pre_registered", False)
-
- if task_status == "cancelled":
- return {
- "code": 200,
- "message": "任务已处于取消状态",
- "data": {"task_id": request.task_id, "status": "cancelled"}
- }
-
- if task_status in ["completed", "failed"]:
- return {
- "code": 400,
- "message": f"任务已完成,无法取消",
- "data": {"task_id": request.task_id, "current_status": task_status}
- }
-
- # 【修复】使用 workflow_manager 的 set_outline_terminate_signal 方法
- # 支持 pending(预注册)和 processing(执行中)两种状态
- result = await workflow_manager.set_outline_terminate_signal(
- callback_task_id=request.task_id,
- operator=request.user_id
- )
-
- if not result.get("success"):
- logger.warning(f"设置终止信号失败: {result.get('message')}")
- return {
- "code": 400,
- "message": result.get("message", "取消任务失败"),
- "data": {"task_id": request.task_id, "current_status": task_status}
- }
-
- cancelled_at = int(time.time())
-
- # 【修复】如果是预注册状态(pending),任务已被直接取消
- if is_pre_registered or task_status == "pending":
- logger.info(f"预注册任务已被取消: {request.task_id}")
-
- # 更新进度信息
- try:
- await progress_manager.update_stage_progress(
- callback_task_id=request.task_id,
- overall_task_status="cancelled",
- status="cancelled",
- message=f"任务已被用户取消: {request.cancel_reason}"
- )
- except Exception as e:
- logger.warning(f"更新进度信息失败: {e}")
-
- return {
- "code": 200,
- "message": "任务已成功取消",
- "data": {
- "task_id": request.task_id,
- "status": "cancelled",
- "cancelled_at": cancelled_at,
- "cancel_reason": request.cancel_reason,
- "cancelled_by": request.user_id,
- "is_pre_registered": True
- }
- }
-
- # 对于正在执行的任务(processing),设置额外的取消标志
- try:
- pool = ConnectionPool(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- max_connections=20,
- socket_connect_timeout=10,
- socket_timeout=10,
- retry_on_timeout=True,
- health_check_interval=30
- )
-
- redis_client = redis_async.Redis(connection_pool=pool)
-
- terminate_data = json.dumps({
- "cancelled": True,
- "cancelled_by": request.user_id,
- "cancel_reason": request.cancel_reason,
- "cancelled_at": cancelled_at
- })
-
- pipe = redis_client.pipeline()
- pipe.set(f"terminate:{request.task_id}", terminate_data, ex=3600)
- pipe.hset(f"progress:{request.task_id}", mapping={
- "overall_task_status": "cancelled",
- "status": "cancelled",
- "message": f"任务已被用户取消: {request.cancel_reason}",
- "cancelled_at": str(cancelled_at),
- "updated_at": str(cancelled_at)
- })
- await pipe.execute()
- logger.info(f"终止标志已设置: {request.task_id}")
-
- await redis_client.close()
- await pool.disconnect()
- redis_client = None
-
- except Exception as e:
- logger.error(f"设置终止标志失败: {e}")
- # 不影响主流程,继续执行
-
- # 尝试终止 Celery 任务
- celery_task_id = task_info.get("celery_task_id") or task_info.get("celery_id")
- if celery_task_id:
- try:
- from celery import current_app as celery_app
- celery_app.control.revoke(celery_task_id, terminate=True)
- logger.info(f"Celery 终止信号已发送: {celery_task_id}")
- except Exception as e:
- logger.warning(f"终止 Celery 任务失败: {e}")
-
- # 关闭 SSE 连接
- try:
- cancel_event = {
- "callback_task_id": request.task_id,
- "status": "cancelled",
- "overall_task_status": "cancelled",
- "message": f"任务已被用户取消: {request.cancel_reason}",
- "cancelled_at": cancelled_at,
- "cancelled_by": request.user_id
- }
- await unified_sse_manager.send_progress(request.task_id, cancel_event)
- await unified_sse_manager.close_connection(request.task_id)
- except Exception as e:
- logger.warning(f"关闭 SSE 失败: {e}")
-
- return {
- "code": 200,
- "message": "任务取消成功",
- "data": {
- "task_id": request.task_id,
- "status": "cancelled",
- "cancelled_at": cancelled_at,
- "cancel_reason": request.cancel_reason,
- "cancelled_by": request.user_id
- }
- }
- except Exception as e:
- logger.error(f"取消任务异常: {str(e)}", exc_info=True)
- return {
- "code": 500,
- "message": f"取消任务失败: {str(e)}",
- "data": {"task_id": request.task_id}
- }
- finally:
- # 关闭连接池
- if redis_client:
- try:
- await redis_client.close()
- except:
- pass
- # ==================== 查询接口 ====================
- @outline_router.get("/task_status")
- @auto_trace(generate_if_missing=True)
- async def task_status(
- task_id: str = Query(..., description="任务ID"),
- user_id: str = Query(..., description="用户ID")
- ):
- """
- 查询大纲生成任务状态
- Args:
- task_id: 任务回调ID
- user_id: 用户ID
- Returns:
- dict: 任务状态信息
- """
- try:
- logger.info(f"查询任务状态: task_id={task_id}")
- # 获取任务信息
- sgbx_task_info = await workflow_manager.get_outline_sgbx_task_info(task_id)
- if sgbx_task_info is None:
- return {
- "code": 404,
- "message": "任务不存在或已完成",
- "data": None
- }
- return {
- "code": 200,
- "message": "查询成功",
- "data": sgbx_task_info
- }
- except Exception as e:
- logger.error(f"查询任务状态失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
- @outline_router.get("/active_tasks")
- @auto_trace(generate_if_missing=True)
- async def active_tasks(
- user_id: str = Query(None, description="用户ID(可选,不提供则返回所有任务)")
- ):
- """
- 获取活跃的大纲生成任务列表
- Args:
- user_id: 用户ID(可选)
- Returns:
- dict: 活跃任务列表
- """
- try:
- logger.info(f"获取活跃任务列表: user_id={user_id}")
- # 获取所有活跃任务
- active_tasks_list = await workflow_manager.get_outline_active_tasks()
- # 如果指定了用户ID,则过滤
- if user_id:
- active_tasks_list = [task for task in active_tasks_list if task["user_id"] == user_id]
- return {
- "code": 200,
- "message": "查询成功",
- "data": {
- "total": len(active_tasks_list),
- "tasks": active_tasks_list
- }
- }
- except Exception as e:
- logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"获取活跃任务列表失败: {str(e)}")
- # 将上下文生成路由注册到应用
- outline_router.include_router(context_generate_router)
|