views.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # -*- coding: utf-8 -*-
  2. """HTTP API for document chat."""
  3. import json
  4. import time
  5. import uuid
  6. from typing import Any, AsyncGenerator, Dict, Iterable, List, Tuple
  7. from fastapi import APIRouter, HTTPException, Query
  8. from fastapi.responses import StreamingResponse
  9. from foundation.infrastructure.tracing import TraceContext, auto_trace
  10. from foundation.observability.logger.loggering import write_logger as logger
  11. from core.document_chat.component.document_chat_logger import log_document_chat_event
  12. from core.document_chat.schemas import DocumentChatRequest, DocumentChatResponse, model_to_dict
  13. document_chat_router = APIRouter(prefix="/sgbx", tags=["文档编辑AI对话"])
  14. MAX_REFERENCES_PER_EVENT = 8
  15. REFERENCE_PREVIEW_CHARS = 600
  16. STAGE_MESSAGES = {
  17. "workflow_started": "文档 AI 对话工作流已启动",
  18. "recognize_intent": "已完成用户意图识别",
  19. "rerank_context": "知识库内容检索重排完成",
  20. "run_answer_skill": "已生成章节问答结果",
  21. "run_modify_skill": "已生成章节修改草案",
  22. "error_handler": "流程异常,已进入错误处理",
  23. }
  24. def format_sse_event(event_type: str, data: dict) -> str:
  25. return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
  26. def get_document_chat_workflow():
  27. from core.document_chat.workflows.document_chat_workflow import document_chat_workflow
  28. return document_chat_workflow
  29. def _iter_node_updates(raw_update: Any) -> Iterable[Tuple[str, Dict[str, Any]]]:
  30. if not isinstance(raw_update, dict):
  31. return []
  32. updates: List[Tuple[str, Dict[str, Any]]] = []
  33. for node_name, node_update in raw_update.items():
  34. if isinstance(node_update, dict):
  35. updates.append((str(node_name), node_update))
  36. if updates:
  37. return updates
  38. stage = str(raw_update.get("current_stage") or "workflow_update")
  39. return [(stage, raw_update)]
  40. def _merge_state_update(state: Dict[str, Any], update: Dict[str, Any]) -> None:
  41. for key, value in update.items():
  42. state[key] = value
  43. def _preview_text(text: Any, limit: int = REFERENCE_PREVIEW_CHARS) -> str:
  44. value = str(text or "").strip()
  45. if len(value) <= limit:
  46. return value
  47. return value[:limit].rstrip() + "..."
  48. def _safe_metadata(metadata: Any) -> Dict[str, Any]:
  49. if not isinstance(metadata, dict):
  50. return {}
  51. allowed_keys = (
  52. "tenant_id",
  53. "project_id",
  54. "knowledge_base_id",
  55. "file_name",
  56. "chapter_level_1",
  57. "chapter_level_2",
  58. "parent_id",
  59. "parent_count",
  60. "source_scope_valid",
  61. )
  62. return {key: metadata.get(key) for key in allowed_keys if metadata.get(key) not in (None, "")}
  63. def _pack_reference_preview(item: Dict[str, Any]) -> Dict[str, Any]:
  64. metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
  65. content = item.get("content") if "content" in item else item.get("text")
  66. data = {
  67. "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
  68. "content": _preview_text(content),
  69. "vector_similarity": item.get("vector_similarity", 0.0),
  70. "metadata": _safe_metadata(metadata),
  71. }
  72. if "rerank_score" in item:
  73. data["rerank_score"] = item.get("rerank_score", 0.0)
  74. return data
  75. def _limited_items(items: List[Dict[str, Any]], packer) -> List[Dict[str, Any]]:
  76. return [packer(item) for item in (items or [])[:MAX_REFERENCES_PER_EVENT] if isinstance(item, dict)]
  77. def _reasoning_event(callback_task_id: str, node_name: str, state: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
  78. status = "failed" if state.get("error_message") else "processing"
  79. return (
  80. "reasoning",
  81. {
  82. "callback_task_id": callback_task_id,
  83. "stage_name": node_name,
  84. "status": status,
  85. "message": STAGE_MESSAGES.get(node_name, f"已完成 {node_name}"),
  86. },
  87. )
  88. def _build_realtime_events(
  89. callback_task_id: str,
  90. state: Dict[str, Any],
  91. node_name: str,
  92. skill_started_sent: bool,
  93. ) -> Tuple[List[Tuple[str, Dict[str, Any]]], bool]:
  94. events: List[Tuple[str, Dict[str, Any]]] = []
  95. if node_name in STAGE_MESSAGES:
  96. events.append(_reasoning_event(callback_task_id, node_name, state))
  97. if node_name == "recognize_intent" and state.get("intent_result"):
  98. events.append(
  99. (
  100. "intent",
  101. {
  102. "callback_task_id": callback_task_id,
  103. "intent_result": state.get("intent_result"),
  104. },
  105. )
  106. )
  107. if node_name == "rerank_context":
  108. reranked = state.get("reranked_references") or []
  109. events.append(
  110. (
  111. "retrieval_result",
  112. {
  113. "callback_task_id": callback_task_id,
  114. "retrieval_status": state.get("retrieval_status"),
  115. "retrieval_method": state.get("retrieval_method"),
  116. "retrieval_metrics": state.get("retrieval_metrics") or {},
  117. "rerank_count": len(reranked),
  118. "references": _limited_items(reranked, _pack_reference_preview),
  119. "warnings": state.get("warnings") or [],
  120. },
  121. )
  122. )
  123. if node_name == "quality_gate":
  124. intent_result = state.get("intent_result") or {}
  125. skill_name = intent_result.get("skill_name") or ""
  126. if skill_name and not skill_started_sent:
  127. response_type = "proposal" if skill_name == "document-modify" else "answer"
  128. events.append(
  129. (
  130. "skill_started",
  131. {
  132. "callback_task_id": callback_task_id,
  133. "skill_name": skill_name,
  134. "response_type": response_type,
  135. },
  136. )
  137. )
  138. skill_started_sent = True
  139. return events, skill_started_sent
  140. @document_chat_router.post("/document_chat")
  141. @auto_trace(generate_if_missing=True)
  142. async def document_chat(request: DocumentChatRequest, stream: bool = Query(False)):
  143. callback_task_id = f"doc_chat_{uuid.uuid4().hex[:12]}"
  144. TraceContext.set_trace_id(callback_task_id)
  145. log_document_chat_event(
  146. "request_received",
  147. callback_task_id,
  148. {
  149. "stream": stream,
  150. "response_mode": request.response_mode,
  151. "request": model_to_dict(request),
  152. },
  153. )
  154. if stream or request.response_mode == "sse":
  155. return StreamingResponse(
  156. _generate_document_chat_events(callback_task_id, request),
  157. media_type="text/event-stream",
  158. headers={
  159. "Cache-Control": "no-cache",
  160. "Connection": "keep-alive",
  161. "X-Accel-Buffering": "no",
  162. },
  163. )
  164. try:
  165. workflow = get_document_chat_workflow()
  166. state = await workflow.run(request, callback_task_id)
  167. data = workflow.to_response_data(state)
  168. data_dict = model_to_dict(data)
  169. log_document_chat_event("response_completed", callback_task_id, data_dict)
  170. code = 500 if data.response_type == "error" else 200
  171. message = data.error_message if data.response_type == "error" else "success"
  172. return DocumentChatResponse(code=code, message=message or "success", data=data)
  173. except Exception as exc:
  174. logger.error(f"[DocumentChat] request failed: {exc}", exc_info=True)
  175. log_document_chat_event(
  176. "request_failed",
  177. callback_task_id,
  178. {"error": str(exc), "request": model_to_dict(request)},
  179. level="error",
  180. )
  181. raise HTTPException(status_code=500, detail=str(exc))
  182. async def _generate_document_chat_events(
  183. callback_task_id: str,
  184. request: DocumentChatRequest,
  185. ) -> AsyncGenerator[str, None]:
  186. started_at = time.time()
  187. try:
  188. yield format_sse_event(
  189. "connected",
  190. {
  191. "callback_task_id": callback_task_id,
  192. "status": "connected",
  193. "timestamp": int(time.time()),
  194. },
  195. )
  196. yield format_sse_event(
  197. "processing",
  198. {
  199. "callback_task_id": callback_task_id,
  200. "stage_name": "workflow_started",
  201. "status": "processing",
  202. "message": "文档 AI 对话工作流已启动",
  203. },
  204. )
  205. workflow = get_document_chat_workflow()
  206. state = workflow.build_initial_state(request, callback_task_id)
  207. graph_state = dict(state)
  208. skill_started_sent = False
  209. async for mode, payload in workflow.get_graph().astream(
  210. graph_state, stream_mode=["updates", "custom"]
  211. ):
  212. if mode == "custom" and isinstance(payload, dict) and payload.get("stream_chunk"):
  213. yield format_sse_event(
  214. "chunk",
  215. {
  216. "callback_task_id": callback_task_id,
  217. "chunk": payload["stream_chunk"],
  218. },
  219. )
  220. elif mode == "updates":
  221. for node_name, node_update in _iter_node_updates(payload):
  222. _merge_state_update(state, node_update)
  223. realtime_events, skill_started_sent = _build_realtime_events(
  224. callback_task_id,
  225. state,
  226. node_name,
  227. skill_started_sent,
  228. )
  229. for event_type, event_data in realtime_events:
  230. yield format_sse_event(event_type, event_data)
  231. data = workflow.to_response_data(state)
  232. data_dict = model_to_dict(data)
  233. log_document_chat_event("response_completed", callback_task_id, data_dict)
  234. if data.response_type == "answer":
  235. yield format_sse_event("answer_completed", data_dict)
  236. elif data.response_type == "proposal":
  237. yield format_sse_event("proposal_completed", data_dict)
  238. elif data.response_type in ("clarify", "unsupported"):
  239. yield format_sse_event("answer_completed", data_dict)
  240. else:
  241. yield format_sse_event("error", data_dict)
  242. if data.response_type != "error":
  243. yield format_sse_event(
  244. "completed",
  245. {
  246. "callback_task_id": callback_task_id,
  247. "status": state.get("overall_task_status", "completed"),
  248. "duration": round(time.time() - started_at, 3),
  249. },
  250. )
  251. except Exception as exc:
  252. logger.error(f"[DocumentChat] SSE request failed: {exc}", exc_info=True)
  253. log_document_chat_event(
  254. "request_failed",
  255. callback_task_id,
  256. {"error": str(exc), "request": model_to_dict(request)},
  257. level="error",
  258. )
  259. yield format_sse_event(
  260. "error",
  261. {
  262. "callback_task_id": callback_task_id,
  263. "status": "error",
  264. "message": str(exc),
  265. },
  266. )
  267. @document_chat_router.get("/document_chat/health")
  268. async def document_chat_health():
  269. return {
  270. "status": "healthy",
  271. "module": "document_chat",
  272. "workflow": "langgraph",
  273. "skills": ["document-answer", "document-modify"],
  274. }