views.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # -*- coding: utf-8 -*-
  2. """文档 AI 对话 HTTP API。
  3. 提供两个接口:
  4. POST /sgbx/document_chat — 发起对话,支持 SSE 流式和非流式同步两种模式
  5. GET /sgbx/document_chat/health — 健康检查
  6. SSE 流式输出事件类型:
  7. connected — 连接建立
  8. processing — 工作流各阶段进度通知
  9. reasoning — 推理状态(启动、检索、重排、技能执行等)
  10. intent — 意图识别结果
  11. retrieval_result — 检索召回详情(含参考预览)
  12. skill_started — 技能开始执行(answer 或 proposal)
  13. chunk — 技能生成的文本片段(流式逐块输出)
  14. answer_completed — 回答完成
  15. proposal_completed — 修改草案完成
  16. completed — 全部完成
  17. error — 异常错误
  18. """
  19. import json
  20. import time
  21. import uuid
  22. from typing import Any, AsyncGenerator, Dict, Iterable, List, Tuple
  23. from fastapi import APIRouter, HTTPException, Query
  24. from fastapi.responses import StreamingResponse
  25. from foundation.infrastructure.tracing import TraceContext, auto_trace
  26. from core.document_chat.component.document_chat_logger import document_chat_logger as logger
  27. from core.document_chat.component.document_chat_logger import log_document_chat_event, log_document_chat_event_truncated
  28. from core.document_chat.schemas import DocumentChatRequest, DocumentChatResponse, model_to_dict
  29. document_chat_router = APIRouter(prefix="/sgbx", tags=["文档编辑AI对话"])
  30. # SSE 事件中对客户端暴露的参考条数上限,防止响应体过大
  31. MAX_REFERENCES_PER_EVENT = 8
  32. # 单条参考内容预览长度上限
  33. REFERENCE_PREVIEW_CHARS = 600
  34. # 工作流各阶段的前端提示文案映射
  35. STAGE_MESSAGES = {
  36. "workflow_started": "文档 AI 对话工作流已启动",
  37. "recognize_intent": "已完成用户意图识别",
  38. "rerank_context": "知识库内容检索重排完成",
  39. "run_answer_skill": "已生成章节问答结果",
  40. "run_modify_skill": "已生成章节修改草案",
  41. "general_answer": "已生成通用回答",
  42. "error_handler": "流程异常,已进入错误处理",
  43. }
  44. def format_sse_event(event_type: str, data: dict) -> str:
  45. """格式化为 SSE event + data 行。"""
  46. return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  47. def get_document_chat_workflow():
  48. """延迟加载工作流实例,避免循环导入。"""
  49. from core.document_chat.workflows.document_chat_workflow import document_chat_workflow
  50. return document_chat_workflow
  51. def _iter_node_updates(raw_update: Any) -> Iterable[Tuple[str, Dict[str, Any]]]:
  52. """解析 LangGraph 的 updates 负载,提取 (节点名, 更新内容) 对。
  53. 如果 raw_update 的键本身就是节点名,直接返回;
  54. 否则把整个 payload 作为 single-stage 更新处理。
  55. """
  56. if not isinstance(raw_update, dict):
  57. return []
  58. updates: List[Tuple[str, Dict[str, Any]]] = []
  59. for node_name, node_update in raw_update.items():
  60. if isinstance(node_update, dict):
  61. updates.append((str(node_name), node_update))
  62. if updates:
  63. return updates
  64. stage = str(raw_update.get("current_stage") or "workflow_update")
  65. return [(stage, raw_update)]
  66. def _merge_state_update(state: Dict[str, Any], update: Dict[str, Any]) -> None:
  67. """将节点返回的增量字段合并到全局状态。"""
  68. for key, value in update.items():
  69. state[key] = value
  70. def _preview_text(text: Any, limit: int = REFERENCE_PREVIEW_CHARS) -> str:
  71. """截取文本预览,超过 limit 长度的加 "..." 后缀。"""
  72. value = str(text or "").strip()
  73. if len(value) <= limit:
  74. return value
  75. return value[:limit].rstrip() + "..."
  76. def _safe_metadata(metadata: Any) -> Dict[str, Any]:
  77. """过滤出 SSE 事件允许透传的 metadata 白名单字段。"""
  78. if not isinstance(metadata, dict):
  79. return {}
  80. allowed_keys = (
  81. "tenant_id",
  82. "project_id",
  83. "knowledge_base_id",
  84. "file_name",
  85. "chapter_level_1",
  86. "chapter_level_2",
  87. "parent_id",
  88. "parent_count",
  89. "source_scope_valid",
  90. )
  91. return {key: metadata.get(key) for key in allowed_keys if metadata.get(key) not in (None, "")}
  92. def _pack_reference_preview(item: Dict[str, Any]) -> Dict[str, Any]:
  93. """将单条检索参考压缩为前端预览格式(来源 + 内容预览 + 相似度)。"""
  94. metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
  95. content = item.get("content") if "content" in item else item.get("text")
  96. data = {
  97. "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
  98. "content": _preview_text(content),
  99. "vector_similarity": item.get("vector_similarity", 0.0),
  100. "metadata": _safe_metadata(metadata),
  101. }
  102. if "rerank_score" in item:
  103. data["rerank_score"] = item.get("rerank_score", 0.0)
  104. return data
  105. def _limited_items(items: List[Dict[str, Any]], packer) -> List[Dict[str, Any]]:
  106. """截断列表至上限,并对每条应用打包函数。"""
  107. return [packer(item) for item in (items or [])[:MAX_REFERENCES_PER_EVENT] if isinstance(item, dict)]
  108. def _reasoning_event(callback_task_id: str, node_name: str, state: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
  109. """构建 reasoning 阶段事件:有错误时标记 failed,否则 processing。"""
  110. status = "failed" if state.get("error_message") else "processing"
  111. return (
  112. "reasoning",
  113. {
  114. "callback_task_id": callback_task_id,
  115. "stage_name": node_name,
  116. "status": status,
  117. "message": STAGE_MESSAGES.get(node_name, f"已完成 {node_name}"),
  118. },
  119. )
  120. def _build_realtime_events(
  121. callback_task_id: str,
  122. state: Dict[str, Any],
  123. node_name: str,
  124. skill_started_sent: bool,
  125. ) -> Tuple[List[Tuple[str, Dict[str, Any]]], bool]:
  126. """根据当前节点和状态构建需要推送的 SSE 事件列表。
  127. 每个节点可能产生多个事件类型(reasoning + 专项事件),
  128. skill_started_sent 用于防止 quality_gate 阶段重复推送 skill_started。
  129. """
  130. events: List[Tuple[str, Dict[str, Any]]] = []
  131. # 通用推理进度事件
  132. if node_name in STAGE_MESSAGES:
  133. events.append(_reasoning_event(callback_task_id, node_name, state))
  134. # 意图识别完成事件
  135. if node_name == "recognize_intent" and state.get("intent_result"):
  136. events.append(
  137. (
  138. "intent",
  139. {
  140. "callback_task_id": callback_task_id,
  141. "intent_result": state.get("intent_result"),
  142. },
  143. )
  144. )
  145. # 检索结果事件(含参考预览)
  146. if node_name == "rerank_context":
  147. reranked = state.get("reranked_references") or []
  148. events.append(
  149. (
  150. "retrieval_result",
  151. {
  152. "callback_task_id": callback_task_id,
  153. "retrieval_status": state.get("retrieval_status"),
  154. "retrieval_method": state.get("retrieval_method"),
  155. "retrieval_metrics": state.get("retrieval_metrics") or {},
  156. "rerank_count": len(reranked),
  157. "references": _limited_items(reranked, _pack_reference_preview),
  158. "warnings": state.get("warnings") or [],
  159. },
  160. )
  161. )
  162. # 技能开始执行通知(quality_gate 之后、实际调用技能之前)
  163. if node_name == "quality_gate":
  164. intent_result = state.get("intent_result") or {}
  165. skill_name = intent_result.get("skill_name") or ""
  166. if skill_name and not skill_started_sent:
  167. response_type = "proposal" if skill_name == "document-modify" else "answer"
  168. events.append(
  169. (
  170. "skill_started",
  171. {
  172. "callback_task_id": callback_task_id,
  173. "skill_name": skill_name,
  174. "response_type": response_type,
  175. },
  176. )
  177. )
  178. skill_started_sent = True
  179. return events, skill_started_sent
  180. @document_chat_router.post("/document_chat")
  181. @auto_trace(generate_if_missing=True)
  182. async def document_chat(request: DocumentChatRequest, stream: bool = Query(False)):
  183. """文档 AI 对话主接口。
  184. 参数:
  185. stream: true 时走 SSE 流式响应
  186. request.response_mode: "sse" 时同样走 SSE,"json" 时走同步返回
  187. 流程:
  188. 1. 生成 callback_task_id 用于全链路追踪
  189. 2. 记录请求入日志(截断模式,避免大 payload)
  190. 3. 流式:返回 StreamingResponse,逐步推送事件
  191. 4. 非流式:同步执行工作流,一次性返回结果
  192. """
  193. callback_task_id = f"doc_chat_{uuid.uuid4().hex[:12]}"
  194. TraceContext.set_trace_id(callback_task_id)
  195. log_document_chat_event_truncated(
  196. "request_received",
  197. callback_task_id,
  198. {
  199. "stream": stream,
  200. "response_mode": request.response_mode,
  201. "request": model_to_dict(request),
  202. },
  203. )
  204. if stream or request.response_mode == "sse":
  205. return StreamingResponse(
  206. _generate_document_chat_events(callback_task_id, request),
  207. media_type="text/event-stream",
  208. headers={
  209. "Cache-Control": "no-cache",
  210. "Connection": "keep-alive",
  211. "X-Accel-Buffering": "no",
  212. },
  213. )
  214. # 同步模式:阻塞等待工作流执行完毕
  215. try:
  216. workflow = get_document_chat_workflow()
  217. state = await workflow.run(request, callback_task_id)
  218. data = workflow.to_response_data(state)
  219. data_dict = model_to_dict(data)
  220. log_document_chat_event("response_completed", callback_task_id, data_dict)
  221. code = 500 if data.response_type == "error" else 200
  222. message = data.error_message if data.response_type == "error" else "success"
  223. return DocumentChatResponse(code=code, message=message or "success", data=data)
  224. except Exception as exc:
  225. logger.error(f"[DocumentChat] request failed: {exc}", exc_info=True)
  226. log_document_chat_event(
  227. "request_failed",
  228. callback_task_id,
  229. {"error": str(exc), "request": model_to_dict(request)},
  230. level="error",
  231. )
  232. raise HTTPException(status_code=500, detail=str(exc))
  233. async def _generate_document_chat_events(
  234. callback_task_id: str,
  235. request: DocumentChatRequest,
  236. ) -> AsyncGenerator[str, None]:
  237. """SSE 流式生成器。逐步推送工作流执行事件。
  238. 事件推送顺序:
  239. connected → processing → (reasoning / intent / retrieval_result) × N
  240. → chunk × M → answer_completed / proposal_completed → completed
  241. """
  242. started_at = time.time()
  243. try:
  244. yield format_sse_event(
  245. "connected",
  246. {
  247. "callback_task_id": callback_task_id,
  248. "status": "connected",
  249. "timestamp": int(time.time()),
  250. },
  251. )
  252. yield format_sse_event(
  253. "processing",
  254. {
  255. "callback_task_id": callback_task_id,
  256. "stage_name": "workflow_started",
  257. "status": "processing",
  258. "message": "文档 AI 对话工作流已启动",
  259. },
  260. )
  261. workflow = get_document_chat_workflow()
  262. state = workflow.build_initial_state(request, callback_task_id)
  263. graph_state = dict(state)
  264. skill_started_sent = False
  265. custom_event_count = 0
  266. async for mode, payload in workflow.get_graph().astream(
  267. graph_state, stream_mode=["updates", "custom"]
  268. ):
  269. if mode == "custom" and isinstance(payload, dict):
  270. # custom 事件:技能流式输出的文本片段
  271. custom_event_count += 1
  272. if payload.get("stream_chunk"):
  273. yield format_sse_event(
  274. "chunk",
  275. {
  276. "callback_task_id": callback_task_id,
  277. "chunk": payload["stream_chunk"],
  278. },
  279. )
  280. elif mode == "updates":
  281. # updates 事件:节点完成,更新状态并推送对应事件
  282. for node_name, node_update in _iter_node_updates(payload):
  283. _merge_state_update(state, node_update)
  284. realtime_events, skill_started_sent = _build_realtime_events(
  285. callback_task_id,
  286. state,
  287. node_name,
  288. skill_started_sent,
  289. )
  290. for event_type, event_data in realtime_events:
  291. yield format_sse_event(event_type, event_data)
  292. logger.info(f"[DocumentChat] SSE stream completed: custom_events_received={custom_event_count}")
  293. # 工作流执行完毕,推送最终结果事件
  294. data = workflow.to_response_data(state)
  295. data_dict = model_to_dict(data)
  296. log_document_chat_event("response_completed", callback_task_id, data_dict)
  297. if data.response_type == "answer":
  298. yield format_sse_event("answer_completed", data_dict)
  299. elif data.response_type == "proposal":
  300. yield format_sse_event("proposal_completed", data_dict)
  301. elif data.response_type in ("clarify", "unsupported", "general_answer"):
  302. yield format_sse_event("answer_completed", data_dict)
  303. else:
  304. yield format_sse_event("error", data_dict)
  305. # 非错误时推送 completed 事件(含耗时)
  306. if data.response_type != "error":
  307. yield format_sse_event(
  308. "completed",
  309. {
  310. "callback_task_id": callback_task_id,
  311. "status": state.get("overall_task_status", "completed"),
  312. "duration": round(time.time() - started_at, 3),
  313. },
  314. )
  315. except Exception as exc:
  316. logger.error(f"[DocumentChat] SSE request failed: {exc}", exc_info=True)
  317. log_document_chat_event(
  318. "request_failed",
  319. callback_task_id,
  320. {"error": str(exc), "request": model_to_dict(request)},
  321. level="error",
  322. )
  323. yield format_sse_event(
  324. "error",
  325. {
  326. "callback_task_id": callback_task_id,
  327. "status": "error",
  328. "message": str(exc),
  329. },
  330. )
  331. @document_chat_router.get("/document_chat/health")
  332. async def document_chat_health():
  333. """健康检查:返回模块状态和工作流基本信息。"""
  334. return {
  335. "status": "healthy",
  336. "module": "document_chat",
  337. "workflow": "langgraph",
  338. "skills": ["document-answer", "document-modify"],
  339. }