# -*- coding: utf-8 -*- """HTTP API for document chat.""" import json import time import uuid from typing import Any, AsyncGenerator, Dict, Iterable, List, Tuple from fastapi import APIRouter, HTTPException, Query from fastapi.responses import StreamingResponse from foundation.infrastructure.tracing import TraceContext, auto_trace from foundation.observability.logger.loggering import write_logger as logger from core.document_chat.component.document_chat_logger import log_document_chat_event from core.document_chat.schemas import DocumentChatRequest, DocumentChatResponse, model_to_dict document_chat_router = APIRouter(prefix="/sgbx", tags=["文档编辑AI对话"]) MAX_REFERENCES_PER_EVENT = 8 REFERENCE_PREVIEW_CHARS = 600 STAGE_MESSAGES = { "validate_input": "已校验对话输入", "load_context": "已整理当前章节上下文", "load_skill_registry": "已加载文档对话技能", "recognize_intent": "已完成用户意图识别", "route_intent": "已确定对话处理路径", "build_retrieval_query": "已构建知识库检索问题", "vector_recall": "已完成知识库向量召回", "rerank_context": "已完成召回片段重排", "quality_gate": "已完成参考资料质量门控", "clarify": "需要用户补充说明", "unsupported": "当前请求不在文档对话能力范围内", "run_answer_skill": "已生成章节问答结果", "run_modify_skill": "已生成章节修改草案", "build_diff": "已生成新旧内容对比", "error_handler": "流程异常,已进入错误处理", "complete": "文档 AI 对话流程完成", } def format_sse_event(event_type: str, data: dict) -> str: return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_document_chat_workflow(): from core.document_chat.workflows.document_chat_workflow import document_chat_workflow return document_chat_workflow def _iter_node_updates(raw_update: Any) -> Iterable[Tuple[str, Dict[str, Any]]]: if not isinstance(raw_update, dict): return [] updates: List[Tuple[str, Dict[str, Any]]] = [] for node_name, node_update in raw_update.items(): if isinstance(node_update, dict): updates.append((str(node_name), node_update)) if updates: return updates stage = str(raw_update.get("current_stage") or "workflow_update") return [(stage, raw_update)] def _merge_state_update(state: Dict[str, Any], update: Dict[str, Any]) -> None: for key, value in update.items(): state[key] = value def _preview_text(text: Any, limit: int = REFERENCE_PREVIEW_CHARS) -> str: value = str(text or "").strip() if len(value) <= limit: return value return value[:limit].rstrip() + "..." def _safe_metadata(metadata: Any) -> Dict[str, Any]: if not isinstance(metadata, dict): return {} allowed_keys = ( "tenant_id", "project_id", "knowledge_base_id", "file_name", "chapter_level_1", "chapter_level_2", "parent_id", "parent_count", "source_scope_valid", ) return {key: metadata.get(key) for key in allowed_keys if metadata.get(key) not in (None, "")} def _pack_candidate_preview(item: Dict[str, Any]) -> Dict[str, Any]: metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} return { "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"), "snippet": _preview_text(item.get("text")), "vector_similarity": item.get("vector_similarity", 0.0), "metadata": _safe_metadata(metadata), } def _pack_reference_preview(item: Dict[str, Any]) -> Dict[str, Any]: metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} content = item.get("content") if "content" in item else item.get("text") data = { "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"), "content": _preview_text(content), "vector_similarity": item.get("vector_similarity", 0.0), "metadata": _safe_metadata(metadata), } if "rerank_score" in item: data["rerank_score"] = item.get("rerank_score", 0.0) return data def _limited_items(items: List[Dict[str, Any]], packer) -> List[Dict[str, Any]]: return [packer(item) for item in (items or [])[:MAX_REFERENCES_PER_EVENT] if isinstance(item, dict)] def _reasoning_event(callback_task_id: str, node_name: str, state: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: status = "failed" if state.get("error_message") else "processing" return ( "reasoning", { "callback_task_id": callback_task_id, "stage_name": node_name, "status": status, "message": STAGE_MESSAGES.get(node_name, f"已完成 {node_name}"), }, ) def _build_realtime_events( callback_task_id: str, state: Dict[str, Any], node_name: str, skill_started_sent: bool, ) -> Tuple[List[Tuple[str, Dict[str, Any]]], bool]: events: List[Tuple[str, Dict[str, Any]]] = [] if node_name in STAGE_MESSAGES: events.append(_reasoning_event(callback_task_id, node_name, state)) if node_name == "recognize_intent" and state.get("intent_result"): events.append( ( "intent", { "callback_task_id": callback_task_id, "intent_result": state.get("intent_result"), }, ) ) if node_name == "build_retrieval_query": events.append( ( "retrieval_query", { "callback_task_id": callback_task_id, "query": state.get("retrieval_query") or "", }, ) ) if node_name == "vector_recall": candidates = state.get("retrieval_candidates") or [] events.append( ( "retrieval_recalled", { "callback_task_id": callback_task_id, "retrieval_status": state.get("retrieval_status"), "retrieval_method": state.get("retrieval_method"), "retrieval_metrics": state.get("retrieval_metrics") or {}, "candidate_count": len(candidates), "candidates": _limited_items(candidates, _pack_candidate_preview), "warnings": state.get("warnings") or [], }, ) ) if node_name == "rerank_context": reranked = state.get("reranked_references") or [] events.append( ( "retrieval_reranked", { "callback_task_id": callback_task_id, "retrieval_status": state.get("retrieval_status"), "retrieval_method": state.get("retrieval_method"), "retrieval_metrics": state.get("retrieval_metrics") or {}, "rerank_count": len(reranked), "references": _limited_items(reranked, _pack_reference_preview), "warnings": state.get("warnings") or [], }, ) ) if node_name == "quality_gate": approved = state.get("approved_references") or [] retrieval_payload = { "callback_task_id": callback_task_id, "retrieval_status": state.get("retrieval_status"), "retrieval_method": state.get("retrieval_method"), "retrieval_metrics": state.get("retrieval_metrics") or {}, "approved_count": len(approved), "references": _limited_items(approved, _pack_reference_preview), "warnings": state.get("warnings") or [], } events.append(("retrieval_approved", retrieval_payload)) events.append(("retrieval", retrieval_payload)) intent_result = state.get("intent_result") or {} skill_name = intent_result.get("skill_name") or "" if skill_name and not skill_started_sent: response_type = "proposal" if skill_name == "document-modify" else "answer" events.append( ( "skill_started", { "callback_task_id": callback_task_id, "skill_name": skill_name, "response_type": response_type, }, ) ) skill_started_sent = True if node_name == "build_diff": diff_result = state.get("diff_result") or {} events.append( ( "diff_ready", { "callback_task_id": callback_task_id, "diff_granularity": diff_result.get("diff_granularity"), "diff_count": len(diff_result.get("diff") or []), "old_content_hash": diff_result.get("old_content_hash"), "new_content_hash": diff_result.get("new_content_hash"), }, ) ) return events, skill_started_sent @document_chat_router.post("/document_chat") @auto_trace(generate_if_missing=True) async def document_chat(request: DocumentChatRequest, stream: bool = Query(False)): callback_task_id = f"doc_chat_{uuid.uuid4().hex[:12]}" TraceContext.set_trace_id(callback_task_id) log_document_chat_event( "request_received", callback_task_id, { "stream": stream, "response_mode": request.response_mode, "request": model_to_dict(request), }, ) if stream or request.response_mode == "sse": return StreamingResponse( _generate_document_chat_events(callback_task_id, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) try: workflow = get_document_chat_workflow() state = await workflow.run(request, callback_task_id) data = workflow.to_response_data(state) data_dict = model_to_dict(data) log_document_chat_event("response_completed", callback_task_id, data_dict) code = 500 if data.response_type == "error" else 200 message = data.error_message if data.response_type == "error" else "success" return DocumentChatResponse(code=code, message=message or "success", data=data) except Exception as exc: logger.error(f"[DocumentChat] request failed: {exc}", exc_info=True) log_document_chat_event( "request_failed", callback_task_id, {"error": str(exc), "request": model_to_dict(request)}, level="error", ) raise HTTPException(status_code=500, detail=str(exc)) async def _generate_document_chat_events( callback_task_id: str, request: DocumentChatRequest, ) -> AsyncGenerator[str, None]: started_at = time.time() try: yield format_sse_event( "connected", { "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time()), }, ) yield format_sse_event( "processing", { "callback_task_id": callback_task_id, "stage_name": "workflow_started", "status": "processing", "message": "文档 AI 对话工作流已启动", }, ) workflow = get_document_chat_workflow() state = workflow.build_initial_state(request, callback_task_id) graph_state = dict(state) skill_started_sent = False async for raw_update in workflow.get_graph().astream(graph_state, stream_mode="updates"): for node_name, node_update in _iter_node_updates(raw_update): _merge_state_update(state, node_update) realtime_events, skill_started_sent = _build_realtime_events( callback_task_id, state, node_name, skill_started_sent, ) for event_type, event_data in realtime_events: yield format_sse_event(event_type, event_data) data = workflow.to_response_data(state) data_dict = model_to_dict(data) log_document_chat_event("response_completed", callback_task_id, data_dict) if data.response_type == "answer" and data.answer: yield format_sse_event( "chunk", { "callback_task_id": callback_task_id, "chunk": data.answer, }, ) yield format_sse_event("answer_completed", data_dict) elif data.response_type == "proposal": if data.proposed_content: yield format_sse_event( "chunk", { "callback_task_id": callback_task_id, "chunk": data.proposed_content, }, ) yield format_sse_event("proposal_completed", data_dict) elif data.response_type in ("clarify", "unsupported"): yield format_sse_event("answer_completed", data_dict) else: yield format_sse_event("error", data_dict) yield format_sse_event( "completed", { "callback_task_id": callback_task_id, "status": state.get("overall_task_status", "completed"), "duration": round(time.time() - started_at, 3), }, ) except Exception as exc: logger.error(f"[DocumentChat] SSE request failed: {exc}", exc_info=True) log_document_chat_event( "request_failed", callback_task_id, {"error": str(exc), "request": model_to_dict(request)}, level="error", ) yield format_sse_event( "error", { "callback_task_id": callback_task_id, "status": "error", "message": str(exc), }, ) @document_chat_router.get("/document_chat/health") async def document_chat_health(): return { "status": "healthy", "module": "document_chat", "workflow": "langgraph", "skills": ["document-answer", "document-modify"], }