|
|
@@ -0,0 +1,516 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""LangGraph workflow for document chat."""
|
|
|
+
|
|
|
+import uuid
|
|
|
+from typing import Any, Dict, Optional
|
|
|
+
|
|
|
+from langgraph.graph import END, StateGraph
|
|
|
+
|
|
|
+from foundation.observability.logger.loggering import write_logger as logger
|
|
|
+
|
|
|
+from core.document_chat.component.conversation_context import ConversationContextBuilder
|
|
|
+from core.document_chat.component.diff_service import DiffService
|
|
|
+from core.document_chat.component.document_chat_logger import log_document_chat_event
|
|
|
+from core.document_chat.component.intent_recognizer import IntentRecognizer
|
|
|
+from core.document_chat.component.rerank_service import DocumentChatRerankService
|
|
|
+from core.document_chat.component.retrieval_quality_gate import RetrievalQualityGate
|
|
|
+from core.document_chat.component.retrieval_service import DocumentChatRetrievalService
|
|
|
+from core.document_chat.component.skill_dispatcher import SkillDispatcher
|
|
|
+from core.document_chat.component.state_models import DocumentChatState
|
|
|
+from core.document_chat.schemas import (
|
|
|
+ DiffResult,
|
|
|
+ DocumentChatData,
|
|
|
+ DocumentChatRequest,
|
|
|
+ DocumentChatSkillInput,
|
|
|
+ DocumentChatSkillOutput,
|
|
|
+ DocumentContext,
|
|
|
+ IntentResult,
|
|
|
+ SelectedSection,
|
|
|
+ model_to_dict,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class DocumentChatWorkflow:
|
|
|
+ """Document chat workflow built with LangGraph."""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.intent_recognizer = IntentRecognizer()
|
|
|
+ self.skill_dispatcher = SkillDispatcher()
|
|
|
+ self.diff_service = DiffService()
|
|
|
+ self.context_builder = ConversationContextBuilder()
|
|
|
+ self.retrieval_service = DocumentChatRetrievalService()
|
|
|
+ self.rerank_service = DocumentChatRerankService(self.retrieval_service.config)
|
|
|
+ self.quality_gate = RetrievalQualityGate(self.retrieval_service.config)
|
|
|
+ self.graph = None
|
|
|
+
|
|
|
+ def build_graph(self):
|
|
|
+ workflow = StateGraph(DocumentChatState)
|
|
|
+ workflow.add_node("validate_input", self.validate_input_node)
|
|
|
+ workflow.add_node("load_context", self.load_context_node)
|
|
|
+ workflow.add_node("load_skill_registry", self.load_skill_registry_node)
|
|
|
+ workflow.add_node("recognize_intent", self.recognize_intent_node)
|
|
|
+ workflow.add_node("route_intent", self.route_intent_node)
|
|
|
+ workflow.add_node("build_retrieval_query", self.build_retrieval_query_node)
|
|
|
+ workflow.add_node("vector_recall", self.vector_recall_node)
|
|
|
+ workflow.add_node("rerank_context", self.rerank_context_node)
|
|
|
+ workflow.add_node("quality_gate", self.quality_gate_node)
|
|
|
+ workflow.add_node("clarify", self.clarify_node)
|
|
|
+ workflow.add_node("unsupported", self.unsupported_node)
|
|
|
+ workflow.add_node("run_answer_skill", self.run_answer_skill_node)
|
|
|
+ workflow.add_node("run_modify_skill", self.run_modify_skill_node)
|
|
|
+ workflow.add_node("build_diff", self.build_diff_node)
|
|
|
+ workflow.add_node("error_handler", self.error_handler_node)
|
|
|
+ workflow.add_node("complete", self.complete_node)
|
|
|
+
|
|
|
+ workflow.set_entry_point("validate_input")
|
|
|
+ workflow.add_edge("validate_input", "load_context")
|
|
|
+ workflow.add_edge("load_context", "load_skill_registry")
|
|
|
+ workflow.add_edge("load_skill_registry", "recognize_intent")
|
|
|
+ workflow.add_edge("recognize_intent", "route_intent")
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "route_intent",
|
|
|
+ self.route_intent,
|
|
|
+ {
|
|
|
+ "clarify": "clarify",
|
|
|
+ "unsupported": "unsupported",
|
|
|
+ "answer": "build_retrieval_query",
|
|
|
+ "modify": "build_retrieval_query",
|
|
|
+ "error": "error_handler",
|
|
|
+ },
|
|
|
+ )
|
|
|
+ workflow.add_edge("build_retrieval_query", "vector_recall")
|
|
|
+ workflow.add_edge("vector_recall", "rerank_context")
|
|
|
+ workflow.add_edge("rerank_context", "quality_gate")
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "quality_gate",
|
|
|
+ self.route_after_retrieval,
|
|
|
+ {
|
|
|
+ "answer": "run_answer_skill",
|
|
|
+ "modify": "run_modify_skill",
|
|
|
+ "error": "error_handler",
|
|
|
+ },
|
|
|
+ )
|
|
|
+ workflow.add_edge("clarify", "complete")
|
|
|
+ workflow.add_edge("unsupported", "complete")
|
|
|
+ workflow.add_edge("run_answer_skill", "complete")
|
|
|
+ workflow.add_edge("run_modify_skill", "build_diff")
|
|
|
+ workflow.add_edge("build_diff", "complete")
|
|
|
+ workflow.add_edge("error_handler", "complete")
|
|
|
+ workflow.add_edge("complete", END)
|
|
|
+ return workflow.compile()
|
|
|
+
|
|
|
+ def get_graph(self):
|
|
|
+ if self.graph is None:
|
|
|
+ self.graph = self.build_graph()
|
|
|
+ return self.graph
|
|
|
+
|
|
|
+ def build_initial_state(self, request: DocumentChatRequest, callback_task_id: Optional[str] = None) -> DocumentChatState:
|
|
|
+ task_id = callback_task_id or f"doc_chat_{uuid.uuid4().hex[:12]}"
|
|
|
+ return {
|
|
|
+ "callback_task_id": task_id,
|
|
|
+ "user_id": request.user_id,
|
|
|
+ "conversation_id": request.conversation_id,
|
|
|
+ "task_id": request.task_id,
|
|
|
+ "project_info": request.project_info,
|
|
|
+ "selected_section": model_to_dict(request.selected_section),
|
|
|
+ "document_context": model_to_dict(request.document_context),
|
|
|
+ "conversation_history": request.conversation_history,
|
|
|
+ "user_message": request.message,
|
|
|
+ "skill_registry": [],
|
|
|
+ "retrieval_query": None,
|
|
|
+ "retrieval_method": None,
|
|
|
+ "retrieval_candidates": [],
|
|
|
+ "reranked_references": [],
|
|
|
+ "approved_references": [],
|
|
|
+ "retrieval_status": None,
|
|
|
+ "retrieval_metrics": {},
|
|
|
+ "intent_result": None,
|
|
|
+ "skill_result": None,
|
|
|
+ "diff_result": None,
|
|
|
+ "response_type": None,
|
|
|
+ "current_stage": "start",
|
|
|
+ "overall_task_status": "processing",
|
|
|
+ "error_message": None,
|
|
|
+ "warnings": [],
|
|
|
+ "messages": [],
|
|
|
+ }
|
|
|
+
|
|
|
+ async def run(self, request: DocumentChatRequest, callback_task_id: Optional[str] = None) -> DocumentChatState:
|
|
|
+ initial_state = self.build_initial_state(request, callback_task_id)
|
|
|
+ return await self.get_graph().ainvoke(initial_state)
|
|
|
+
|
|
|
+ async def validate_input_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ try:
|
|
|
+ selected_section = state.get("selected_section") or {}
|
|
|
+ user_message = (state.get("user_message") or "").strip()
|
|
|
+ if not state.get("user_id"):
|
|
|
+ raise ValueError("user_id is required")
|
|
|
+ if not user_message:
|
|
|
+ raise ValueError("message is required")
|
|
|
+ if not selected_section.get("index") or not selected_section.get("title"):
|
|
|
+ raise ValueError("selected_section.index and selected_section.title are required")
|
|
|
+ if "content" not in selected_section:
|
|
|
+ selected_section["content"] = ""
|
|
|
+ return {
|
|
|
+ "selected_section": selected_section,
|
|
|
+ "user_message": user_message,
|
|
|
+ "current_stage": "validate_input",
|
|
|
+ }
|
|
|
+ except Exception as exc:
|
|
|
+ return self._error_update("validate_input", exc)
|
|
|
+
|
|
|
+ async def load_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ context = self.context_builder.build(state)
|
|
|
+ return {
|
|
|
+ "project_info": context["project_info"],
|
|
|
+ "selected_section": context["selected_section"],
|
|
|
+ "document_context": context["document_context"],
|
|
|
+ "conversation_history": context["conversation_history"],
|
|
|
+ "current_stage": "load_context",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def load_skill_registry_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ return {
|
|
|
+ "skill_registry": self.skill_dispatcher.registry_for_prompt(),
|
|
|
+ "current_stage": "load_skill_registry",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def recognize_intent_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ try:
|
|
|
+ intent_result = await self.intent_recognizer.recognize(state)
|
|
|
+ return {
|
|
|
+ "intent_result": model_to_dict(intent_result),
|
|
|
+ "current_stage": "recognize_intent",
|
|
|
+ }
|
|
|
+ except Exception as exc:
|
|
|
+ return self._error_update("recognize_intent", exc)
|
|
|
+
|
|
|
+ async def route_intent_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ return {"current_stage": "route_intent"}
|
|
|
+
|
|
|
+ def route_intent(self, state: DocumentChatState) -> str:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return "error"
|
|
|
+ intent_data = state.get("intent_result") or {}
|
|
|
+ try:
|
|
|
+ intent = IntentResult(**intent_data)
|
|
|
+ except Exception:
|
|
|
+ return "error"
|
|
|
+ if intent.needs_clarification or intent.intent == "clarify" or intent.confidence < 0.65:
|
|
|
+ return "clarify"
|
|
|
+ if intent.skill_name == "document-answer":
|
|
|
+ return "answer"
|
|
|
+ if intent.skill_name == "document-modify":
|
|
|
+ return "modify"
|
|
|
+ if intent.intent == "unsupported":
|
|
|
+ return "unsupported"
|
|
|
+ return "error"
|
|
|
+
|
|
|
+ def route_after_retrieval(self, state: DocumentChatState) -> str:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return "error"
|
|
|
+ intent_data = state.get("intent_result") or {}
|
|
|
+ skill_name = intent_data.get("skill_name")
|
|
|
+ if skill_name == "document-answer":
|
|
|
+ return "answer"
|
|
|
+ if skill_name == "document-modify":
|
|
|
+ return "modify"
|
|
|
+ return "error"
|
|
|
+
|
|
|
+ async def build_retrieval_query_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ query = self.retrieval_service.build_query(state)
|
|
|
+ log_document_chat_event(
|
|
|
+ "rag_query_built",
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
+ {
|
|
|
+ "retrieval_query": query,
|
|
|
+ "intent_result": state.get("intent_result"),
|
|
|
+ "selected_section": state.get("selected_section"),
|
|
|
+ "project_info": state.get("project_info"),
|
|
|
+ "document_context": state.get("document_context"),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "retrieval_query": query,
|
|
|
+ "current_stage": "build_retrieval_query",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def vector_recall_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ result = self.retrieval_service.recall(state)
|
|
|
+ log_document_chat_event(
|
|
|
+ "rag_recall_completed",
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
+ {
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
+ "retrieval_method": result.get("retrieval_method"),
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": result.get("retrieval_metrics") or {},
|
|
|
+ "retrieval_candidates": result.get("retrieval_candidates") or [],
|
|
|
+ "warnings": result.get("warnings") or [],
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "retrieval_candidates": result.get("retrieval_candidates") or [],
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
+ "retrieval_method": result.get("retrieval_method"),
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, result.get("retrieval_metrics") or {}),
|
|
|
+ "warnings": self._append_warnings(state, result.get("warnings") or []),
|
|
|
+ "current_stage": "vector_recall",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def rerank_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ if state.get("retrieval_status") != "recalled":
|
|
|
+ log_document_chat_event(
|
|
|
+ "rag_rerank_skipped",
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
+ {
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
+ "retrieval_status": state.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": state.get("retrieval_metrics") or {},
|
|
|
+ "warnings": state.get("warnings") or [],
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "reranked_references": [],
|
|
|
+ "approved_references": [],
|
|
|
+ "current_stage": "rerank_context",
|
|
|
+ }
|
|
|
+
|
|
|
+ result = self.rerank_service.rerank(
|
|
|
+ query=state.get("retrieval_query") or "",
|
|
|
+ candidates=state.get("retrieval_candidates") or [],
|
|
|
+ )
|
|
|
+ log_document_chat_event(
|
|
|
+ "rag_rerank_completed",
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
+ {
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": result.get("retrieval_metrics") or {},
|
|
|
+ "retrieval_candidates": state.get("retrieval_candidates") or [],
|
|
|
+ "reranked_references": result.get("reranked_references") or [],
|
|
|
+ "warnings": result.get("warnings") or [],
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "reranked_references": result.get("reranked_references") or [],
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, result.get("retrieval_metrics") or {}),
|
|
|
+ "warnings": self._append_warnings(state, result.get("warnings") or []),
|
|
|
+ "current_stage": "rerank_context",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def quality_gate_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ if state.get("retrieval_status") != "reranked":
|
|
|
+ log_document_chat_event(
|
|
|
+ "rag_quality_gate_skipped",
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
+ {
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
+ "retrieval_status": state.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, {"approved_count": 0}),
|
|
|
+ "reranked_references": state.get("reranked_references") or [],
|
|
|
+ "warnings": state.get("warnings") or [],
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "approved_references": [],
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, {"approved_count": 0}),
|
|
|
+ "current_stage": "quality_gate",
|
|
|
+ }
|
|
|
+
|
|
|
+ result = self.quality_gate.apply(state.get("reranked_references") or [])
|
|
|
+ log_document_chat_event(
|
|
|
+ "rag_quality_gate_completed",
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
+ {
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": result.get("retrieval_metrics") or {},
|
|
|
+ "reranked_references": state.get("reranked_references") or [],
|
|
|
+ "approved_references": result.get("approved_references") or [],
|
|
|
+ "warnings": result.get("warnings") or [],
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "approved_references": result.get("approved_references") or [],
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, result.get("retrieval_metrics") or {}),
|
|
|
+ "warnings": self._append_warnings(state, result.get("warnings") or []),
|
|
|
+ "current_stage": "quality_gate",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def clarify_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ intent = IntentResult(**(state.get("intent_result") or {"intent": "clarify"}))
|
|
|
+ question = intent.clarification_question or "请补充说明希望 AI 对当前章节做什么。"
|
|
|
+ skill_result = DocumentChatSkillOutput(
|
|
|
+ skill_name="",
|
|
|
+ response_type="clarify",
|
|
|
+ answer=question,
|
|
|
+ warnings=intent.warnings,
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
+ "response_type": "clarify",
|
|
|
+ "current_stage": "clarify",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def unsupported_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ intent = IntentResult(**(state.get("intent_result") or {"intent": "unsupported"}))
|
|
|
+ message = intent.reason or "当前 AI 对话模块只支持选中章节的问答和修改。"
|
|
|
+ skill_result = DocumentChatSkillOutput(
|
|
|
+ skill_name="",
|
|
|
+ response_type="unsupported",
|
|
|
+ answer=message,
|
|
|
+ warnings=intent.warnings,
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
+ "response_type": "unsupported",
|
|
|
+ "current_stage": "unsupported",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def run_answer_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ return await self._run_skill(state, "document-answer", "run_answer_skill")
|
|
|
+
|
|
|
+ async def run_modify_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ return await self._run_skill(state, "document-modify", "run_modify_skill")
|
|
|
+
|
|
|
+ async def _run_skill(self, state: DocumentChatState, skill_name: str, stage: str) -> Dict[str, Any]:
|
|
|
+ try:
|
|
|
+ skill_input = self._build_skill_input(state)
|
|
|
+ skill_result = await self.skill_dispatcher.run_skill(skill_name, skill_input)
|
|
|
+ return {
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
+ "response_type": skill_result.response_type,
|
|
|
+ "current_stage": stage,
|
|
|
+ }
|
|
|
+ except Exception as exc:
|
|
|
+ return self._error_update(stage, exc)
|
|
|
+
|
|
|
+ async def build_diff_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("error_message"):
|
|
|
+ return {}
|
|
|
+ skill_result = state.get("skill_result") or {}
|
|
|
+ old_content = skill_result.get("old_content")
|
|
|
+ if old_content is None:
|
|
|
+ old_content = (state.get("selected_section") or {}).get("content", "")
|
|
|
+ new_content = skill_result.get("proposed_content") or ""
|
|
|
+ diff_result = self.diff_service.build_diff(old_content, new_content)
|
|
|
+ return {
|
|
|
+ "diff_result": model_to_dict(diff_result),
|
|
|
+ "current_stage": "build_diff",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def error_handler_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ error_message = state.get("error_message") or "document chat workflow failed"
|
|
|
+ logger.error(f"[DocumentChat] workflow error: {error_message}")
|
|
|
+ return {
|
|
|
+ "response_type": "error",
|
|
|
+ "overall_task_status": "failed",
|
|
|
+ "current_stage": "error_handler",
|
|
|
+ }
|
|
|
+
|
|
|
+ async def complete_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
+ if state.get("overall_task_status") == "failed":
|
|
|
+ return {"current_stage": "complete"}
|
|
|
+ return {
|
|
|
+ "overall_task_status": "completed",
|
|
|
+ "current_stage": "complete",
|
|
|
+ }
|
|
|
+
|
|
|
+ def to_response_data(self, state: DocumentChatState) -> DocumentChatData:
|
|
|
+ skill_result = state.get("skill_result") or {}
|
|
|
+ intent_result = state.get("intent_result")
|
|
|
+ diff_result = state.get("diff_result") or {}
|
|
|
+ selected_section = state.get("selected_section") or {}
|
|
|
+ warnings = []
|
|
|
+ warnings.extend(state.get("warnings") or [])
|
|
|
+ warnings.extend(skill_result.get("warnings") or [])
|
|
|
+ if intent_result:
|
|
|
+ warnings.extend(intent_result.get("warnings") or [])
|
|
|
+
|
|
|
+ response_type = state.get("response_type") or skill_result.get("response_type") or "error"
|
|
|
+ approved_references = state.get("approved_references") or []
|
|
|
+ return DocumentChatData(
|
|
|
+ callback_task_id=state.get("callback_task_id", ""),
|
|
|
+ response_type=response_type,
|
|
|
+ intent_result=intent_result,
|
|
|
+ answer=skill_result.get("answer"),
|
|
|
+ proposed_content=skill_result.get("proposed_content"),
|
|
|
+ old_content_hash=diff_result.get("old_content_hash"),
|
|
|
+ new_content_hash=diff_result.get("new_content_hash"),
|
|
|
+ diff=diff_result.get("diff") or [],
|
|
|
+ diff_granularity=diff_result.get("diff_granularity"),
|
|
|
+ change_summary=skill_result.get("change_summary") or [],
|
|
|
+ references=approved_references,
|
|
|
+ retrieval_status=state.get("retrieval_status"),
|
|
|
+ retrieval_metrics=self._merge_metrics(state, {"retrieval_method": state.get("retrieval_method")}),
|
|
|
+ warnings=warnings,
|
|
|
+ selected_section={
|
|
|
+ "index": selected_section.get("index", ""),
|
|
|
+ "code": selected_section.get("code", ""),
|
|
|
+ "title": selected_section.get("title", ""),
|
|
|
+ },
|
|
|
+ error_message=state.get("error_message"),
|
|
|
+ )
|
|
|
+
|
|
|
+ def _build_skill_input(self, state: DocumentChatState) -> DocumentChatSkillInput:
|
|
|
+ document_context = dict(state.get("document_context") or {})
|
|
|
+ document_context["references"] = state.get("approved_references") or []
|
|
|
+ return DocumentChatSkillInput(
|
|
|
+ user_id=state.get("user_id", ""),
|
|
|
+ conversation_id=state.get("conversation_id"),
|
|
|
+ task_id=state.get("task_id"),
|
|
|
+ project_info=state.get("project_info") or {},
|
|
|
+ selected_section=SelectedSection(**(state.get("selected_section") or {})),
|
|
|
+ document_context=DocumentContext(**document_context),
|
|
|
+ conversation_history=state.get("conversation_history") or [],
|
|
|
+ user_message=state.get("user_message", ""),
|
|
|
+ intent_result=IntentResult(**(state.get("intent_result") or {})),
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _append_warnings(state: DocumentChatState, new_warnings: list) -> list:
|
|
|
+ warnings = list(state.get("warnings") or [])
|
|
|
+ for warning in new_warnings:
|
|
|
+ warning = str(warning).strip()
|
|
|
+ if warning and warning not in warnings:
|
|
|
+ warnings.append(warning)
|
|
|
+ return warnings
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _merge_metrics(state: DocumentChatState, new_metrics: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ metrics = dict(state.get("retrieval_metrics") or {})
|
|
|
+ metrics.update(new_metrics or {})
|
|
|
+ return metrics
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _error_update(stage: str, exc: Exception) -> Dict[str, Any]:
|
|
|
+ return {
|
|
|
+ "current_stage": stage,
|
|
|
+ "overall_task_status": "failed",
|
|
|
+ "response_type": "error",
|
|
|
+ "error_message": str(exc),
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+document_chat_workflow = DocumentChatWorkflow()
|