| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- # -*- coding: utf-8 -*-
- """HTTP API for document chat."""
- import json
- import time
- import uuid
- from typing import Any, AsyncGenerator, Dict, Iterable, List, Tuple
- from fastapi import APIRouter, HTTPException, Query
- from fastapi.responses import StreamingResponse
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from foundation.observability.logger.loggering import write_logger as logger
- from core.document_chat.component.document_chat_logger import log_document_chat_event
- from core.document_chat.schemas import DocumentChatRequest, DocumentChatResponse, model_to_dict
- document_chat_router = APIRouter(prefix="/sgbx", tags=["文档编辑AI对话"])
- MAX_REFERENCES_PER_EVENT = 8
- REFERENCE_PREVIEW_CHARS = 600
- STAGE_MESSAGES = {
- "workflow_started": "文档 AI 对话工作流已启动",
- "recognize_intent": "已完成用户意图识别",
- "rerank_context": "知识库内容检索重排完成",
- "run_answer_skill": "已生成章节问答结果",
- "run_modify_skill": "已生成章节修改草案",
- "error_handler": "流程异常,已进入错误处理",
- }
- def format_sse_event(event_type: str, data: dict) -> str:
- return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
- def get_document_chat_workflow():
- from core.document_chat.workflows.document_chat_workflow import document_chat_workflow
- return document_chat_workflow
- def _iter_node_updates(raw_update: Any) -> Iterable[Tuple[str, Dict[str, Any]]]:
- if not isinstance(raw_update, dict):
- return []
- updates: List[Tuple[str, Dict[str, Any]]] = []
- for node_name, node_update in raw_update.items():
- if isinstance(node_update, dict):
- updates.append((str(node_name), node_update))
- if updates:
- return updates
- stage = str(raw_update.get("current_stage") or "workflow_update")
- return [(stage, raw_update)]
- def _merge_state_update(state: Dict[str, Any], update: Dict[str, Any]) -> None:
- for key, value in update.items():
- state[key] = value
- def _preview_text(text: Any, limit: int = REFERENCE_PREVIEW_CHARS) -> str:
- value = str(text or "").strip()
- if len(value) <= limit:
- return value
- return value[:limit].rstrip() + "..."
- def _safe_metadata(metadata: Any) -> Dict[str, Any]:
- if not isinstance(metadata, dict):
- return {}
- allowed_keys = (
- "tenant_id",
- "project_id",
- "knowledge_base_id",
- "file_name",
- "chapter_level_1",
- "chapter_level_2",
- "parent_id",
- "parent_count",
- "source_scope_valid",
- )
- return {key: metadata.get(key) for key in allowed_keys if metadata.get(key) not in (None, "")}
- def _pack_reference_preview(item: Dict[str, Any]) -> Dict[str, Any]:
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
- content = item.get("content") if "content" in item else item.get("text")
- data = {
- "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
- "content": _preview_text(content),
- "vector_similarity": item.get("vector_similarity", 0.0),
- "metadata": _safe_metadata(metadata),
- }
- if "rerank_score" in item:
- data["rerank_score"] = item.get("rerank_score", 0.0)
- return data
- def _limited_items(items: List[Dict[str, Any]], packer) -> List[Dict[str, Any]]:
- return [packer(item) for item in (items or [])[:MAX_REFERENCES_PER_EVENT] if isinstance(item, dict)]
- def _reasoning_event(callback_task_id: str, node_name: str, state: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
- status = "failed" if state.get("error_message") else "processing"
- return (
- "reasoning",
- {
- "callback_task_id": callback_task_id,
- "stage_name": node_name,
- "status": status,
- "message": STAGE_MESSAGES.get(node_name, f"已完成 {node_name}"),
- },
- )
- def _build_realtime_events(
- callback_task_id: str,
- state: Dict[str, Any],
- node_name: str,
- skill_started_sent: bool,
- ) -> Tuple[List[Tuple[str, Dict[str, Any]]], bool]:
- events: List[Tuple[str, Dict[str, Any]]] = []
- if node_name in STAGE_MESSAGES:
- events.append(_reasoning_event(callback_task_id, node_name, state))
- if node_name == "recognize_intent" and state.get("intent_result"):
- events.append(
- (
- "intent",
- {
- "callback_task_id": callback_task_id,
- "intent_result": state.get("intent_result"),
- },
- )
- )
- if node_name == "rerank_context":
- reranked = state.get("reranked_references") or []
- events.append(
- (
- "retrieval_result",
- {
- "callback_task_id": callback_task_id,
- "retrieval_status": state.get("retrieval_status"),
- "retrieval_method": state.get("retrieval_method"),
- "retrieval_metrics": state.get("retrieval_metrics") or {},
- "rerank_count": len(reranked),
- "references": _limited_items(reranked, _pack_reference_preview),
- "warnings": state.get("warnings") or [],
- },
- )
- )
- if node_name == "quality_gate":
- intent_result = state.get("intent_result") or {}
- skill_name = intent_result.get("skill_name") or ""
- if skill_name and not skill_started_sent:
- response_type = "proposal" if skill_name == "document-modify" else "answer"
- events.append(
- (
- "skill_started",
- {
- "callback_task_id": callback_task_id,
- "skill_name": skill_name,
- "response_type": response_type,
- },
- )
- )
- skill_started_sent = True
- return events, skill_started_sent
- @document_chat_router.post("/document_chat")
- @auto_trace(generate_if_missing=True)
- async def document_chat(request: DocumentChatRequest, stream: bool = Query(False)):
- callback_task_id = f"doc_chat_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
- log_document_chat_event(
- "request_received",
- callback_task_id,
- {
- "stream": stream,
- "response_mode": request.response_mode,
- "request": model_to_dict(request),
- },
- )
- if stream or request.response_mode == "sse":
- return StreamingResponse(
- _generate_document_chat_events(callback_task_id, request),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- },
- )
- try:
- workflow = get_document_chat_workflow()
- state = await workflow.run(request, callback_task_id)
- data = workflow.to_response_data(state)
- data_dict = model_to_dict(data)
- log_document_chat_event("response_completed", callback_task_id, data_dict)
- code = 500 if data.response_type == "error" else 200
- message = data.error_message if data.response_type == "error" else "success"
- return DocumentChatResponse(code=code, message=message or "success", data=data)
- except Exception as exc:
- logger.error(f"[DocumentChat] request failed: {exc}", exc_info=True)
- log_document_chat_event(
- "request_failed",
- callback_task_id,
- {"error": str(exc), "request": model_to_dict(request)},
- level="error",
- )
- raise HTTPException(status_code=500, detail=str(exc))
- async def _generate_document_chat_events(
- callback_task_id: str,
- request: DocumentChatRequest,
- ) -> AsyncGenerator[str, None]:
- started_at = time.time()
- try:
- yield format_sse_event(
- "connected",
- {
- "callback_task_id": callback_task_id,
- "status": "connected",
- "timestamp": int(time.time()),
- },
- )
- yield format_sse_event(
- "processing",
- {
- "callback_task_id": callback_task_id,
- "stage_name": "workflow_started",
- "status": "processing",
- "message": "文档 AI 对话工作流已启动",
- },
- )
- workflow = get_document_chat_workflow()
- state = workflow.build_initial_state(request, callback_task_id)
- graph_state = dict(state)
- skill_started_sent = False
- async for mode, payload in workflow.get_graph().astream(
- graph_state, stream_mode=["updates", "custom"]
- ):
- if mode == "custom" and isinstance(payload, dict) and payload.get("stream_chunk"):
- yield format_sse_event(
- "chunk",
- {
- "callback_task_id": callback_task_id,
- "chunk": payload["stream_chunk"],
- },
- )
- elif mode == "updates":
- for node_name, node_update in _iter_node_updates(payload):
- _merge_state_update(state, node_update)
- realtime_events, skill_started_sent = _build_realtime_events(
- callback_task_id,
- state,
- node_name,
- skill_started_sent,
- )
- for event_type, event_data in realtime_events:
- yield format_sse_event(event_type, event_data)
- data = workflow.to_response_data(state)
- data_dict = model_to_dict(data)
- log_document_chat_event("response_completed", callback_task_id, data_dict)
- if data.response_type == "answer":
- yield format_sse_event("answer_completed", data_dict)
- elif data.response_type == "proposal":
- yield format_sse_event("proposal_completed", data_dict)
- elif data.response_type in ("clarify", "unsupported"):
- yield format_sse_event("answer_completed", data_dict)
- else:
- yield format_sse_event("error", data_dict)
- if data.response_type != "error":
- yield format_sse_event(
- "completed",
- {
- "callback_task_id": callback_task_id,
- "status": state.get("overall_task_status", "completed"),
- "duration": round(time.time() - started_at, 3),
- },
- )
- except Exception as exc:
- logger.error(f"[DocumentChat] SSE request failed: {exc}", exc_info=True)
- log_document_chat_event(
- "request_failed",
- callback_task_id,
- {"error": str(exc), "request": model_to_dict(request)},
- level="error",
- )
- yield format_sse_event(
- "error",
- {
- "callback_task_id": callback_task_id,
- "status": "error",
- "message": str(exc),
- },
- )
- @document_chat_router.get("/document_chat/health")
- async def document_chat_health():
- return {
- "status": "healthy",
- "module": "document_chat",
- "workflow": "langgraph",
- "skills": ["document-answer", "document-modify"],
- }
|