| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687 |
- # -*- coding: utf-8 -*-
- """
- 大纲生成 API 接口 (SSE 版本)
- 集成到 Celery + WorkflowManager 架构中
- 提供以下接口:
- - SSE /sgbx/generating_outline: SSE 流式大纲生成
- - SSE /sgbx/regenerate_outline: SSE 流式重新生成
- - POST /sgbx/task_cancel: 取消大纲生成任务
- - POST /sgbx/context_generate: SSE 流式上下文生成 (新增)
- """
- import uuid
- import json
- import time
- import asyncio
- from typing import Optional, Dict, Any, List, AsyncGenerator, Union
- from pydantic import BaseModel, Field
- from fastapi import APIRouter, HTTPException, Query
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from foundation.ai.agent.generate.model_generate import generate_model_client
- from core.base.workflow_manager import workflow_manager
- from core.base.sse_manager import unified_sse_manager
- from core.base.progress_manager import ProgressManager
- from redis.asyncio import Redis as AsyncRedis
- # 创建路由
- outline_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- # 初始化进度管理器
- progress_manager = ProgressManager()
- async def sse_progress_callback(callback_task_id: str, current_data: dict):
- """SSE 推送回调函数 - 接收进度更新并推送到客户端"""
- await unified_sse_manager.send_progress(callback_task_id, current_data)
- def format_sse_event(event_type: str, data: str) -> str:
- """格式化 SSE 事件 - 按照 SSE 协议格式化事件数据"""
- lines = [
- f"event: {event_type}",
- f"data: {data}",
- "",
- ""
- ]
- return "\n".join(lines) + "\n"
- # ==================== 请求/响应模型 ====================
- class BaseInfo(BaseModel):
- """项目基础信息"""
- project_name: str = Field(..., description="方案名称", example="罗成依达大桥上部结构专项施工方案")
- construct_location: str = Field(..., description="建设地点", example="四川省凉山州")
- engineering_type: str = Field(..., description="方案模版类型", example="T型梁")
- construction_unit: str = Field(..., description="施工单位")
- supervision_unit: str = Field(..., description="监理单位")
- class ProjectInfo(BaseModel):
- """项目信息(嵌套结构)"""
- base_info: BaseInfo = Field(..., description="基础信息")
- selectable: Optional[str] = Field("", description="其他可选信息")
- class TemplateStructureItem(BaseModel):
- """模板结构项(支持嵌套children)"""
- index: str = Field(..., description="章节编号", example="2")
- level: int = Field(..., description="层级", ge=1, le=5)
- title: str = Field(..., description="章节标题", example="工程概况")
- code: str = Field(..., description="章节代码", example="overview")
- template_content:str = Field("", description="模板内容")
- is_enabled: bool = Field(True, description="是否启用该章节")
- # 使用 Union 支持递归类型
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class GenerationTemplate(BaseModel):
- """大纲生成模板
- 示例:
- {
- "source_file": "方案编写助手原文关键词规范文档修改版-2026-2-5.md",
- "alias": "施工方案知识审查与编写体系",
- "structure": [
- {
- "index": "2",
- "level": 1,
- "title": "工程概况",
- "code": "overview",
- "children": [...]
- }
- ]
- }
- """
- source_file: Optional[str] = Field(None, description="源文件", example="方案编写助手原文关键词规范文档修改版-2026-2-5.md")
- alias: Optional[str] = Field(None, description="别名", example="施工方案知识审查与编写体系")
- structure: List[Union[TemplateStructureItem, Dict[str, Any]]] = Field(..., description="模板结构")
- class OutlineGenerationRequest(BaseModel):
- """大纲生成请求
- 示例请求体(适配curl示例):
- {
- "user_id": "user-001",
- "project_info": {
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": {
- "source_file": "方案编写助手原文关键词规范文档修改版-2026-2-5.md",
- "alias": "施工方案知识审查与编写体系",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_ProjectIntroduction", ...]
- }
- """
- user_id: str = Field(..., description="用户标识", example="user-001")
- project_info: ProjectInfo = Field(..., description="项目基础信息")
- generation_template: GenerationTemplate = Field(..., description="大纲生成模板")
- generation_chapterenum: List[str] = Field(default_factory=list, description="生成章节代码列表,为空时生成全部章节")
- class RegenerateOutlineRequest(BaseModel):
- """重新生成大纲请求
- 复用大纲生成接口的请求定义,额外添加 regenerate_config 字段用于指定重新生成配置。
- project_info 和 generation_template 为可选字段,不传入则使用原任务的信息。
- 示例请求:
- {
- "task_id": "task-20250130-123456",
- "user_id": "user-001",
- "project_info": { // 可选,不传则使用原任务的项目信息
- "base_info": {
- "project_name": "罗成依达大桥上部结构专项施工方案",
- "construct_location": "四川省凉山州",
- "engineering_type": "T型梁"
- },
- "selectable": ""
- },
- "generation_template": { // 可选,不传则使用原任务的模板
- "source_file": "...",
- "alias": "...",
- "structure": [...]
- },
- "generation_chapterenum": ["overview_DesignSummary_MainTechnicalStandards"], // 可选
- "regenerate_config": {
- "regenerate_mode": "chapter",
- "target_path": "2.1",
- "preserve_children": true,
- "reason": "调整内容结构"
- }
- }
- """
- task_id: str = Field(..., description="原大纲生成任务ID")
- user_id: str = Field(..., description="用户ID")
- # 可选:复用大纲生成接口的项目信息(不传则使用原任务的)
- project_info: Optional[ProjectInfo] = Field(None, description="项目基础信息(可选)")
- # 可选:复用大纲生成接口的模板(不传则使用原任务的)
- generation_template: Optional[GenerationTemplate] = Field(None, description="大纲生成模板(可选)")
- # 可选:复用大纲生成接口的章节代码列表
- generation_chapterenum: Optional[List[str]] = Field(None, description="生成章节代码列表(可选)")
- # 重新生成特有的配置
- regenerate_config: Dict[str, Any] = Field(..., description="重新生成配置")
- class TaskCancelRequest(BaseModel):
- """任务取消请求"""
- task_id: str = Field(..., description="任务ID")
- user_id: str = Field(..., description="用户ID")
- cancel_reason: Optional[str] = Field("用户主动取消", description="取消原因")
- # ==================== 响应模型 ====================
- class OutlineNodeResponse(BaseModel):
- """大纲节点响应模型
- 与请求的 TemplateStructureItem 对应,增加以下字段:
- - 每级节点都包含 generated_content
- - 2级和3级节点包含 similar_fragments
- - 末级节点额外包含 key_points 和 knowledge_bases
- """
- index: str = Field(..., description="章节编号", example="2.1.1")
- level: int = Field(..., description="层级", ge=1, le=5, example=3)
- title: str = Field(..., description="章节标题", example="工程简介")
- code: str = Field(..., description="章节代码", example="overview_DesignSummary_ProjectIntroduction")
- generated_content: str = Field(..., description="AI生成的内容", example="罗成依达大桥位于四川省凉山州...")
- # 2级和3级节点包含 similar_fragments
- similar_fragments: Optional[List[Dict[str, Any]]] = Field(None, description="相似片段推荐(2级和3级节点)")
- key_points: Optional[List[str]] = Field(None, description="核心要点(仅末级节点)", example=["桥位位置", "桥梁规模"])
- knowledge_bases: Optional[List[str]] = Field(None, description="知识点/编制依据(仅末级节点)", example=["《公路桥涵设计通用规范》JTG D60-2015"])
- children: Optional[List[Dict[str, Any]]] = Field(None, description="子章节(递归结构)")
- class SimilarPlanResponse(BaseModel):
- """相似方案响应(整篇方案级推荐)"""
- plan_id: str = Field(..., description="方案ID")
- plan_title: str = Field(..., description="方案标题")
- similarity_score: float = Field(..., description="相似度分数", ge=0.0, le=1.0)
- plan_type: str = Field(..., description="方案类型")
- outline: Optional[List[Dict]] = Field(None, description="方案大纲结构")
- metadata: Optional[Dict[str, Any]] = Field(None, description="元数据")
- class SimilarFragmentResponse(BaseModel):
- """相似片段响应"""
- fragment_id: str = Field(..., description="片段ID")
- section_path: str = Field(..., description="所属章节路径")
- section_title: str = Field(..., description="章节标题")
- fragment_content: str = Field(..., description="片段内容")
- similarity_score: float = Field(..., description="相似度分数", ge=0.0, le=1.0)
- source_document_id: str = Field(..., description="来源文档ID")
- source_document_title: str = Field(..., description="来源文档标题")
- class OutlineGenerationResult(BaseModel):
- """大纲生成结果
- 结构说明:
- - outline_structure: 嵌套大纲结构,每个节点包含 generated_content
- - 2级和3级节点额外包含 similar_fragments
- - 末级节点额外包含 key_points 和 knowledge_bases
- - similar_plan: 整篇方案的相似方案推荐(顶层)
- """
- outline_structure: List[OutlineNodeResponse] = Field(..., description="大纲结构(包含AI生成内容和章节级similar_fragments)")
- similar_plan: List[SimilarPlanResponse] = Field(default_factory=list, description="相似方案推荐(整篇方案级)")
- # ==================== 上下文生成新增模型 ====================
- class CompletionConfig(BaseModel):
- section_path: str = Field(..., description="章节路径")
- current_content: str = Field(default="", description="当前已有内容")
- context_window: int = Field(default=2000, ge=500, le=5000)
- completion_mode: str = Field(default="continue", description="模式")
- target_length: int = Field(default=1000, ge=100, le=5000)
- include_references: bool = Field(default=True)
- style_match: bool = Field(default=True)
- hint_keywords: Optional[List[str]] = Field(default=None)
- class ProjectInfoSimple(BaseModel):
- project_name: str = Field(default="施工方案")
- construct_location: Optional[str] = Field(default=None)
- engineering_type: Optional[str] = Field(default=None)
- class ContextGenerateRequest(BaseModel):
- task_id: Optional[str] = Field(default=None)
- user_id: str = Field(...)
- project_info: Optional[ProjectInfoSimple] = Field(default=None)
- completion_config: CompletionConfig = Field(...)
- model_name: Optional[str] = Field(default=None)
- class Config: extra = "forbid"
- class ContextGenerateResponse(BaseModel):
- code: int
- message: str
- data: Optional[Dict[str, Any]] = None
- # ==================== 全局资源池 ====================
- CONTEXT_GENERATE_FUNCTION = "write_content_generate"
- GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
- async def init_global_resources():
- """初始化全局连接池"""
- global GLOBAL_REDIS_CLIENT
- if GLOBAL_REDIS_CLIENT is None:
- try:
- GLOBAL_REDIS_CLIENT = AsyncRedis(
- host='127.0.0.1', port=6379, password='123456', db=0,
- decode_responses=True, socket_connect_timeout=1,
- socket_keepalive=True, max_connections=50
- )
- asyncio.create_task(_background_ping())
- logger.info("✅ 全局 Redis 连接池已初始化")
- except Exception as e:
- logger.warning(f"⚠️ Redis 初始化失败: {e}")
- GLOBAL_REDIS_CLIENT = None
- async def _background_ping():
- if GLOBAL_REDIS_CLIENT:
- try: await GLOBAL_REDIS_CLIENT.ping()
- except: pass
- async def get_redis_client():
- if GLOBAL_REDIS_CLIENT is None:
- await init_global_resources()
- return GLOBAL_REDIS_CLIENT
- # ==================== 极速流式调用 (通过统一模型框架) ====================
- async def call_custom_api_stream(
- prompt: str, system_prompt: str = "", max_tokens: int = 2000,
- temperature: float = 0.7, trace_id: str = ""
- ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
- """流式调用 LLM,通过 generate_model_client 统一管理"""
- max_prompt_len = 10000
- if len(prompt) > max_prompt_len:
- prompt = prompt[-max_prompt_len:]
- logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
- start_time = time.time()
- first_token_time: Optional[float] = None
- try:
- for chunk in generate_model_client.get_model_generate_stream(
- trace_id=trace_id,
- system_prompt=system_prompt,
- user_prompt=prompt,
- function_name=CONTEXT_GENERATE_FUNCTION
- ):
- if first_token_time is None:
- first_token_time = time.time() - start_time
- yield (chunk, first_token_time)
- except Exception as e:
- logger.error(f"[{trace_id}] 流式请求异常: {e}")
- raise
- # ==================== 上下文生成业务逻辑辅助 ====================
- CONTEXT_GENERATE_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
- def build_context_generate_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
- parts = []
- parts.append(f"【项目】{project_info.get('project_name', '未知')}")
- parts.append(f"【章节】{section_title} ({section_path})")
- parts.append(f"【模式】{completion_mode} (目标:{target_length})")
-
- if context_before: parts.append(f"【前文】...{context_before[-500:]}")
- if current_content: parts.append(f"【当前】{current_content}")
- if context_after: parts.append(f"【后文】{context_after[:500]}...")
-
- parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
- return "\n".join(parts)
- def extract_chunk_content(chunk: Any) -> str:
- if isinstance(chunk, str): return chunk
- if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
- if isinstance(chunk, dict): return str(chunk.get('content', ''))
- return str(chunk)
- def validate_user_id(user_id: str):
- supported_users = {'user-001', 'user-002', 'user-003'}
- if user_id not in supported_users:
- raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
- def validate_completion_config(config: CompletionConfig):
- if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
- raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
- def validate_request(request: ContextGenerateRequest):
- if not request.task_id and not request.project_info:
- raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
- # ==================== 上下文生成核心流式逻辑 ====================
- async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
- async def is_cancelled() -> bool:
- if not redis_client: return False
- try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
- except: return False
- stream_start_time = time.time()
- first_token_latency: Optional[float] = None
- full_content_parts: List[str] = []
- chunk_count = 0
- try:
- yield format_sse_event("connected", json.dumps({
- "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
- }, ensure_ascii=False))
- project_info = request.project_info.dict() if request.project_info else {}
- section_title = f"章节 {request.completion_config.section_path}"
-
- user_prompt = build_context_generate_prompt(
- project_info=project_info,
- section_path=request.completion_config.section_path,
- section_title=section_title,
- current_content=request.completion_config.current_content,
- completion_mode=request.completion_config.completion_mode,
- target_length=request.completion_config.target_length,
- include_references=request.completion_config.include_references,
- style_match=request.completion_config.style_match,
- hint_keywords=request.completion_config.hint_keywords
- )
- yield format_sse_event("generating", json.dumps({
- "status": "generating",
- "message": "正在调用 LLM 模型 (write_content_generate)...",
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- async for content, ftl in call_custom_api_stream(
- prompt=user_prompt,
- system_prompt=CONTEXT_GENERATE_SYSTEM_PROMPT,
- max_tokens=min(request.completion_config.target_length, 4000),
- temperature=0.7,
- trace_id=callback_task_id
- ):
- if await is_cancelled():
- yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
- return
- if content:
- full_content_parts.append(content)
- chunk_count += 1
- if first_token_latency is None:
- first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
- logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s")
- yield format_sse_event("chunk", json.dumps({
- "chunk": content,
- "first_token_latency": round(first_token_latency, 3),
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- # 完成统计
- total_duration = time.time() - stream_start_time
- full_content = "".join(full_content_parts)
-
- logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
- yield format_sse_event("completed", json.dumps({
- "callback_task_id": callback_task_id,
- "status": "completed",
- "metrics": {
- "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
- "total_duration": round(total_duration, 3),
- "char_count": len(full_content),
- "chunk_count": chunk_count,
- "model_used": CONTEXT_GENERATE_FUNCTION
- },
- "full_content": full_content,
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- except Exception as e:
- logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
- yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
- # ==================== 上下文生成API路由 ====================
- context_generate_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- @context_generate_router.post("/context_generate")
- @auto_trace(generate_if_missing=True)
- async def context_generate(request: ContextGenerateRequest):
- callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
-
- receive_time = time.time()
-
- try:
- validate_user_id(request.user_id)
- validate_completion_config(request.completion_config)
- validate_request(request)
-
- redis_client = await get_redis_client()
-
- logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
- return StreamingResponse(
- generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache, no-store, must-revalidate",
- "Pragma": "no-cache",
- "Expires": "0",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- "Content-Type": "text/event-stream; charset=utf-8",
- "Access-Control-Allow-Origin": "*"
- }
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @context_generate_router.get("/context_generate_health")
- async def health_check():
- return {
- "status": "healthy",
- "provider": "Shutian",
- "current_model": CONTEXT_GENERATE_FUNCTION,
- "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
- }
- @context_generate_router.get("/context_generate_modes", response_model=ContextGenerateResponse)
- async def get_modes():
- modes = [
- {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
- {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
- ]
- return ContextGenerateResponse(code=200, message="success", data={"modes": modes})
- @context_generate_router.get("/context_generate_api_status", response_model=ContextGenerateResponse)
- async def get_api_status():
- return ContextGenerateResponse(
- code=200, message="success",
- data={
- "enabled": True,
- "provider": "Shutian",
- "model": CONTEXT_GENERATE_FUNCTION
- }
- )
- # ==================== 原有大纲生成接口实现 ====================
- @outline_router.post("/generating_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def generating_outline(request: OutlineGenerationRequest):
- """
- 大纲生成接口 (SSE 流式响应)
- """
- callback_task_id = f"outline_{uuid.uuid4().hex[:16]}"
- TraceContext.set_trace_id(callback_task_id)
- user_id = request.user_id
- logger.info(f"接收大纲生成 SSE 请求: user_id={user_id}, project={request.project_info.base_info.project_name}")
- # ===== 新增:创建 Redis 连接用于检查终止标志 =====
- redis_check_client = None
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # ===== 新增:定义取消检查函数 =====
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{callback_task_id}") > 0
- except Exception:
- return False
- # 使用统一 SSE 管理器建立连接并注册回调
- queue = await unified_sse_manager.establish_connection(callback_task_id, sse_progress_callback)
- async def generate_outline_events() -> AsyncGenerator[str, None]:
- """生成大纲生成 SSE 事件流"""
- try:
- # ===== 检查点1: 函数开始 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 发送连接确认事件
- connected_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": "SSE 连接已建立,正在启动大纲生成任务...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # 构建任务信息
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or "",
- "construction_unit": base_info.construction_unit or "",
- "supervision_unit": base_info.supervision_unit or ""
- }
- sgbx_task_info = {
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": request.generation_template.alias or "default_template",
- "generation_chapterenum": request.generation_chapterenum,
- "generation_template": [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ],
- "similarity_config": {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- },
- "knowledge_config": {
- "topk": 3,
- "threshold": 0.75
- },
- }
- # ===== 检查点2: 任务提交前 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 发送处理中事件
- processing_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": "正在提交大纲生成任务...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # 提交任务到 Celery
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"大纲生成任务已提交: callback_task_id={callback_task_id}, celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "大纲生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # 持续监听进度并转发
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- no_data_count = 0
- while True:
- try:
- # ===== 检查点3: 每次轮询前检查取消 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 先推送章节内容流式事件,避免并发章节覆盖进度快照
- stream_events = await progress_manager.pop_stream_events(callback_task_id)
- for stream_data in stream_events:
- last_progress = stream_data.get("current", last_progress)
- last_event_type = stream_data.get("event_type", "processing")
- last_message = stream_data.get("message", "")
- last_progress_data = stream_data
- yield format_sse_event("processing", json.dumps(stream_data, ensure_ascii=False))
- no_change_count = 0
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(callback_task_id)
- if progress_data:
- no_data_count = 0
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
-
- # ===== 新增:检测到取消立即返回 =====
- if status == "cancelled":
- logger.info(f"[{callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
-
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- stream_events = await progress_manager.pop_stream_events(callback_task_id)
- for stream_data in stream_events:
- last_progress = stream_data.get("current", last_progress)
- last_event_type = stream_data.get("event_type", "processing")
- last_message = stream_data.get("message", "")
- last_progress_data = stream_data
- yield format_sse_event("processing", json.dumps(stream_data, ensure_ascii=False))
- break
- else:
- no_data_count += 1
- if no_data_count >= 60: # 30 秒无数据
- logger.info(f"[{callback_task_id}] 进度数据丢失,结束SSE")
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "大纲生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # 获取最终结果
- # 获取最终结果
- final_result = await workflow_manager.get_outline_sgbx_task_info(callback_task_id)
- # ===== 检查点4: 结果返回前检查取消 =====
- if await is_task_cancelled():
- logger.info(f"[{callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 新增:检查任务结果是否为已取消 =====
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "大纲生成完成",
- "status": "completed",
- "message": "大纲生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "大纲生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"大纲生成 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": callback_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(callback_task_id)
- return StreamingResponse(
- generate_outline_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- @outline_router.post("/regenerate_outline", response_model=None)
- @auto_trace(generate_if_missing=True)
- async def regenerate_outline(request: RegenerateOutlineRequest):
- """
- 重新生成大纲接口 (SSE 流式响应)
- 【任务状态管理】
- - 重新生成会创建**新任务**,原任务状态保持不变
- - 新任务通过 regenerate_config.source_task_id 关联原任务
- - 原任务仍可查询,不受影响
- 【字段说明】
- - generation_chapterenum: 可选,默认使用原任务的章节列表
- - project_info: 可选,默认使用原任务的项目信息
- - generation_template: 可选,默认使用原任务的模板
- 【错误处理】
- - 原任务不存在: 返回 404 错误事件
- - 原任务已完成/失败: 允许重新生成(基于已完成结果进行局部调整)
- - 重新生成配置缺失: 返回 400 错误事件
- 【与 /generating_outline 的复用关系】
- - 复用 generating_outline 的核心 SSE 事件流生成逻辑
- - 差异点:1) 构建任务信息时合并原任务数据 2) 添加 regenerate_config 标记
- """
- # ===== 1. 参数校验 =====
- if not request.regenerate_config:
- logger.error("重新生成配置缺失")
- raise HTTPException(status_code=400, detail="regenerate_config 为必填项")
- # 生成新任务ID(重要:重新生成创建新任务,不覆盖原任务)
- new_callback_task_id = f"outline_regen_{uuid.uuid4().hex[:16]}"
- source_task_id = request.task_id # 原任务ID用于数据查询
- TraceContext.set_trace_id(new_callback_task_id)
- user_id = request.user_id
- regenerate_config = request.regenerate_config
- logger.info(f"接收重新生成大纲 SSE 请求: "
- f"source_task_id={source_task_id}, "
- f"new_task_id={new_callback_task_id}, "
- f"user_id={user_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}")
- # ===== 2. 获取原任务信息(带错误处理)=====
- original_task = None
- try:
- original_task = await workflow_manager.get_outline_sgbx_task_info(source_task_id)
- except Exception as e:
- logger.warning(f"获取原任务信息异常: {source_task_id}, error={e}")
- # 原任务不存在处理
- if not original_task:
- logger.error(f"原任务不存在: {source_task_id}")
- async def error_not_found():
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "原任务不存在",
- "status": "error",
- "message": f"原任务不存在或已过期: {source_task_id}",
- "overall_task_status": "failed",
- "error_code": "SOURCE_TASK_NOT_FOUND",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- return StreamingResponse(
- error_not_found(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # 获取原任务状态
- original_status = original_task.get("status") or original_task.get("overall_task_status", "unknown")
- logger.info(f"原任务状态: {source_task_id} = {original_status}")
- # 使用统一 SSE 管理器建立连接(使用新任务ID)
- queue = await unified_sse_manager.establish_connection(new_callback_task_id, sse_progress_callback)
- # ===== 3. 复用 generating_outline 的核心逻辑 =====
- async def generate_regenerate_events() -> AsyncGenerator[str, None]:
- """生成重新生成 SSE 事件流 - 复用 generating_outline 模式"""
- redis_check_client = None
- try:
- # ===== 3.1 初始化 Redis 连接(复用 generating_outline 模式)=====
- try:
- redis_check_client = AsyncRedis(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- socket_connect_timeout=2,
- socket_timeout=2
- )
- except Exception as e:
- logger.warning(f"[{new_callback_task_id}] 创建取消检查Redis连接失败: {e}")
- # 定义取消检查函数(复用 generating_outline 模式)
- async def is_task_cancelled() -> bool:
- """检查任务是否被取消"""
- if not redis_check_client or not new_callback_task_id:
- return False
- try:
- return await redis_check_client.exists(f"terminate:{new_callback_task_id}") > 0
- except Exception:
- return False
- # ===== 3.2 检查取消(复用 generating_outline 检查点1)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 连接建立前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.3 发送连接确认(复用 generating_outline 模式)=====
- connected_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "连接建立",
- "status": "connected",
- "message": f"SSE 连接已建立,正在启动重新生成任务(原任务: {source_task_id}, 状态: {original_status})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("connected", connected_data)
- # ===== 3.4 构建任务信息(合并原任务数据 + 新配置)=====
- # 优先使用传入的 project_info,否则使用原任务的
- if request.project_info:
- base_info = request.project_info.base_info
- project_info_flat = {
- "project_name": base_info.project_name,
- "construct_location": base_info.construct_location,
- "engineering_type": base_info.engineering_type,
- "selectable": request.project_info.selectable or ""
- }
- else:
- project_info_flat = original_task.get("project_info", {})
- # 处理 generation_template
- if request.generation_template:
- outline_structure = [
- item.dict() if isinstance(item, TemplateStructureItem) else item
- for item in request.generation_template.structure
- ]
- template_alias = request.generation_template.alias or "default_template"
- else:
- # 从原任务提取模板结构
- outline_structure = original_task.get("generation_template", [])
- if not outline_structure:
- outline_structure = original_task.get("results", {}).get("outline_structure", [])
- template_alias = original_task.get("template_id", "default_template")
- # 处理 generation_chapterenum(可选,默认使用原任务)
- generation_chapterenum = request.generation_chapterenum
- if generation_chapterenum is None:
- generation_chapterenum = original_task.get("generation_chapterenum", [])
- # 如果原任务也没有,则根据 regenerate_config.target_path 推断
- if not generation_chapterenum and regenerate_config.get("target_path"):
- target_path = regenerate_config.get("target_path")
- # 内嵌:根据路径查找章节代码的逻辑
- original_outline = original_task.get("results", {}).get("outline_structure", [])
- chapter_code = None
- if original_outline and target_path:
- path_parts = target_path.split(".")
-
- def search_in_nodes(nodes, depth=0):
- if depth >= len(path_parts):
- return None
- target_index = path_parts[depth]
- for node in nodes:
- node_index = str(node.get("index", ""))
- if node_index == target_index:
- if depth == len(path_parts) - 1:
- return node.get("code")
- children = node.get("children", [])
- if children:
- result = search_in_nodes(children, depth + 1)
- if result:
- return result
- return None
-
- chapter_code = search_in_nodes(original_outline)
-
- if chapter_code:
- generation_chapterenum = [chapter_code]
- # 构建完整任务信息(与 generating_outline 格式保持一致)
- sgbx_task_info = {
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id, # 关联原任务
- "user_id": user_id,
- "project_info": project_info_flat,
- "template_id": template_alias,
- "generation_chapterenum": generation_chapterenum,
- "generation_template": outline_structure,
- "similarity_config": original_task.get("similarity_config", {
- "topk_plans": 3,
- "topk_fragments": 10,
- "threshold": 0.75
- }),
- "knowledge_config": original_task.get("knowledge_config", {
- "topk": 3,
- "threshold": 0.75
- }),
- # 重新生成特有配置
- "regenerate_config": regenerate_config,
- "is_regenerate": True,
- "original_task_status": original_status # 记录原任务状态
- }
- logger.info(f"重新生成任务信息构建完成: "
- f"new_task_id={new_callback_task_id}, "
- f"source_task_id={source_task_id}, "
- f"target={regenerate_config.get('target_path', 'unknown')}, "
- f"chapters={generation_chapterenum}")
- # ===== 3.5 检查取消(复用 generating_outline 检查点2)=====
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 任务提交前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.6 发送处理中事件(复用 generating_outline 模式)=====
- processing_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 5,
- "stage_name": "任务提交中",
- "status": "processing",
- "message": f"正在提交重新生成任务(目标: {regenerate_config.get('target_path', 'unknown')})...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("processing", processing_data)
- # ===== 3.7 提交任务到 Celery(复用 generating_outline 模式)=====
- celery_task_id = await workflow_manager.submit_outline_generation_task(sgbx_task_info)
- logger.info(f"重新生成任务已提交: "
- f"new_callback_task_id={new_callback_task_id}, "
- f"celery_task_id={celery_task_id}")
- # 发送任务提交成功事件
- submitted_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 10,
- "stage_name": "任务已提交",
- "status": "submitted",
- "message": "重新生成任务已提交,正在执行...",
- "overall_task_status": "processing",
- "updated_at": int(time.time()),
- "celery_task_id": celery_task_id
- }, ensure_ascii=False)
- yield format_sse_event("submitted", submitted_data)
- # ===== 3.8 持续监听进度(完全复用 generating_outline 模式)=====
- last_progress = 10
- last_progress_data = None
- last_event_type = "processing"
- last_message = ""
- no_change_count = 0
- no_data_count = 0
- while True:
- try:
- # 检查取消(复用 generating_outline 检查点3)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 进度轮询中检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 先推送章节内容流式事件,避免并发章节覆盖进度快照
- stream_events = await progress_manager.pop_stream_events(new_callback_task_id)
- for stream_data in stream_events:
- last_progress = stream_data.get("current", last_progress)
- last_event_type = stream_data.get("event_type", "processing")
- last_message = stream_data.get("message", "")
- last_progress_data = stream_data
- yield format_sse_event("processing", json.dumps(stream_data, ensure_ascii=False))
- no_change_count = 0
- # 从 Redis 获取最新进度
- progress_data = await progress_manager.get_progress(new_callback_task_id)
- if progress_data:
- no_data_count = 0
- current_progress = progress_data.get("current", last_progress)
- current_event_type = progress_data.get("event_type", "processing")
- current_message = progress_data.get("message", "")
- # 检查进度数据中是否已经是取消状态
- if progress_data.get("overall_task_status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 从进度数据检测到取消状态")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 进度有变化时推送
- should_push = False
- if current_progress != last_progress:
- should_push = True
- elif current_event_type != last_event_type:
- should_push = True
- elif current_message != last_message:
- should_push = True
- elif last_progress_data is None:
- should_push = True
- elif progress_data.get("overall_task_status") != last_progress_data.get("overall_task_status"):
- should_push = True
- if should_push:
- last_progress = current_progress
- last_event_type = current_event_type
- last_message = current_message
- last_progress_data = progress_data
- yield format_sse_event("processing", json.dumps(progress_data, ensure_ascii=False))
- no_change_count = 0
- else:
- no_change_count += 1
- # 检查任务状态
- status = progress_data.get("overall_task_status")
- # 检测到取消立即返回
- if status == "cancelled":
- logger.info(f"[{new_callback_task_id}] 检测到任务已取消")
- yield format_sse_event("cancelled", json.dumps(progress_data, ensure_ascii=False))
- return
- # 检查任务是否完成
- if status in ["completed", "failed", "terminated"]:
- stream_events = await progress_manager.pop_stream_events(new_callback_task_id)
- for stream_data in stream_events:
- last_progress = stream_data.get("current", last_progress)
- last_event_type = stream_data.get("event_type", "processing")
- last_message = stream_data.get("message", "")
- last_progress_data = stream_data
- yield format_sse_event("processing", json.dumps(stream_data, ensure_ascii=False))
- break
- else:
- no_data_count += 1
- if no_data_count >= 60: # 30 秒无数据
- logger.info(f"[{new_callback_task_id}] 进度数据丢失,结束SSE")
- break
- await asyncio.sleep(0.5)
- # 每 6 秒发送一次心跳
- if no_change_count >= 30:
- heartbeat_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "执行中",
- "status": "processing",
- "message": "重新生成任务正在执行中...",
- "overall_task_status": "processing",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("heartbeat", heartbeat_data)
- no_change_count = 0
- except Exception as e:
- logger.warning(f"轮询进度异常: {new_callback_task_id}, 错误: {str(e)}")
- await asyncio.sleep(0.5)
- # ===== 3.9 获取最终结果(复用 generating_outline 模式)=====
- final_result = await workflow_manager.get_outline_sgbx_task_info(new_callback_task_id)
- # 检查取消(复用 generating_outline 检查点4)
- if await is_task_cancelled():
- logger.info(f"[{new_callback_task_id}] 结果返回前检测到取消信号")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": "任务已被用户取消",
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # 检查任务结果是否为已取消
- if final_result and final_result.get("status") == "cancelled":
- logger.info(f"[{new_callback_task_id}] 任务结果状态为已取消,不返回实际结果")
- cancelled_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务已取消",
- "status": "cancelled",
- "message": final_result.get("message", "任务已被用户取消"),
- "overall_task_status": "cancelled",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("cancelled", cancelled_data)
- return
- # ===== 3.10 返回最终结果(复用 generating_outline 模式)=====
- if final_result and final_result.get("status") == "completed":
- completed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 100,
- "stage_name": "重新生成完成",
- "status": "completed",
- "message": "大纲重新生成任务已完成",
- "overall_task_status": "completed",
- "updated_at": int(time.time()),
- "result": {
- "outline_structure": final_result.get("results", {}).get("outline_structure", []),
- "similar_plan": final_result.get("results", {}).get("similar_plan", [])
- }
- }, ensure_ascii=False)
- yield format_sse_event("completed", completed_data)
- else:
- failed_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": last_progress,
- "stage_name": "任务失败",
- "status": "failed",
- "message": final_result.get("results", {}).get("error", "重新生成任务失败") if final_result else "任务执行失败",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("failed", failed_data)
- except Exception as e:
- logger.error(f"重新生成大纲 SSE 事件流错误: {str(e)}", exc_info=True)
- error_data = json.dumps({
- "callback_task_id": new_callback_task_id,
- "source_task_id": source_task_id,
- "user_id": user_id,
- "current": 0,
- "stage_name": "系统错误",
- "status": "error",
- "message": f"系统错误: {str(e)}",
- "overall_task_status": "failed",
- "updated_at": int(time.time())
- }, ensure_ascii=False)
- yield format_sse_event("error", error_data)
- finally:
- # 关闭 Redis 连接
- if redis_check_client:
- try:
- await redis_check_client.close()
- except Exception:
- pass
- # 关闭 SSE 连接
- await unified_sse_manager.close_connection(new_callback_task_id)
- return StreamingResponse(
- generate_regenerate_events(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no"
- }
- )
- # ==================== POST 接口 ====================
- @outline_router.post("/task_cancel")
- @auto_trace(generate_if_missing=True)
- async def task_cancel(request: TaskCancelRequest):
- """
- 取消大纲生成任务
-
- 【修复】现在支持取消预注册状态(pending)的任务,即任务提交后、Worker 执行前的时间段。
- """
- import redis.asyncio as redis_async
- from redis.asyncio.connection import ConnectionPool
-
- redis_client = None
-
- try:
- logger.info(f"接收取消任务请求: task_id={request.task_id}")
-
- # 【修复】优先使用 workflow_manager 获取任务信息(支持预注册状态)
- task_info = await workflow_manager.get_outline_sgbx_task_info(request.task_id)
-
- if not task_info:
- return {
- "code": 404,
- "message": "任务不存在",
- "data": {"task_id": request.task_id, "error_type": "TASK_NOT_FOUND"}
- }
-
- # 检查任务状态
- task_status = task_info.get("status") or task_info.get("overall_task_status", "unknown")
- is_pre_registered = task_info.get("is_pre_registered", False)
-
- if task_status == "cancelled":
- return {
- "code": 200,
- "message": "任务已处于取消状态",
- "data": {"task_id": request.task_id, "status": "cancelled"}
- }
-
- if task_status in ["completed", "failed"]:
- return {
- "code": 400,
- "message": f"任务已完成,无法取消",
- "data": {"task_id": request.task_id, "current_status": task_status}
- }
-
- # 【修复】使用 workflow_manager 的 set_outline_terminate_signal 方法
- # 支持 pending(预注册)和 processing(执行中)两种状态
- result = await workflow_manager.set_outline_terminate_signal(
- callback_task_id=request.task_id,
- operator=request.user_id
- )
-
- if not result.get("success"):
- logger.warning(f"设置终止信号失败: {result.get('message')}")
- return {
- "code": 400,
- "message": result.get("message", "取消任务失败"),
- "data": {"task_id": request.task_id, "current_status": task_status}
- }
-
- cancelled_at = int(time.time())
-
- # 【修复】如果是预注册状态(pending),任务已被直接取消
- if is_pre_registered or task_status == "pending":
- logger.info(f"预注册任务已被取消: {request.task_id}")
-
- # 更新进度信息
- try:
- await progress_manager.update_stage_progress(
- callback_task_id=request.task_id,
- overall_task_status="cancelled",
- status="cancelled",
- message=f"任务已被用户取消: {request.cancel_reason}"
- )
- except Exception as e:
- logger.warning(f"更新进度信息失败: {e}")
-
- return {
- "code": 200,
- "message": "任务已成功取消",
- "data": {
- "task_id": request.task_id,
- "status": "cancelled",
- "cancelled_at": cancelled_at,
- "cancel_reason": request.cancel_reason,
- "cancelled_by": request.user_id,
- "is_pre_registered": True
- }
- }
-
- # 对于正在执行的任务(processing),设置额外的取消标志
- try:
- pool = ConnectionPool(
- host='127.0.0.1',
- port=6379,
- password='123456',
- db=0,
- decode_responses=True,
- max_connections=20,
- socket_connect_timeout=10,
- socket_timeout=10,
- retry_on_timeout=True,
- health_check_interval=30
- )
-
- redis_client = redis_async.Redis(connection_pool=pool)
-
- terminate_data = json.dumps({
- "cancelled": True,
- "cancelled_by": request.user_id,
- "cancel_reason": request.cancel_reason,
- "cancelled_at": cancelled_at
- })
-
- pipe = redis_client.pipeline()
- pipe.set(f"terminate:{request.task_id}", terminate_data, ex=3600)
- pipe.hset(f"progress:{request.task_id}", mapping={
- "overall_task_status": "cancelled",
- "status": "cancelled",
- "message": f"任务已被用户取消: {request.cancel_reason}",
- "cancelled_at": str(cancelled_at),
- "updated_at": str(cancelled_at)
- })
- await pipe.execute()
- logger.info(f"终止标志已设置: {request.task_id}")
-
- await redis_client.close()
- await pool.disconnect()
- redis_client = None
-
- except Exception as e:
- logger.error(f"设置终止标志失败: {e}")
- # 不影响主流程,继续执行
-
- # 尝试终止 Celery 任务
- celery_task_id = task_info.get("celery_task_id") or task_info.get("celery_id")
- if celery_task_id:
- try:
- from celery import current_app as celery_app
- celery_app.control.revoke(celery_task_id, terminate=True)
- logger.info(f"Celery 终止信号已发送: {celery_task_id}")
- except Exception as e:
- logger.warning(f"终止 Celery 任务失败: {e}")
-
- # 关闭 SSE 连接
- try:
- cancel_event = {
- "callback_task_id": request.task_id,
- "status": "cancelled",
- "overall_task_status": "cancelled",
- "message": f"任务已被用户取消: {request.cancel_reason}",
- "cancelled_at": cancelled_at,
- "cancelled_by": request.user_id
- }
- await unified_sse_manager.send_progress(request.task_id, cancel_event)
- await unified_sse_manager.close_connection(request.task_id)
- except Exception as e:
- logger.warning(f"关闭 SSE 失败: {e}")
-
- return {
- "code": 200,
- "message": "任务取消成功",
- "data": {
- "task_id": request.task_id,
- "status": "cancelled",
- "cancelled_at": cancelled_at,
- "cancel_reason": request.cancel_reason,
- "cancelled_by": request.user_id
- }
- }
- except Exception as e:
- logger.error(f"取消任务异常: {str(e)}", exc_info=True)
- return {
- "code": 500,
- "message": f"取消任务失败: {str(e)}",
- "data": {"task_id": request.task_id}
- }
- finally:
- # 关闭连接池
- if redis_client:
- try:
- await redis_client.close()
- except:
- pass
- # ==================== 查询接口 ====================
- @outline_router.get("/task_status")
- @auto_trace(generate_if_missing=True)
- async def task_status(
- task_id: str = Query(..., description="任务ID"),
- user_id: str = Query(..., description="用户ID")
- ):
- """
- 查询大纲生成任务状态
- Args:
- task_id: 任务回调ID
- user_id: 用户ID
- Returns:
- dict: 任务状态信息
- """
- try:
- logger.info(f"查询任务状态: task_id={task_id}")
- # 获取任务信息
- sgbx_task_info = await workflow_manager.get_outline_sgbx_task_info(task_id)
- if sgbx_task_info is None:
- return {
- "code": 404,
- "message": "任务不存在或已完成",
- "data": None
- }
- return {
- "code": 200,
- "message": "查询成功",
- "data": sgbx_task_info
- }
- except Exception as e:
- logger.error(f"查询任务状态失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
- @outline_router.get("/active_tasks")
- @auto_trace(generate_if_missing=True)
- async def active_tasks(
- user_id: str = Query(None, description="用户ID(可选,不提供则返回所有任务)")
- ):
- """
- 获取活跃的大纲生成任务列表
- Args:
- user_id: 用户ID(可选)
- Returns:
- dict: 活跃任务列表
- """
- try:
- logger.info(f"获取活跃任务列表: user_id={user_id}")
- # 获取所有活跃任务
- active_tasks_list = await workflow_manager.get_outline_active_tasks()
- # 如果指定了用户ID,则过滤
- if user_id:
- active_tasks_list = [task for task in active_tasks_list if task["user_id"] == user_id]
- return {
- "code": 200,
- "message": "查询成功",
- "data": {
- "total": len(active_tasks_list),
- "tasks": active_tasks_list
- }
- }
- except Exception as e:
- logger.error(f"获取活跃任务列表失败: {str(e)}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"获取活跃任务列表失败: {str(e)}")
- # 将上下文生成路由注册到应用
- outline_router.include_router(context_generate_router)
|