|
@@ -0,0 +1,317 @@
|
|
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
|
|
+"""LangGraph workflow for document chat."""
|
|
|
|
|
+
|
|
|
|
|
+import uuid
|
|
|
|
|
+from typing import Any, Dict, Optional
|
|
|
|
|
+
|
|
|
|
|
+from langgraph.graph import END, StateGraph
|
|
|
|
|
+
|
|
|
|
|
+from foundation.observability.logger.loggering import write_logger as logger
|
|
|
|
|
+
|
|
|
|
|
+from core.document_chat.component.conversation_context import ConversationContextBuilder
|
|
|
|
|
+from core.document_chat.component.diff_service import DiffService
|
|
|
|
|
+from core.document_chat.component.intent_recognizer import IntentRecognizer
|
|
|
|
|
+from core.document_chat.component.skill_dispatcher import SkillDispatcher
|
|
|
|
|
+from core.document_chat.component.state_models import DocumentChatState
|
|
|
|
|
+from core.document_chat.schemas import (
|
|
|
|
|
+ DiffResult,
|
|
|
|
|
+ DocumentChatData,
|
|
|
|
|
+ DocumentChatRequest,
|
|
|
|
|
+ DocumentChatSkillInput,
|
|
|
|
|
+ DocumentChatSkillOutput,
|
|
|
|
|
+ DocumentContext,
|
|
|
|
|
+ IntentResult,
|
|
|
|
|
+ SelectedSection,
|
|
|
|
|
+ model_to_dict,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DocumentChatWorkflow:
|
|
|
|
|
+ """Document chat workflow built with LangGraph."""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ self.intent_recognizer = IntentRecognizer()
|
|
|
|
|
+ self.skill_dispatcher = SkillDispatcher()
|
|
|
|
|
+ self.diff_service = DiffService()
|
|
|
|
|
+ self.context_builder = ConversationContextBuilder()
|
|
|
|
|
+ self.graph = None
|
|
|
|
|
+
|
|
|
|
|
+ def build_graph(self):
|
|
|
|
|
+ workflow = StateGraph(DocumentChatState)
|
|
|
|
|
+ workflow.add_node("validate_input", self.validate_input_node)
|
|
|
|
|
+ workflow.add_node("load_context", self.load_context_node)
|
|
|
|
|
+ workflow.add_node("load_skill_registry", self.load_skill_registry_node)
|
|
|
|
|
+ workflow.add_node("recognize_intent", self.recognize_intent_node)
|
|
|
|
|
+ workflow.add_node("route_intent", self.route_intent_node)
|
|
|
|
|
+ workflow.add_node("clarify", self.clarify_node)
|
|
|
|
|
+ workflow.add_node("unsupported", self.unsupported_node)
|
|
|
|
|
+ workflow.add_node("run_answer_skill", self.run_answer_skill_node)
|
|
|
|
|
+ workflow.add_node("run_modify_skill", self.run_modify_skill_node)
|
|
|
|
|
+ workflow.add_node("build_diff", self.build_diff_node)
|
|
|
|
|
+ workflow.add_node("error_handler", self.error_handler_node)
|
|
|
|
|
+ workflow.add_node("complete", self.complete_node)
|
|
|
|
|
+
|
|
|
|
|
+ workflow.set_entry_point("validate_input")
|
|
|
|
|
+ workflow.add_edge("validate_input", "load_context")
|
|
|
|
|
+ workflow.add_edge("load_context", "load_skill_registry")
|
|
|
|
|
+ workflow.add_edge("load_skill_registry", "recognize_intent")
|
|
|
|
|
+ workflow.add_edge("recognize_intent", "route_intent")
|
|
|
|
|
+ workflow.add_conditional_edges(
|
|
|
|
|
+ "route_intent",
|
|
|
|
|
+ self.route_intent,
|
|
|
|
|
+ {
|
|
|
|
|
+ "clarify": "clarify",
|
|
|
|
|
+ "unsupported": "unsupported",
|
|
|
|
|
+ "answer": "run_answer_skill",
|
|
|
|
|
+ "modify": "run_modify_skill",
|
|
|
|
|
+ "error": "error_handler",
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ workflow.add_edge("clarify", "complete")
|
|
|
|
|
+ workflow.add_edge("unsupported", "complete")
|
|
|
|
|
+ workflow.add_edge("run_answer_skill", "complete")
|
|
|
|
|
+ workflow.add_edge("run_modify_skill", "build_diff")
|
|
|
|
|
+ workflow.add_edge("build_diff", "complete")
|
|
|
|
|
+ workflow.add_edge("error_handler", "complete")
|
|
|
|
|
+ workflow.add_edge("complete", END)
|
|
|
|
|
+ return workflow.compile()
|
|
|
|
|
+
|
|
|
|
|
+ def get_graph(self):
|
|
|
|
|
+ if self.graph is None:
|
|
|
|
|
+ self.graph = self.build_graph()
|
|
|
|
|
+ return self.graph
|
|
|
|
|
+
|
|
|
|
|
+ async def run(self, request: DocumentChatRequest, callback_task_id: Optional[str] = None) -> DocumentChatState:
|
|
|
|
|
+ task_id = callback_task_id or f"doc_chat_{uuid.uuid4().hex[:12]}"
|
|
|
|
|
+ initial_state: DocumentChatState = {
|
|
|
|
|
+ "callback_task_id": task_id,
|
|
|
|
|
+ "user_id": request.user_id,
|
|
|
|
|
+ "conversation_id": request.conversation_id,
|
|
|
|
|
+ "task_id": request.task_id,
|
|
|
|
|
+ "project_info": request.project_info,
|
|
|
|
|
+ "selected_section": model_to_dict(request.selected_section),
|
|
|
|
|
+ "document_context": model_to_dict(request.document_context),
|
|
|
|
|
+ "conversation_history": request.conversation_history,
|
|
|
|
|
+ "user_message": request.message,
|
|
|
|
|
+ "skill_registry": [],
|
|
|
|
|
+ "intent_result": None,
|
|
|
|
|
+ "skill_result": None,
|
|
|
|
|
+ "diff_result": None,
|
|
|
|
|
+ "response_type": None,
|
|
|
|
|
+ "current_stage": "start",
|
|
|
|
|
+ "overall_task_status": "processing",
|
|
|
|
|
+ "error_message": None,
|
|
|
|
|
+ "warnings": [],
|
|
|
|
|
+ "messages": [],
|
|
|
|
|
+ }
|
|
|
|
|
+ return await self.get_graph().ainvoke(initial_state)
|
|
|
|
|
+
|
|
|
|
|
+ async def validate_input_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ try:
|
|
|
|
|
+ selected_section = state.get("selected_section") or {}
|
|
|
|
|
+ user_message = (state.get("user_message") or "").strip()
|
|
|
|
|
+ if not state.get("user_id"):
|
|
|
|
|
+ raise ValueError("user_id is required")
|
|
|
|
|
+ if not user_message:
|
|
|
|
|
+ raise ValueError("message is required")
|
|
|
|
|
+ if not selected_section.get("index") or not selected_section.get("title"):
|
|
|
|
|
+ raise ValueError("selected_section.index and selected_section.title are required")
|
|
|
|
|
+ if "content" not in selected_section:
|
|
|
|
|
+ selected_section["content"] = ""
|
|
|
|
|
+ return {
|
|
|
|
|
+ "selected_section": selected_section,
|
|
|
|
|
+ "user_message": user_message,
|
|
|
|
|
+ "current_stage": "validate_input",
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ return self._error_update("validate_input", exc)
|
|
|
|
|
+
|
|
|
|
|
+ async def load_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ context = self.context_builder.build(state)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "project_info": context["project_info"],
|
|
|
|
|
+ "selected_section": context["selected_section"],
|
|
|
|
|
+ "document_context": context["document_context"],
|
|
|
|
|
+ "conversation_history": context["conversation_history"],
|
|
|
|
|
+ "current_stage": "load_context",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def load_skill_registry_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ return {
|
|
|
|
|
+ "skill_registry": self.skill_dispatcher.registry_for_prompt(),
|
|
|
|
|
+ "current_stage": "load_skill_registry",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def recognize_intent_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ try:
|
|
|
|
|
+ intent_result = await self.intent_recognizer.recognize(state)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "intent_result": model_to_dict(intent_result),
|
|
|
|
|
+ "current_stage": "recognize_intent",
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ return self._error_update("recognize_intent", exc)
|
|
|
|
|
+
|
|
|
|
|
+ async def route_intent_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ return {"current_stage": "route_intent"}
|
|
|
|
|
+
|
|
|
|
|
+ def route_intent(self, state: DocumentChatState) -> str:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return "error"
|
|
|
|
|
+ intent_data = state.get("intent_result") or {}
|
|
|
|
|
+ try:
|
|
|
|
|
+ intent = IntentResult(**intent_data)
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ return "error"
|
|
|
|
|
+ if intent.needs_clarification or intent.intent == "clarify" or intent.confidence < 0.65:
|
|
|
|
|
+ return "clarify"
|
|
|
|
|
+ if intent.intent == "unsupported":
|
|
|
|
|
+ return "unsupported"
|
|
|
|
|
+ if intent.skill_name == "document-answer":
|
|
|
|
|
+ return "answer"
|
|
|
|
|
+ if intent.skill_name == "document-modify":
|
|
|
|
|
+ return "modify"
|
|
|
|
|
+ return "error"
|
|
|
|
|
+
|
|
|
|
|
+ async def clarify_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ intent = IntentResult(**(state.get("intent_result") or {"intent": "clarify"}))
|
|
|
|
|
+ question = intent.clarification_question or "请补充说明希望 AI 对当前章节做什么。"
|
|
|
|
|
+ skill_result = DocumentChatSkillOutput(
|
|
|
|
|
+ skill_name="",
|
|
|
|
|
+ response_type="clarify",
|
|
|
|
|
+ answer=question,
|
|
|
|
|
+ warnings=intent.warnings,
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
|
|
+ "response_type": "clarify",
|
|
|
|
|
+ "current_stage": "clarify",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def unsupported_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ intent = IntentResult(**(state.get("intent_result") or {"intent": "unsupported"}))
|
|
|
|
|
+ message = intent.reason or "当前 AI 对话模块只支持选中章节的问答和修改。"
|
|
|
|
|
+ skill_result = DocumentChatSkillOutput(
|
|
|
|
|
+ skill_name="",
|
|
|
|
|
+ response_type="unsupported",
|
|
|
|
|
+ answer=message,
|
|
|
|
|
+ warnings=intent.warnings,
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
|
|
+ "response_type": "unsupported",
|
|
|
|
|
+ "current_stage": "unsupported",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def run_answer_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ return await self._run_skill(state, "document-answer", "run_answer_skill")
|
|
|
|
|
+
|
|
|
|
|
+ async def run_modify_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ return await self._run_skill(state, "document-modify", "run_modify_skill")
|
|
|
|
|
+
|
|
|
|
|
+ async def _run_skill(self, state: DocumentChatState, skill_name: str, stage: str) -> Dict[str, Any]:
|
|
|
|
|
+ try:
|
|
|
|
|
+ skill_input = self._build_skill_input(state)
|
|
|
|
|
+ skill_result = await self.skill_dispatcher.run_skill(skill_name, skill_input)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
|
|
+ "response_type": skill_result.response_type,
|
|
|
|
|
+ "current_stage": stage,
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ return self._error_update(stage, exc)
|
|
|
|
|
+
|
|
|
|
|
+ async def build_diff_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ skill_result = state.get("skill_result") or {}
|
|
|
|
|
+ old_content = skill_result.get("old_content")
|
|
|
|
|
+ if old_content is None:
|
|
|
|
|
+ old_content = (state.get("selected_section") or {}).get("content", "")
|
|
|
|
|
+ new_content = skill_result.get("proposed_content") or ""
|
|
|
|
|
+ diff_result = self.diff_service.build_diff(old_content, new_content)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "diff_result": model_to_dict(diff_result),
|
|
|
|
|
+ "current_stage": "build_diff",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def error_handler_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ error_message = state.get("error_message") or "document chat workflow failed"
|
|
|
|
|
+ logger.error(f"[DocumentChat] workflow error: {error_message}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "response_type": "error",
|
|
|
|
|
+ "overall_task_status": "failed",
|
|
|
|
|
+ "current_stage": "error_handler",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def complete_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("overall_task_status") == "failed":
|
|
|
|
|
+ return {"current_stage": "complete"}
|
|
|
|
|
+ return {
|
|
|
|
|
+ "overall_task_status": "completed",
|
|
|
|
|
+ "current_stage": "complete",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def to_response_data(self, state: DocumentChatState) -> DocumentChatData:
|
|
|
|
|
+ skill_result = state.get("skill_result") or {}
|
|
|
|
|
+ intent_result = state.get("intent_result")
|
|
|
|
|
+ diff_result = state.get("diff_result") or {}
|
|
|
|
|
+ selected_section = state.get("selected_section") or {}
|
|
|
|
|
+ warnings = []
|
|
|
|
|
+ warnings.extend(state.get("warnings") or [])
|
|
|
|
|
+ warnings.extend(skill_result.get("warnings") or [])
|
|
|
|
|
+ if intent_result:
|
|
|
|
|
+ warnings.extend(intent_result.get("warnings") or [])
|
|
|
|
|
+
|
|
|
|
|
+ response_type = state.get("response_type") or skill_result.get("response_type") or "error"
|
|
|
|
|
+ return DocumentChatData(
|
|
|
|
|
+ callback_task_id=state.get("callback_task_id", ""),
|
|
|
|
|
+ response_type=response_type,
|
|
|
|
|
+ intent_result=intent_result,
|
|
|
|
|
+ answer=skill_result.get("answer"),
|
|
|
|
|
+ proposed_content=skill_result.get("proposed_content"),
|
|
|
|
|
+ old_content_hash=diff_result.get("old_content_hash"),
|
|
|
|
|
+ new_content_hash=diff_result.get("new_content_hash"),
|
|
|
|
|
+ diff=diff_result.get("diff") or [],
|
|
|
|
|
+ diff_granularity=diff_result.get("diff_granularity"),
|
|
|
|
|
+ change_summary=skill_result.get("change_summary") or [],
|
|
|
|
|
+ references=skill_result.get("references") or [],
|
|
|
|
|
+ warnings=warnings,
|
|
|
|
|
+ selected_section={
|
|
|
|
|
+ "index": selected_section.get("index", ""),
|
|
|
|
|
+ "code": selected_section.get("code", ""),
|
|
|
|
|
+ "title": selected_section.get("title", ""),
|
|
|
|
|
+ },
|
|
|
|
|
+ error_message=state.get("error_message"),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _build_skill_input(self, state: DocumentChatState) -> DocumentChatSkillInput:
|
|
|
|
|
+ return DocumentChatSkillInput(
|
|
|
|
|
+ user_id=state.get("user_id", ""),
|
|
|
|
|
+ conversation_id=state.get("conversation_id"),
|
|
|
|
|
+ task_id=state.get("task_id"),
|
|
|
|
|
+ project_info=state.get("project_info") or {},
|
|
|
|
|
+ selected_section=SelectedSection(**(state.get("selected_section") or {})),
|
|
|
|
|
+ document_context=DocumentContext(**(state.get("document_context") or {})),
|
|
|
|
|
+ conversation_history=state.get("conversation_history") or [],
|
|
|
|
|
+ user_message=state.get("user_message", ""),
|
|
|
|
|
+ intent_result=IntentResult(**(state.get("intent_result") or {})),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _error_update(stage: str, exc: Exception) -> Dict[str, Any]:
|
|
|
|
|
+ return {
|
|
|
|
|
+ "current_stage": stage,
|
|
|
|
|
+ "overall_task_status": "failed",
|
|
|
|
|
+ "response_type": "error",
|
|
|
|
|
+ "error_message": str(exc),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+document_chat_workflow = DocumentChatWorkflow()
|