# -*- coding: utf-8 -*- """文档 AI 对话 HTTP API。 提供两个接口: POST /sgbx/document_chat — 发起对话,支持 SSE 流式和非流式同步两种模式 GET /sgbx/document_chat/health — 健康检查 SSE 流式输出事件类型: connected — 连接建立 processing — 工作流各阶段进度通知 reasoning — 推理状态(启动、检索、重排、技能执行等) intent — 意图识别结果 retrieval_result — 检索召回详情(含参考预览) skill_started — 技能开始执行(answer 或 proposal) chunk — 技能生成的文本片段(流式逐块输出) answer_completed — 回答完成 proposal_completed — 修改草案完成 completed — 全部完成 error — 异常错误 """ import json import time import uuid from typing import Any, AsyncGenerator, Dict, Iterable, List, Tuple from fastapi import APIRouter, HTTPException, Query from fastapi.responses import StreamingResponse from foundation.infrastructure.tracing import TraceContext, auto_trace from core.document_chat.component.document_chat_logger import document_chat_logger as logger from core.document_chat.component.document_chat_logger import log_document_chat_event, log_document_chat_event_truncated from core.document_chat.schemas import DocumentChatRequest, DocumentChatResponse, model_to_dict document_chat_router = APIRouter(prefix="/sgbx", tags=["文档编辑AI对话"]) # SSE 事件中对客户端暴露的参考条数上限,防止响应体过大 MAX_REFERENCES_PER_EVENT = 8 # 单条参考内容预览长度上限 REFERENCE_PREVIEW_CHARS = 600 # 工作流各阶段的前端提示文案映射 STAGE_MESSAGES = { "workflow_started": "文档 AI 对话工作流已启动", "recognize_intent": "已完成用户意图识别", "rerank_context": "知识库内容检索重排完成", "run_answer_skill": "已生成章节问答结果", "run_modify_skill": "已生成章节修改草案", "general_answer": "已生成通用回答", "error_handler": "流程异常,已进入错误处理", } def format_sse_event(event_type: str, data: dict) -> str: """格式化为 SSE event + data 行。""" return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_document_chat_workflow(): """延迟加载工作流实例,避免循环导入。""" from core.document_chat.workflows.document_chat_workflow import document_chat_workflow return document_chat_workflow def _iter_node_updates(raw_update: Any) -> Iterable[Tuple[str, Dict[str, Any]]]: """解析 LangGraph 的 updates 负载,提取 (节点名, 更新内容) 对。 如果 raw_update 的键本身就是节点名,直接返回; 否则把整个 payload 作为 single-stage 更新处理。 """ if not isinstance(raw_update, dict): return [] updates: List[Tuple[str, Dict[str, Any]]] = [] for node_name, node_update in raw_update.items(): if isinstance(node_update, dict): updates.append((str(node_name), node_update)) if updates: return updates stage = str(raw_update.get("current_stage") or "workflow_update") return [(stage, raw_update)] def _merge_state_update(state: Dict[str, Any], update: Dict[str, Any]) -> None: """将节点返回的增量字段合并到全局状态。""" for key, value in update.items(): state[key] = value def _preview_text(text: Any, limit: int = REFERENCE_PREVIEW_CHARS) -> str: """截取文本预览,超过 limit 长度的加 "..." 后缀。""" value = str(text or "").strip() if len(value) <= limit: return value return value[:limit].rstrip() + "..." def _safe_metadata(metadata: Any) -> Dict[str, Any]: """过滤出 SSE 事件允许透传的 metadata 白名单字段。""" if not isinstance(metadata, dict): return {} allowed_keys = ( "tenant_id", "project_id", "knowledge_base_id", "file_name", "chapter_level_1", "chapter_level_2", "parent_id", "parent_count", "source_scope_valid", ) return {key: metadata.get(key) for key in allowed_keys if metadata.get(key) not in (None, "")} def _pack_reference_preview(item: Dict[str, Any]) -> Dict[str, Any]: """将单条检索参考压缩为前端预览格式(来源 + 内容预览 + 相似度)。""" metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} content = item.get("content") if "content" in item else item.get("text") data = { "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"), "content": _preview_text(content), "vector_similarity": item.get("vector_similarity", 0.0), "metadata": _safe_metadata(metadata), } if "rerank_score" in item: data["rerank_score"] = item.get("rerank_score", 0.0) return data def _limited_items(items: List[Dict[str, Any]], packer) -> List[Dict[str, Any]]: """截断列表至上限,并对每条应用打包函数。""" return [packer(item) for item in (items or [])[:MAX_REFERENCES_PER_EVENT] if isinstance(item, dict)] def _reasoning_event(callback_task_id: str, node_name: str, state: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """构建 reasoning 阶段事件:有错误时标记 failed,否则 processing。""" status = "failed" if state.get("error_message") else "processing" return ( "reasoning", { "callback_task_id": callback_task_id, "stage_name": node_name, "status": status, "message": STAGE_MESSAGES.get(node_name, f"已完成 {node_name}"), }, ) def _build_realtime_events( callback_task_id: str, state: Dict[str, Any], node_name: str, skill_started_sent: bool, ) -> Tuple[List[Tuple[str, Dict[str, Any]]], bool]: """根据当前节点和状态构建需要推送的 SSE 事件列表。 每个节点可能产生多个事件类型(reasoning + 专项事件), skill_started_sent 用于防止 quality_gate 阶段重复推送 skill_started。 """ events: List[Tuple[str, Dict[str, Any]]] = [] # 通用推理进度事件 if node_name in STAGE_MESSAGES: events.append(_reasoning_event(callback_task_id, node_name, state)) # 意图识别完成事件 if node_name == "recognize_intent" and state.get("intent_result"): events.append( ( "intent", { "callback_task_id": callback_task_id, "intent_result": state.get("intent_result"), }, ) ) # 检索结果事件(含参考预览) if node_name == "rerank_context": reranked = state.get("reranked_references") or [] events.append( ( "retrieval_result", { "callback_task_id": callback_task_id, "retrieval_status": state.get("retrieval_status"), "retrieval_method": state.get("retrieval_method"), "retrieval_metrics": state.get("retrieval_metrics") or {}, "rerank_count": len(reranked), "references": _limited_items(reranked, _pack_reference_preview), "warnings": state.get("warnings") or [], }, ) ) # 技能开始执行通知(quality_gate 之后、实际调用技能之前) if node_name == "quality_gate": intent_result = state.get("intent_result") or {} skill_name = intent_result.get("skill_name") or "" if skill_name and not skill_started_sent: response_type = "proposal" if skill_name == "document-modify" else "answer" events.append( ( "skill_started", { "callback_task_id": callback_task_id, "skill_name": skill_name, "response_type": response_type, }, ) ) skill_started_sent = True return events, skill_started_sent @document_chat_router.post("/document_chat") @auto_trace(generate_if_missing=True) async def document_chat(request: DocumentChatRequest, stream: bool = Query(False)): """文档 AI 对话主接口。 参数: stream: true 时走 SSE 流式响应 request.response_mode: "sse" 时同样走 SSE,"json" 时走同步返回 流程: 1. 生成 callback_task_id 用于全链路追踪 2. 记录请求入日志(截断模式,避免大 payload) 3. 流式:返回 StreamingResponse,逐步推送事件 4. 非流式:同步执行工作流,一次性返回结果 """ callback_task_id = f"doc_chat_{uuid.uuid4().hex[:12]}" TraceContext.set_trace_id(callback_task_id) log_document_chat_event_truncated( "request_received", callback_task_id, { "stream": stream, "response_mode": request.response_mode, "request": model_to_dict(request), }, ) if stream or request.response_mode == "sse": return StreamingResponse( _generate_document_chat_events(callback_task_id, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # 同步模式:阻塞等待工作流执行完毕 try: workflow = get_document_chat_workflow() state = await workflow.run(request, callback_task_id) data = workflow.to_response_data(state) data_dict = model_to_dict(data) log_document_chat_event("response_completed", callback_task_id, data_dict) code = 500 if data.response_type == "error" else 200 message = data.error_message if data.response_type == "error" else "success" return DocumentChatResponse(code=code, message=message or "success", data=data) except Exception as exc: logger.error(f"[DocumentChat] request failed: {exc}", exc_info=True) log_document_chat_event( "request_failed", callback_task_id, {"error": str(exc), "request": model_to_dict(request)}, level="error", ) raise HTTPException(status_code=500, detail=str(exc)) async def _generate_document_chat_events( callback_task_id: str, request: DocumentChatRequest, ) -> AsyncGenerator[str, None]: """SSE 流式生成器。逐步推送工作流执行事件。 事件推送顺序: connected → processing → (reasoning / intent / retrieval_result) × N → chunk × M → answer_completed / proposal_completed → completed """ started_at = time.time() try: yield format_sse_event( "connected", { "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time()), }, ) yield format_sse_event( "processing", { "callback_task_id": callback_task_id, "stage_name": "workflow_started", "status": "processing", "message": "文档 AI 对话工作流已启动", }, ) workflow = get_document_chat_workflow() state = workflow.build_initial_state(request, callback_task_id) graph_state = dict(state) skill_started_sent = False custom_event_count = 0 async for mode, payload in workflow.get_graph().astream( graph_state, stream_mode=["updates", "custom"] ): if mode == "custom" and isinstance(payload, dict): # custom 事件:技能流式输出的文本片段 custom_event_count += 1 if payload.get("stream_chunk"): yield format_sse_event( "chunk", { "callback_task_id": callback_task_id, "chunk": payload["stream_chunk"], }, ) elif mode == "updates": # updates 事件:节点完成,更新状态并推送对应事件 for node_name, node_update in _iter_node_updates(payload): _merge_state_update(state, node_update) realtime_events, skill_started_sent = _build_realtime_events( callback_task_id, state, node_name, skill_started_sent, ) for event_type, event_data in realtime_events: yield format_sse_event(event_type, event_data) logger.info(f"[DocumentChat] SSE stream completed: custom_events_received={custom_event_count}") # 工作流执行完毕,推送最终结果事件 data = workflow.to_response_data(state) data_dict = model_to_dict(data) log_document_chat_event("response_completed", callback_task_id, data_dict) if data.response_type == "answer": yield format_sse_event("answer_completed", data_dict) elif data.response_type == "proposal": yield format_sse_event("proposal_completed", data_dict) elif data.response_type in ("clarify", "unsupported", "general_answer"): yield format_sse_event("answer_completed", data_dict) else: yield format_sse_event("error", data_dict) # 非错误时推送 completed 事件(含耗时) if data.response_type != "error": yield format_sse_event( "completed", { "callback_task_id": callback_task_id, "status": state.get("overall_task_status", "completed"), "duration": round(time.time() - started_at, 3), }, ) except Exception as exc: logger.error(f"[DocumentChat] SSE request failed: {exc}", exc_info=True) log_document_chat_event( "request_failed", callback_task_id, {"error": str(exc), "request": model_to_dict(request)}, level="error", ) yield format_sse_event( "error", { "callback_task_id": callback_task_id, "status": "error", "message": str(exc), }, ) @document_chat_router.get("/document_chat/health") async def document_chat_health(): """健康检查:返回模块状态和工作流基本信息。""" return { "status": "healthy", "module": "document_chat", "workflow": "langgraph", "skills": ["document-answer", "document-modify"], }