| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757 |
- # -*- coding: utf-8 -*-
- """
- 大纲生成 API 接口 (SSE 版本)
- 集成到 Celery + WorkflowManager 架构中
- 提供以下接口:
- - SSE /sgbx/generating_outline: SSE 流式大纲生成
- - SSE /sgbx/regenerate_outline: SSE 流式重新生成
- - POST /sgbx/task_cancel: 取消大纲生成任务
- - POST /sgbx/context_generate: SSE 流式上下文生成 (新增)
- """
- import os
- import uuid
- import json
- import time
- import asyncio
- import aiohttp
- from typing import Optional, Dict, Any, List, AsyncGenerator, Union
- from pydantic import BaseModel, Field
- from fastapi import APIRouter, HTTPException, Query
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from foundation.infrastructure.config.config import config_handler
- from core.base.workflow_manager import WorkflowManager
- from core.base.sse_manager import unified_sse_manager
- from core.base.progress_manager import ProgressManager
- from redis import asyncio as redis_async # 新增这行
- from redis.asyncio import Redis as AsyncRedis
- # 创建路由
- outline_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- # 初始化工作流管理器
- workflow_manager = WorkflowManager(
- max_concurrent_docs=3,
- max_concurrent_reviews=5
- )
- # 初始化进度管理器
- progress_manager = ProgressManager()
- async def sse_progress_callback(callback_task_id: str, current_data: dict):
- """SSE 推送回调函数 - 接收进度更新并推送到客户端"""
- await unified_sse_manager.send_progress(callback_task_id, current_data)
- def format_sse_event(event_type: str, data: str) -> str:
- """格式化 SSE 事件 - 按照 SSE 协议格式化事件数据"""
- lines = [
- f"event: {event_type}",
- f"data: {data}",
- "",
- ""
- ]
- return "\n".join(lines) + "\n"
- # ==================== 请求/响应模型 ====================
- class BaseInfo(BaseModel):
- """项目基础信息"""
- project_name: str = Field(..., description="方案名称", example="罗成依达大桥上部结构专项施工方案")
- construct_location: str = Field(..., description="建设地点", example="四川省凉山州")
- engineering_type: str = Field(..., description="方案模版类型", example="T型梁")
- class ProjectInfo(BaseModel):
- """项目信息(嵌套结构)"""
- base_info: BaseInfo = Field(..., description="基础信息")
- selectable: Optional[str] = Field("", description="其他可选信息")
- class TemplateStructureItem(BaseModel):
- """模板结构项(支持嵌套children)"""
- index: str = Field(..., description="章节编号", example="2")
- level: int = Field(..., description="层级", ge=1, le=5)
- title: str = Field(..., description="章节标题", example="工程概况")
- code: str = Field(..., description="章节代码", example="overview")
- # 使用 Union 支持递归类型
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class GenerationTemplate(BaseModel):
- """大纲生成模板
- 示例:
- {
- "source_file": "方案编写助手原文关键词规范文档修改版-2026-2-5.md",
- "alias": "施工方案知识审查与编写体系",
- "structure": [
- {
- "index": "2",
- "level": 1,
- "title": "工程概况",
- "code": "overview",
- "children": [...]
- }
- ]
- }
- """
- source_file: Optional[str] = Field(None, description="源文件", example="方案编写助手原文关键词规范文档修改版-2026-2-5.md")
- alias: Optional[str] = Field(None, description="别名", example="施工方案知识审查与编写体系")
- structure: List[Union[TemplateStructureItem, Dict[str, Any]]] = Field(..., description="模板结构")
- class OutlineGenerationRequest(BaseModel):
- """大纲生成请求
- 示例请求体(适配curl示例):
- {
- "user_id": "user-001",
- "project_info": {
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": {
- "source_file": "方案编写助手原文关键词规范文档修改版-2026-2-5.md",
- "alias": "施工方案知识审查与编写体系",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_ProjectIntroduction", ...]
- }
- """
- user_id: str = Field(..., description="用户标识", example="user-001")
- project_info: ProjectInfo = Field(..., description="项目基础信息")
- generation_template: GenerationTemplate = Field(..., description="大纲生成模板")
- generation_chapterenum: List[str] = Field(default_factory=list, description="生成章节代码列表,为空时生成全部章节")
- class RegenerateOutlineRequest(BaseModel):
- """重新生成大纲请求
- 复用大纲生成接口的请求定义,额外添加 regenerate_config 字段用于指定重新生成配置。
- project_info 和 generation_template 为可选字段,不传入则使用原任务的信息。
- 示例请求:
- {
- "task_id": "task-20250130-123456",
- "user_id": "user-001",
- "project_info": { // 可选,不传则使用原任务的项目信息
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": { // 可选,不传则使用原任务的模板
- "source_file": "...",
- "alias": "...",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_MainTechnicalStandards"], // 可选
- "regenerate_config": {
- "regenerate_mode": "chapter",
- "target_path": "2.1",
- "preserve_children": true,
- "reason": "调整内容结构"
- }
- }
- """
- task_id: str = Field(..., description="原大纲生成任务ID")
- user_id: str = Field(..., description="用户ID")
- # 可选:复用大纲生成接口的项目信息(不传则使用原任务的)
- project_info: Optional[ProjectInfo] = Field(None, description="项目基础信息(可选)")
- # 可选:复用大纲生成接口的模板(不传则使用原任务的)
- generation_template: Optional[GenerationTemplate] = Field(None, description="大纲生成模板(可选)")
- # 可选:复用大纲生成接口的章节代码列表
- generation_chapterenum: Optional[List[str]] = Field(None, description="生成章节代码列表(可选)")
- # 重新生成特有的配置
- regenerate_config: Dict[str, Any] = Field(..., description="重新生成配置")
- class TaskCancelRequest(BaseModel):
- """任务取消请求"""
- task_id: str = Field(..., description="任务ID")
- user_id: str = Field(..., description="用户ID")
- cancel_reason: Optional[str] = Field("用户主动取消", description="取消原因")
- # ==================== 响应模型 ====================
- class OutlineNodeResponse(BaseModel):
- """大纲节点响应模型
- 与请求的 TemplateStructureItem 对应,增加以下字段:
- - 每级节点都包含 generated_content
- - 2级和3级节点包含 similar_fragments
- - 末级节点额外包含 key_points 和 knowledge_bases
- """
- index: str = Field(..., description="章节编号", example="2.1.1")
- level: int = Field(..., description="层级", ge=1, le=5, example=3)
- title: str = Field(..., description="章节标题", example="工程简介")
- code: str = Field(..., description="章节代码", example="overview_DesignSummary_ProjectIntroduction")
- generated_content: str = Field(..., description="AI生成的内容", example="罗成依达大桥位于四川省凉山州...")
- # 2级和3级节点包含 similar_fragments
- similar_fragments: Optional[List[Dict[str, Any]]] = Field(None, description="相似片段推荐(2级和3级节点)")
- key_points: Optional[List[str]] = Field(None, description="核心要点(仅末级节点)", example=["桥位位置", "桥梁规模"])
- knowledge_bases: Optional[List[str]] = Field(None, description="知识点/编制依据(仅末级节点)", example=["《公路桥涵设计通用规范》JTG D60-2015"])
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class SimilarPlanResponse(BaseModel):
- """相似方案响应(整篇方案级推荐)"""
- plan_id: str = Field(..., description="方案ID")
- plan_title: str = Field(..., description="方案标题")
- similarity_score: float = Field(..., description="相似度分数", ge=0.0, le=1.0)
- plan_type: str = Field(..., description="方案类型")
- outline: Optional[List[Dict]] = Field(None, description="方案大纲结构")
- metadata: Optional[Dict[str, Any]] = Field(None, description="元数据")
- class SimilarFragmentResponse(BaseModel):
- """相似片段响应"""
- fragment_id: str = Field(..., description="片段ID")
- section_path: str = Field(..., description="所属章节路径")
- section_title: str = Field(..., description="章节标题")
- fragment_content: str = Field(..., description="片段内容")
- similarity_score: float = Field(..., description="相似度分数", ge=0.0, le=1.0)
- source_document_id: str = Field(..., description="来源文档ID")
- source_document_title: str = Field(..., description="来源文档标题")
- class OutlineGenerationResult(BaseModel):
- """大纲生成结果
- 结构说明:
- - outline_structure: 嵌套大纲结构,每个节点包含 generated_content
- - 2级和3级节点额外包含 similar_fragments
- - 末级节点额外包含 key_points 和 knowledge_bases
- - similar_plan: 整篇方案的相似方案推荐(顶层)
- """
- outline_structure: List[OutlineNodeResponse] = Field(..., description="大纲结构(包含AI生成内容和章节级similar_fragments)")
- similar_plan: List[SimilarPlanResponse] = Field(default_factory=list, description="相似方案推荐(整篇方案级)")
- # ==================== 上下文生成新增模型 ====================
- class CompletionConfig(BaseModel):
- section_path: str = Field(..., description="章节路径")
- current_content: str = Field(default="", description="当前已有内容")
- context_window: int = Field(default=2000, ge=500, le=5000)
- completion_mode: str = Field(default="continue", description="模式")
- target_length: int = Field(default=1000, ge=100, le=5000)
- include_references: bool = Field(default=True)
- style_match: bool = Field(default=True)
- hint_keywords: Optional[List[str]] = Field(default=None)
- class ProjectInfoSimple(BaseModel):
- project_name: str = Field(default="施工方案")
- construct_location: Optional[str] = Field(default=None)
- engineering_type: Optional[str] = Field(default=None)
- class ContextGenerateRequest(BaseModel):
- task_id: Optional[str] = Field(default=None)
- user_id: str = Field(...)
- project_info: Optional[ProjectInfoSimple] = Field(default=None)
- completion_config: CompletionConfig = Field(...)
- model_name: Optional[str] = Field(default=None)
- class Config: extra = "forbid"
- class ContextGenerateResponse(BaseModel):
- code: int
- message: str
- data: Optional[Dict[str, Any]] = None
- # ==================== 全局资源池 (速度优化核心) ====================
- GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None
- async def init_global_resources():
- """初始化全局连接池"""
- global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT
-
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- # 增加 DNS 缓存和连接复用,针对阿里云域名优化
- connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False)
- GLOBAL_HTTP_SESSION = aiohttp.ClientSession(
- timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10), # 连接超时稍长以防网络波动
- connector=connector,
- headers={"User-Agent": "FastAPI-DashScope-Optimized/2.0"}
- )
- logger.info("✅ 全局 HTTP 连接池已初始化 (DashScope Ready)")
- if GLOBAL_REDIS_CLIENT is None:
- try:
- GLOBAL_REDIS_CLIENT = AsyncRedis(
- host='127.0.0.1', port=6379, password='123456', db=0,
- decode_responses=True, socket_connect_timeout=1,
- socket_keepalive=True, max_connections=50
- )
- asyncio.create_task(_background_ping())
- logger.info("✅ 全局 Redis 连接池已初始化")
- except Exception as e:
- logger.warning(f"⚠️ Redis 初始化失败: {e}")
- GLOBAL_REDIS_CLIENT = None
- async def _background_ping():
- if GLOBAL_REDIS_CLIENT:
- try: await GLOBAL_REDIS_CLIENT.ping()
- except: pass
- async def get_http_session():
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- await init_global_resources()
- return GLOBAL_HTTP_SESSION
- async def get_redis_client():
- if GLOBAL_REDIS_CLIENT is None:
- await init_global_resources()
- return GLOBAL_REDIS_CLIENT
- # ==================== 自定义 API 配置 (阿里云 DashScope) ====================
- class CustomAPIConfig:
- # 【关键修改】阿里云 DashScope 兼容模式地址
- # 注意:必须包含 /chat/completions 后缀
- DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
- DASHSCOPE_CHAT_URL = f"{DASHSCOPE_BASE_URL}/chat/completions"
-
- # 【关键修改】您的 API Key
- DASHSCOPE_API_KEY = "sk-ae805c991b6a4a8da3a09351c34963a5"
-
- # 【关键修改】目标模型
- DEFAULT_MODEL_NAME = "qwen3-30b-a3b-instruct-2507"
-
- @staticmethod
- def get_api_url() -> str:
- # 优先使用硬编码的阿里云地址
- return CustomAPIConfig.DASHSCOPE_CHAT_URL
-
- @staticmethod
- def get_api_key() -> str:
- return CustomAPIConfig.DASHSCOPE_API_KEY
-
- @staticmethod
- def get_model_name() -> str:
- # 允许配置覆盖,否则使用默认
- configured_model = config_handler.get("custom_api", "MODEL_NAME", "")
- return configured_model if configured_model else CustomAPIConfig.DEFAULT_MODEL_NAME
-
- @staticmethod
- def is_enabled() -> bool:
- # 只要 Key 不为空即启用
- return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url())
- # ==================== 极速流式调用 (核心优化) ====================
- async def call_custom_api_stream(
- prompt: str, system_prompt: str = "", max_tokens: int = 2000,
- temperature: float = 0.7, trace_id: str = ""
- ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
-
- api_url = CustomAPIConfig.get_api_url()
- model_name = CustomAPIConfig.get_model_name()
- api_key = CustomAPIConfig.get_api_key()
-
- logger.debug(f"[{trace_id}] 正在调用阿里云 DashScope: {model_name} @ {api_url}")
- # 截断过长的 Prompt (阿里云对输入长度有限制,且为了速度)
- max_prompt_len = 10000
- if len(prompt) > max_prompt_len:
- prompt = prompt[-max_prompt_len:]
- logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
- payload = {
- "model": model_name,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt}
- ],
- "max_tokens": max_tokens,
- "temperature": temperature,
- "stream": True,
- "incremental_output": True # 阿里云兼容模式可能支持此参数,优化流式体验
- }
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}"
- }
-
- start_time = time.time()
- first_token_time: Optional[float] = None
- buffer = ""
-
- session = await get_http_session()
-
- try:
- # 阿里云 HTTPS 连接,保持 read_bufsize=1 以获取最快首字
- async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response:
- if response.status != 200:
- error_text = await response.text()
- logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}")
- raise Exception(f"API 错误 {response.status}: {error_text}")
-
- async for chunk in response.content.iter_any():
- if not chunk: continue
- try:
- text = chunk.decode('utf-8', errors='ignore')
- if not text: continue
- buffer += text
-
- while '\n' in buffer:
- line, buffer = buffer.split('\n', 1)
- line = line.strip()
-
- if line.startswith('data: '):
- data = line[6:]
- if data == '[DONE]':
- return
-
- try:
- event_data = json.loads(data)
- # 处理阿里云可能的错误格式
- if "error" in event_data:
- err_msg = event_data["error"].get("message", "Unknown Error")
- logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}")
- continue
- choices = event_data.get("choices", [])
- if choices:
- delta = choices[0].get("delta", {})
- content = delta.get("content", "")
-
- if content:
- if first_token_time is None:
- first_token_time = time.time() - start_time
- yield (content, first_token_time)
- except json.JSONDecodeError:
- continue
- except UnicodeDecodeError:
- continue
- except Exception as e:
- logger.error(f"[{trace_id}] API 流式请求异常: {e}")
- raise
- # ==================== 上下文生成业务逻辑辅助 ====================
- CONTEXT_GENERATE_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
- def build_context_generate_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
- parts = []
- parts.append(f"【项目】{project_info.get('project_name', '未知')}")
- parts.append(f"【章节】{section_title} ({section_path})")
- parts.append(f"【模式】{completion_mode} (目标:{target_length})")
-
- if context_before: parts.append(f"【前文】...{context_before[-500:]}")
- if current_content: parts.append(f"【当前】{current_content}")
- if context_after: parts.append(f"【后文】{context_after[:500]}...")
-
- parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
- return "\n".join(parts)
- def extract_chunk_content(chunk: Any) -> str:
- if isinstance(chunk, str): return chunk
- if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
- if isinstance(chunk, dict): return str(chunk.get('content', ''))
- return str(chunk)
- def validate_user_id(user_id: str):
- supported_users = {'user-001', 'user-002', 'user-003'}
- if user_id not in supported_users:
- raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
- def validate_completion_config(config: CompletionConfig):
- if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
- raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
- def validate_request(request: ContextGenerateRequest):
- if not request.task_id and not request.project_info:
- raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
- # ==================== 上下文生成核心流式逻辑 ====================
- async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
- async def is_cancelled() -> bool:
- if not redis_client: return False
- try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
- except: return False
- stream_start_time = time.time()
- first_token_latency: Optional[float] = None
- full_content_parts: List[str] = []
- chunk_count = 0
- try:
- yield format_sse_event("connected", json.dumps({
- "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
- }, ensure_ascii=False))
- project_info = request.project_info.dict() if request.project_info else {}
- section_title = f"章节 {request.completion_config.section_path}"
-
- user_prompt = build_context_generate_prompt(
- project_info=project_info,
- section_path=request.completion_config.section_path,
- section_title=section_title,
- current_content=request.completion_config.current_content,
- completion_mode=request.completion_config.completion_mode,
- target_length=request.completion_config.target_length,
- include_references=request.completion_config.include_references,
- style_match=request.completion_config.style_match,
- hint_keywords=request.completion_config.hint_keywords
- )
- yield format_sse_event("generating", json.dumps({
- "status": "generating",
- "message": f"正在调用阿里云 Qwen3 ({CustomAPIConfig.get_model_name()})...",
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- # 执行生成
- if CustomAPIConfig.is_enabled():
- logger.info(f"[{callback_task_id}] 使用阿里云 DashScope API (模型:{CustomAPIConfig.get_model_name()})")
- async for content, ftl in call_custom_api_stream(
- prompt=user_prompt,
- system_prompt=CONTEXT_GENERATE_SYSTEM_PROMPT,
- max_tokens=min(request.completion_config.target_length, 4000),
- temperature=0.7,
- trace_id=callback_task_id
- ):
- if await is_cancelled():
- yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
- return
-
- if content:
- full_content_parts.append(content)
- chunk_count += 1
-
- if first_token_latency is None:
- first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
- logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s (Model: {CustomAPIConfig.get_model_name()})")
-
- yield format_sse_event("chunk", json.dumps({
- "chunk": content,
- "first_token_latency": round(first_token_latency, 3),
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- else:
- # 备用逻辑 (理论上不会触发,因为 Key 已硬编码)
- logger.warning(f"[{callback_task_id}] API 配置失效,回退到默认模型 (不应发生)")
- raise Exception("API 配置未生效,请检查 CustomAPIConfig")
- # 完成统计
- total_duration = time.time() - stream_start_time
- full_content = "".join(full_content_parts)
-
- logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
- yield format_sse_event("completed", json.dumps({
- "callback_task_id": callback_task_id,
- "status": "completed",
- "metrics": {
- "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
- "total_duration": round(total_duration, 3),
- "char_count": len(full_content),
- "chunk_count": chunk_count,
- "model_used": CustomAPIConfig.get_model_name()
- },
- "full_content": full_content,
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- except Exception as e:
- logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
- yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
- # ==================== 上下文生成API路由 ====================
- context_generate_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- @context_generate_router.post("/context_generate")
- @auto_trace(generate_if_missing=True)
- async def context_generate(request: ContextGenerateRequest):
- callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
-
- receive_time = time.time()
-
- try:
- validate_user_id(request.user_id)
- validate_completion_config(request.completion_config)
- validate_request(request)
-
- redis_client = await get_redis_client()
-
- logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
- return StreamingResponse(
- generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache, no-store, must-revalidate",
- "Pragma": "no-cache",
- "Expires": "0",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- "Content-Type": "text/event-stream; charset=utf-8",
- "Access-Control-Allow-Origin": "*"
- }
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @context_generate_router.get("/context_generate_health")
- async def health_check():
- return {
- "status": "healthy",
- "provider": "Aliyun DashScope",
- "current_model": CustomAPIConfig.get_model_name(),
- "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
- }
- @context_generate_router.get("/context_generate_modes", response_model=ContextGenerateResponse)
- async def get_modes():
- modes = [
- {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
- {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
- ]
- return ContextGenerateResponse(code=200, message="success", data={"modes": modes})
- @context_generate_router.get("/context_generate_api_status", response_model=ContextGenerateResponse)
- async def get_api_status():
- enabled = CustomAPIConfig.is_enabled()
- return ContextGenerateResponse(
- code=200, message="success",
- data={
- "enabled": enabled,
- "provider": "Aliyun DashScope",
- "model": CustomAPIConfig.get_model_name()
- }
- )
- # ==================== 原有大纲生成接口实现 ====================
- @outline_router.post("/generating_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def generating_outline(request: OutlineGenerationRequest):
- """
- 大纲生成接口 (SSE 流式响应)
- """
- callback_task_id = f"outline_{uuid.uuid4().hex[:16]}"
- TraceContext.set_trace_id(callback_task_id)
- user_id = request.user_id
- logger.info(f"接收大纲生成 SSE 请求: user_id={user_id}, project={request.project_info.base_info.project_name}")
- # ===== 新增:创建 Redis 连接用于检查终止标志 =====
- redis_check_client = None
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # ===== 新增:定义取消检查函数 =====
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{callback_task_id}") > 0
- except Exception:
- return False
- # 使用统一 SSE 管理器建立连接并注册回调
- queue = await unified_sse_manager.establish_connection(callback_task_id, sse_progress_callback)
- async def generate_outline_events() -> AsyncGenerator[str, None]:
- """生成大纲生成 SSE 事件流"""
- try:
- # ===== 检查点1: 函数开始 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 发送连接确认事件
- connected_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": "SSE 连接已建立,正在启动大纲生成任务...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # 构建任务信息
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or ""
- }
- sgbx_task_info = {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": request.generation_template.alias or "default_template",
- "generation_chapterenum": request.generation_chapterenum,
- "generation_template": [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ],
- "similarity_config": {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- },
- "knowledge_config": {
- "topk": 3,
- "threshold": 0.75
- },
- }
- # ===== 检查点2: 任务提交前 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 发送处理中事件
- processing_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": "正在提交大纲生成任务...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # 提交任务到 Celery
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"大纲生成任务已提交: callback_task_id={callback_task_id}, celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "大纲生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # 持续监听进度并转发
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- while True:
- try:
- # ===== 检查点3: 每次轮询前检查取消 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(callback_task_id)
- if progress_data:
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
-
- # ===== 新增:检测到取消立即返回 =====
- if status == "cancelled":
- logger.info(f"[{callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
-
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "大纲生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # 获取最终结果
- # 获取最终结果
- final_result = await workflow_manager.get_outline_sgbx_task_info(callback_task_id)
- # ===== 检查点4: 结果返回前检查取消 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 新增:检查任务结果是否为已取消 =====
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "大纲生成完成",
- "status": "completed",
- "message": "大纲生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "大纲生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"大纲生成 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(callback_task_id)
- return StreamingResponse(
- generate_outline_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- @outline_router.post("/regenerate_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def regenerate_outline(request: RegenerateOutlineRequest):
- """
- 重新生成大纲接口 (SSE 流式响应)
- 【任务状态管理】
- - 重新生成会创建**新任务**,原任务状态保持不变
- - 新任务通过 regenerate_config.source_task_id 关联原任务
- - 原任务仍可查询,不受影响
- 【字段说明】
- - generation_chapterenum: 可选,默认使用原任务的章节列表
- - project_info: 可选,默认使用原任务的项目信息
- - generation_template: 可选,默认使用原任务的模板
- 【错误处理】
- - 原任务不存在: 返回 404 错误事件
- - 原任务已完成/失败: 允许重新生成(基于已完成结果进行局部调整)
- - 重新生成配置缺失: 返回 400 错误事件
- 【与 /generating_outline 的复用关系】
- - 复用 generating_outline 的核心 SSE 事件流生成逻辑
- - 差异点:1) 构建任务信息时合并原任务数据 2) 添加 regenerate_config 标记
- """
- # ===== 1. 参数校验 =====
- if not request.regenerate_config:
- logger.error("重新生成配置缺失")
- raise HTTPException(status_code=400, detail="regenerate_config 为必填项")
- # 生成新任务ID(重要:重新生成创建新任务,不覆盖原任务)
- new_callback_task_id = f"outline_regen_{uuid.uuid4().hex[:16]}"
- source_task_id = request.task_id # 原任务ID用于数据查询
- TraceContext.set_trace_id(new_callback_task_id)
- user_id = request.user_id
- regenerate_config = request.regenerate_config
- logger.info(f"接收重新生成大纲 SSE 请求: "
- f"source_task_id={source_task_id}, "
- f"new_task_id={new_callback_task_id}, "
- f"user_id={user_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}")
- # ===== 2. 获取原任务信息(带错误处理)=====
- original_task = None
- try:
- original_task = await workflow_manager.get_outline_sgbx_task_info(source_task_id)
- except Exception as e:
- logger.warning(f"获取原任务信息异常: {source_task_id}, error={e}")
- # 原任务不存在处理
- if not original_task:
- logger.error(f"原任务不存在: {source_task_id}")
- async def error_not_found():
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "原任务不存在",
- "status": "error",
- "message": f"原任务不存在或已过期: {source_task_id}",
- "overall_task_status": "failed",
- "error_code": "SOURCE_TASK_NOT_FOUND",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- return StreamingResponse(
- error_not_found(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # 获取原任务状态
- original_status = original_task.get("status") or original_task.get("overall_task_status", "unknown")
- logger.info(f"原任务状态: {source_task_id} = {original_status}")
- # 使用统一 SSE 管理器建立连接(使用新任务ID)
- queue = await unified_sse_manager.establish_connection(new_callback_task_id, sse_progress_callback)
- # ===== 3. 复用 generating_outline 的核心逻辑 =====
- async def generate_regenerate_events() -> AsyncGenerator[str, None]:
- """生成重新生成 SSE 事件流 - 复用 generating_outline 模式"""
- redis_check_client = None
- try:
- # ===== 3.1 初始化 Redis 连接(复用 generating_outline 模式)=====
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{new_callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # 定义取消检查函数(复用 generating_outline 模式)
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not new_callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{new_callback_task_id}") > 0
- except Exception:
- return False
- # ===== 3.2 检查取消(复用 generating_outline 检查点1)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.3 发送连接确认(复用 generating_outline 模式)=====
- connected_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": f"SSE 连接已建立,正在启动重新生成任务(原任务: {source_task_id}, 状态: {original_status})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # ===== 3.4 构建任务信息(合并原任务数据 + 新配置)=====
- # 优先使用传入的 project_info,否则使用原任务的
- if request.project_info:
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or ""
- }
- else:
- project_info_flat = original_task.get("project_info", {})
- # 处理 generation_template
- if request.generation_template:
- outline_structure = [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ]
- template_alias = request.generation_template.alias or "default_template"
- else:
- # 从原任务提取模板结构
- outline_structure = original_task.get("generation_template", [])
- if not outline_structure:
- outline_structure = original_task.get("results", {}).get("outline_structure", [])
- template_alias = original_task.get("template_id", "default_template")
- # 处理 generation_chapterenum(可选,默认使用原任务)
- generation_chapterenum = request.generation_chapterenum
- if generation_chapterenum is None:
- generation_chapterenum = original_task.get("generation_chapterenum", [])
- # 如果原任务也没有,则根据 regenerate_config.target_path 推断
- if not generation_chapterenum and regenerate_config.get("target_path"):
- target_path = regenerate_config.get("target_path")
- # 内嵌:根据路径查找章节代码的逻辑
- original_outline = original_task.get("results", {}).get("outline_structure", [])
- chapter_code = None
- if original_outline and target_path:
- path_parts = target_path.split(".")
-
- def search_in_nodes(nodes, depth=0):
- if depth >= len(path_parts):
- return None
- target_index = path_parts[depth]
- for node in nodes:
- node_index = str(node.get("index", ""))
- if node_index == target_index:
- if depth == len(path_parts) - 1:
- return node.get("code")
- children = node.get("children", [])
- if children:
- result = search_in_nodes(children, depth + 1)
- if result:
- return result
- return None
-
- chapter_code = search_in_nodes(original_outline)
-
- if chapter_code:
- generation_chapterenum = [chapter_code]
- # 构建完整任务信息(与 generating_outline 格式保持一致)
- sgbx_task_info = {
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id, # 关联原任务
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": template_alias,
- "generation_chapterenum": generation_chapterenum,
- "generation_template": outline_structure,
- "similarity_config": original_task.get("similarity_config", {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- }),
- "knowledge_config": original_task.get("knowledge_config", {
- "topk": 3,
- "threshold": 0.75
- }),
- # 重新生成特有配置
- "regenerate_config": regenerate_config,
- "is_regenerate": True,
- "original_task_status": original_status # 记录原任务状态
- }
- logger.info(f"重新生成任务信息构建完成: "
- f"new_task_id={new_callback_task_id}, "
- f"source_task_id={source_task_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}, "
- f"chapters={generation_chapterenum}")
- # ===== 3.5 检查取消(复用 generating_outline 检查点2)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.6 发送处理中事件(复用 generating_outline 模式)=====
- processing_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": f"正在提交重新生成任务(目标: {regenerate_config.get('target_path', 'unknown')})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # ===== 3.7 提交任务到 Celery(复用 generating_outline 模式)=====
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"重新生成任务已提交: "
- f"new_callback_task_id={new_callback_task_id}, "
- f"celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "重新生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # ===== 3.8 持续监听进度(完全复用 generating_outline 模式)=====
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- while True:
- try:
- # 检查取消(复用 generating_outline 检查点3)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(new_callback_task_id)
- if progress_data:
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
- # 检测到取消立即返回
- if status == "cancelled":
- logger.info(f"[{new_callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "重新生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {new_callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # ===== 3.9 获取最终结果(复用 generating_outline 模式)=====
- final_result = await workflow_manager.get_outline_sgbx_task_info(new_callback_task_id)
- # 检查取消(复用 generating_outline 检查点4)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 检查任务结果是否为已取消
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.10 返回最终结果(复用 generating_outline 模式)=====
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "重新生成完成",
- "status": "completed",
- "message": "大纲重新生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "重新生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"重新生成大纲 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(new_callback_task_id)
- return StreamingResponse(
- generate_regenerate_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # ==================== POST 接口 ====================
- @outline_router.post("/task_cancel")
- @auto_trace(generate_if_missing=True)
- async def task_cancel(request: TaskCancelRequest):
- """
- 取消大纲生成任务
-
- 【修复】现在支持取消预注册状态(pending)的任务,即任务提交后、Worker 执行前的时间段。
- """
- import redis.asyncio as redis_async
- from redis.asyncio.connection import ConnectionPool
-
- redis_client = None
-
- try:
- logger.info(f"接收取消任务请求: task_id={request.task_id}")
-
- # 【修复】优先使用 workflow_manager 获取任务信息(支持预注册状态)
- task_info = await workflow_manager.get_outline_sgbx_task_info(request.task_id)
-
- if not task_info:
- return {
- "code": 404,
- "message": "任务不存在",
- "data": {"task_id": request.task_id, "error_type": "TASK_NOT_FOUND"}
- }
-
- # 检查任务状态
- task_status = task_info.get("status") or task_info.get("overall_task_status", "unknown")
- is_pre_registered = task_info.get("is_pre_registered", False)
-
- if task_status == "cancelled":
- return {
- "code": 200,
- "message": "任务已处于取消状态",
- "data": {"task_id": request.task_id, "status": "cancelled"}
- }
-
- if task_status in ["completed", "failed"]:
- return {
- "code": 400,
- "message": f"任务已完成,无法取消",
- "data": {"task_id": request.task_id, "current_status": task_status}
- }
-
- # 【修复】使用 workflow_manager 的 set_outline_terminate_signal 方法
- # 支持 pending(预注册)和 processing(执行中)两种状态
- result = await workflow_manager.set_outline_terminate_signal(
- callback_task_id=request.task_id,
- operator=request.user_id
- )
-
- if not result.get("success"):
- logger.warning(f"设置终止信号失败: {result.get('message')}")
- return {
- "code": 400,
- "message": result.get("message", "取消任务失败"),
- "data": {"task_id": request.task_id, "current_status": task_status}
- }
-
- cancelled_at = int(time.time())
-
- # 【修复】如果是预注册状态(pending),任务已被直接取消
- if is_pre_registered or task_status == "pending":
- logger.info(f"预注册任务已被取消: {request.task_id}")
-
- # 更新进度信息
- try:
- await progress_manager.update_stage_progress(
- callback_task_id=request.task_id,
- overall_task_status="cancelled",
- status="cancelled",
- message=f"任务已被用户取消: {request.cancel_reason}"
- )
- except Exception as e:
- logger.warning(f"更新进度信息失败: {e}")
-
- return {
- "code": 200,
- "message": "任务已成功取消",
- "data": {
- "task_id": request.task_id,
- "status": "cancelled",
- "cancelled_at": cancelled_at,
- "cancel_reason": request.cancel_reason,
- "cancelled_by": request.user_id,
- "is_pre_registered": True
- }
- }
-
- # 对于正在执行的任务(processing),设置额外的取消标志
- try:
- pool = ConnectionPool(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- max_connections=20,
- socket_connect_timeout=10,
- socket_timeout=10,
- retry_on_timeout=True,
- health_check_interval=30
- )
-
- redis_client = redis_async.Redis(connection_pool=pool)
-
- terminate_data = json.dumps({
- "cancelled": True,
- "cancelled_by": request.user_id,
- "cancel_reason": request.cancel_reason,
- "cancelled_at": cancelled_at
- })
-
- pipe = redis_client.pipeline()
- pipe.set(f"terminate:{request.task_id}", terminate_data, ex=3600)
- pipe.hset(f"progress:{request.task_id}", mapping={
- "overall_task_status": "cancelled",
- "status": "cancelled",
- "message": f"任务已被用户取消: {request.cancel_reason}",
- "cancelled_at": str(cancelled_at),
- "updated_at": str(cancelled_at)
- })
- await pipe.execute()
- logger.info(f"终止标志已设置: {request.task_id}")
-
- await redis_client.close()
- await pool.disconnect()
- redis_client = None
-
- except Exception as e:
- logger.error(f"设置终止标志失败: {e}")
- # 不影响主流程,继续执行
-
- # 尝试终止 Celery 任务
- celery_task_id = task_info.get("celery_task_id") or task_info.get("celery_id")
- if celery_task_id:
- try:
- from celery import current_app as celery_app
- celery_app.control.revoke(celery_task_id, terminate=True)
- logger.info(f"Celery 终止信号已发送: {celery_task_id}")
- except Exception as e:
- logger.warning(f"终止 Celery 任务失败: {e}")
-
- # 关闭 SSE 连接
- try:
- cancel_event = {
- "callback_task_id": request.task_id,
- "status": "cancelled",
- "overall_task_status": "cancelled",
- "message": f"任务已被用户取消: {request.cancel_reason}",
- "cancelled_at": cancelled_at,
- "cancelled_by": request.user_id
- }
- await unified_sse_manager.send_progress(request.task_id, cancel_event)
- await unified_sse_manager.close_connection(request.task_id)
- except Exception as e:
- logger.warning(f"关闭 SSE 失败: {e}")
-
- return {
- "code": 200,
- "message": "任务取消成功",
- "data": {
- "task_id": request.task_id,
- "status": "cancelled",
- "cancelled_at": cancelled_at,
- "cancel_reason": request.cancel_reason,
- "cancelled_by": request.user_id
- }
- }
- except Exception as e:
- logger.error(f"取消任务异常: {str(e)}", exc_info=True)
- return {
- "code": 500,
- "message": f"取消任务失败: {str(e)}",
- "data": {"task_id": request.task_id}
- }
- finally:
- # 关闭连接池
- if redis_client:
- try:
- await redis_client.close()
- except:
- pass
- # ==================== 查询接口 ====================
- @outline_router.get("/task_status")
- @auto_trace(generate_if_missing=True)
- async def task_status(
- task_id: str = Query(..., description="任务ID"),
- user_id: str = Query(..., description="用户ID")
- ):
- """
- 查询大纲生成任务状态
- Args:
- task_id: 任务回调ID
- user_id: 用户ID
- Returns:
- dict: 任务状态信息
- """
- try:
- logger.info(f"查询任务状态: task_id={task_id}")
- # 获取任务信息
- sgbx_task_info = await workflow_manager.get_outline_sgbx_task_info(task_id)
- if sgbx_task_info is None:
- return {
- "code": 404,
- "message": "任务不存在或已完成",
- "data": None
- }
- return {
- "code": 200,
- "message": "查询成功",
- "data": sgbx_task_info
- }
- except Exception as e:
- logger.error(f"查询任务状态失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
- @outline_router.get("/active_tasks")
- @auto_trace(generate_if_missing=True)
- async def active_tasks(
- user_id: str = Query(None, description="用户ID(可选,不提供则返回所有任务)")
- ):
- """
- 获取活跃的大纲生成任务列表
- Args:
- user_id: 用户ID(可选)
- Returns:
- dict: 活跃任务列表
- """
- try:
- logger.info(f"获取活跃任务列表: user_id={user_id}")
- # 获取所有活跃任务
- active_tasks_list = await workflow_manager.get_outline_active_tasks()
- # 如果指定了用户ID,则过滤
- if user_id:
- active_tasks_list = [task for task in active_tasks_list if task["user_id"] == user_id]
- return {
- "code": 200,
- "message": "查询成功",
- "data": {
- "total": len(active_tasks_list),
- "tasks": active_tasks_list
- }
- }
- except Exception as e:
- logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"获取活跃任务列表失败: {str(e)}")
- # 将上下文生成路由注册到应用
- outline_router.include_router(context_generate_router)
|