| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- # -*- coding: utf-8 -*-
- """文档 AI 对话 HTTP API。
- 提供两个接口:
- POST /sgbx/document_chat — 发起对话,支持 SSE 流式和非流式同步两种模式
- GET /sgbx/document_chat/health — 健康检查
- SSE 流式输出事件类型:
- connected — 连接建立
- processing — 工作流各阶段进度通知
- reasoning — 推理状态(启动、检索、重排、技能执行等)
- intent — 意图识别结果
- retrieval_result — 检索召回详情(含参考预览)
- skill_started — 技能开始执行(answer 或 proposal)
- chunk — 技能生成的文本片段(流式逐块输出)
- answer_completed — 回答完成
- proposal_completed — 修改草案完成
- completed — 全部完成
- error — 异常错误
- """
- import json
- import time
- import uuid
- from typing import Any, AsyncGenerator, Dict, Iterable, List, Tuple
- from fastapi import APIRouter, HTTPException, Query
- from fastapi.responses import StreamingResponse
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from core.document_chat.component.document_chat_logger import document_chat_logger as logger
- from core.document_chat.component.document_chat_logger import log_document_chat_event, log_document_chat_event_truncated
- from core.document_chat.schemas import DocumentChatRequest, DocumentChatResponse, model_to_dict
- document_chat_router = APIRouter(prefix="/sgbx", tags=["文档编辑AI对话"])
- # SSE 事件中对客户端暴露的参考条数上限,防止响应体过大
- MAX_REFERENCES_PER_EVENT = 8
- # 单条参考内容预览长度上限
- REFERENCE_PREVIEW_CHARS = 600
- # 工作流各阶段的前端提示文案映射
- STAGE_MESSAGES = {
- "workflow_started": "文档 AI 对话工作流已启动",
- "recognize_intent": "已完成用户意图识别",
- "rerank_context": "知识库内容检索重排完成",
- "run_answer_skill": "已生成章节问答结果",
- "run_modify_skill": "已生成章节修改草案",
- "general_answer": "已生成通用回答",
- "error_handler": "流程异常,已进入错误处理",
- }
- def format_sse_event(event_type: str, data: dict) -> str:
- """格式化为 SSE event + data 行。"""
- return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
- def get_document_chat_workflow():
- """延迟加载工作流实例,避免循环导入。"""
- from core.document_chat.workflows.document_chat_workflow import document_chat_workflow
- return document_chat_workflow
- def _iter_node_updates(raw_update: Any) -> Iterable[Tuple[str, Dict[str, Any]]]:
- """解析 LangGraph 的 updates 负载,提取 (节点名, 更新内容) 对。
- 如果 raw_update 的键本身就是节点名,直接返回;
- 否则把整个 payload 作为 single-stage 更新处理。
- """
- if not isinstance(raw_update, dict):
- return []
- updates: List[Tuple[str, Dict[str, Any]]] = []
- for node_name, node_update in raw_update.items():
- if isinstance(node_update, dict):
- updates.append((str(node_name), node_update))
- if updates:
- return updates
- stage = str(raw_update.get("current_stage") or "workflow_update")
- return [(stage, raw_update)]
- def _merge_state_update(state: Dict[str, Any], update: Dict[str, Any]) -> None:
- """将节点返回的增量字段合并到全局状态。"""
- for key, value in update.items():
- state[key] = value
- def _preview_text(text: Any, limit: int = REFERENCE_PREVIEW_CHARS) -> str:
- """截取文本预览,超过 limit 长度的加 "..." 后缀。"""
- value = str(text or "").strip()
- if len(value) <= limit:
- return value
- return value[:limit].rstrip() + "..."
- def _safe_metadata(metadata: Any) -> Dict[str, Any]:
- """过滤出 SSE 事件允许透传的 metadata 白名单字段。"""
- if not isinstance(metadata, dict):
- return {}
- allowed_keys = (
- "tenant_id",
- "project_id",
- "knowledge_base_id",
- "file_name",
- "chapter_level_1",
- "chapter_level_2",
- "parent_id",
- "parent_count",
- "source_scope_valid",
- )
- return {key: metadata.get(key) for key in allowed_keys if metadata.get(key) not in (None, "")}
- def _pack_reference_preview(item: Dict[str, Any]) -> Dict[str, Any]:
- """将单条检索参考压缩为前端预览格式(来源 + 内容预览 + 相似度)。"""
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
- content = item.get("content") if "content" in item else item.get("text")
- data = {
- "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
- "content": _preview_text(content),
- "vector_similarity": item.get("vector_similarity", 0.0),
- "metadata": _safe_metadata(metadata),
- }
- if "rerank_score" in item:
- data["rerank_score"] = item.get("rerank_score", 0.0)
- return data
- def _limited_items(items: List[Dict[str, Any]], packer) -> List[Dict[str, Any]]:
- """截断列表至上限,并对每条应用打包函数。"""
- return [packer(item) for item in (items or [])[:MAX_REFERENCES_PER_EVENT] if isinstance(item, dict)]
- def _reasoning_event(callback_task_id: str, node_name: str, state: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
- """构建 reasoning 阶段事件:有错误时标记 failed,否则 processing。"""
- status = "failed" if state.get("error_message") else "processing"
- return (
- "reasoning",
- {
- "callback_task_id": callback_task_id,
- "stage_name": node_name,
- "status": status,
- "message": STAGE_MESSAGES.get(node_name, f"已完成 {node_name}"),
- },
- )
- def _build_realtime_events(
- callback_task_id: str,
- state: Dict[str, Any],
- node_name: str,
- skill_started_sent: bool,
- ) -> Tuple[List[Tuple[str, Dict[str, Any]]], bool]:
- """根据当前节点和状态构建需要推送的 SSE 事件列表。
- 每个节点可能产生多个事件类型(reasoning + 专项事件),
- skill_started_sent 用于防止 quality_gate 阶段重复推送 skill_started。
- """
- events: List[Tuple[str, Dict[str, Any]]] = []
- # 通用推理进度事件
- if node_name in STAGE_MESSAGES:
- events.append(_reasoning_event(callback_task_id, node_name, state))
- # 意图识别完成事件
- if node_name == "recognize_intent" and state.get("intent_result"):
- events.append(
- (
- "intent",
- {
- "callback_task_id": callback_task_id,
- "intent_result": state.get("intent_result"),
- },
- )
- )
- # 检索结果事件(含参考预览)
- if node_name == "rerank_context":
- reranked = state.get("reranked_references") or []
- events.append(
- (
- "retrieval_result",
- {
- "callback_task_id": callback_task_id,
- "retrieval_status": state.get("retrieval_status"),
- "retrieval_method": state.get("retrieval_method"),
- "retrieval_metrics": state.get("retrieval_metrics") or {},
- "rerank_count": len(reranked),
- "references": _limited_items(reranked, _pack_reference_preview),
- "warnings": state.get("warnings") or [],
- },
- )
- )
- # 技能开始执行通知(quality_gate 之后、实际调用技能之前)
- if node_name == "quality_gate":
- intent_result = state.get("intent_result") or {}
- skill_name = intent_result.get("skill_name") or ""
- if skill_name and not skill_started_sent:
- response_type = "proposal" if skill_name == "document-modify" else "answer"
- events.append(
- (
- "skill_started",
- {
- "callback_task_id": callback_task_id,
- "skill_name": skill_name,
- "response_type": response_type,
- },
- )
- )
- skill_started_sent = True
- return events, skill_started_sent
- @document_chat_router.post("/document_chat")
- @auto_trace(generate_if_missing=True)
- async def document_chat(request: DocumentChatRequest, stream: bool = Query(False)):
- """文档 AI 对话主接口。
- 参数:
- stream: true 时走 SSE 流式响应
- request.response_mode: "sse" 时同样走 SSE,"json" 时走同步返回
- 流程:
- 1. 生成 callback_task_id 用于全链路追踪
- 2. 记录请求入日志(截断模式,避免大 payload)
- 3. 流式:返回 StreamingResponse,逐步推送事件
- 4. 非流式:同步执行工作流,一次性返回结果
- """
- callback_task_id = f"doc_chat_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
- log_document_chat_event_truncated(
- "request_received",
- callback_task_id,
- {
- "stream": stream,
- "response_mode": request.response_mode,
- "request": model_to_dict(request),
- },
- )
- if stream or request.response_mode == "sse":
- return StreamingResponse(
- _generate_document_chat_events(callback_task_id, request),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- },
- )
- # 同步模式:阻塞等待工作流执行完毕
- try:
- workflow = get_document_chat_workflow()
- state = await workflow.run(request, callback_task_id)
- data = workflow.to_response_data(state)
- data_dict = model_to_dict(data)
- log_document_chat_event("response_completed", callback_task_id, data_dict)
- code = 500 if data.response_type == "error" else 200
- message = data.error_message if data.response_type == "error" else "success"
- return DocumentChatResponse(code=code, message=message or "success", data=data)
- except Exception as exc:
- logger.error(f"[DocumentChat] request failed: {exc}", exc_info=True)
- log_document_chat_event(
- "request_failed",
- callback_task_id,
- {"error": str(exc), "request": model_to_dict(request)},
- level="error",
- )
- raise HTTPException(status_code=500, detail=str(exc))
- async def _generate_document_chat_events(
- callback_task_id: str,
- request: DocumentChatRequest,
- ) -> AsyncGenerator[str, None]:
- """SSE 流式生成器。逐步推送工作流执行事件。
- 事件推送顺序:
- connected → processing → (reasoning / intent / retrieval_result) × N
- → chunk × M → answer_completed / proposal_completed → completed
- """
- started_at = time.time()
- try:
- yield format_sse_event(
- "connected",
- {
- "callback_task_id": callback_task_id,
- "status": "connected",
- "timestamp": int(time.time()),
- },
- )
- yield format_sse_event(
- "processing",
- {
- "callback_task_id": callback_task_id,
- "stage_name": "workflow_started",
- "status": "processing",
- "message": "文档 AI 对话工作流已启动",
- },
- )
- workflow = get_document_chat_workflow()
- state = workflow.build_initial_state(request, callback_task_id)
- graph_state = dict(state)
- skill_started_sent = False
- custom_event_count = 0
- async for mode, payload in workflow.get_graph().astream(
- graph_state, stream_mode=["updates", "custom"]
- ):
- if mode == "custom" and isinstance(payload, dict):
- # custom 事件:技能流式输出的文本片段
- custom_event_count += 1
- if payload.get("stream_chunk"):
- yield format_sse_event(
- "chunk",
- {
- "callback_task_id": callback_task_id,
- "chunk": payload["stream_chunk"],
- },
- )
- elif mode == "updates":
- # updates 事件:节点完成,更新状态并推送对应事件
- for node_name, node_update in _iter_node_updates(payload):
- _merge_state_update(state, node_update)
- realtime_events, skill_started_sent = _build_realtime_events(
- callback_task_id,
- state,
- node_name,
- skill_started_sent,
- )
- for event_type, event_data in realtime_events:
- yield format_sse_event(event_type, event_data)
- logger.info(f"[DocumentChat] SSE stream completed: custom_events_received={custom_event_count}")
- # 工作流执行完毕,推送最终结果事件
- data = workflow.to_response_data(state)
- data_dict = model_to_dict(data)
- log_document_chat_event("response_completed", callback_task_id, data_dict)
- if data.response_type == "answer":
- yield format_sse_event("answer_completed", data_dict)
- elif data.response_type == "proposal":
- yield format_sse_event("proposal_completed", data_dict)
- elif data.response_type in ("clarify", "unsupported", "general_answer"):
- yield format_sse_event("answer_completed", data_dict)
- else:
- yield format_sse_event("error", data_dict)
- # 非错误时推送 completed 事件(含耗时)
- if data.response_type != "error":
- yield format_sse_event(
- "completed",
- {
- "callback_task_id": callback_task_id,
- "status": state.get("overall_task_status", "completed"),
- "duration": round(time.time() - started_at, 3),
- },
- )
- except Exception as exc:
- logger.error(f"[DocumentChat] SSE request failed: {exc}", exc_info=True)
- log_document_chat_event(
- "request_failed",
- callback_task_id,
- {"error": str(exc), "request": model_to_dict(request)},
- level="error",
- )
- yield format_sse_event(
- "error",
- {
- "callback_task_id": callback_task_id,
- "status": "error",
- "message": str(exc),
- },
- )
- @document_chat_router.get("/document_chat/health")
- async def document_chat_health():
- """健康检查:返回模块状态和工作流基本信息。"""
- return {
- "status": "healthy",
- "module": "document_chat",
- "workflow": "langgraph",
- "skills": ["document-answer", "document-modify"],
- }
|