|
@@ -2,7 +2,7 @@
|
|
|
"""LangGraph workflow for document chat."""
|
|
"""LangGraph workflow for document chat."""
|
|
|
|
|
|
|
|
import uuid
|
|
import uuid
|
|
|
-from typing import Any, Dict, Optional
|
|
|
|
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
from langgraph.graph import END, StateGraph
|
|
from langgraph.graph import END, StateGraph
|
|
|
|
|
|
|
@@ -55,11 +55,20 @@ class DocumentChatWorkflow:
|
|
|
workflow.add_node("unsupported", self.unsupported_node)
|
|
workflow.add_node("unsupported", self.unsupported_node)
|
|
|
workflow.add_node("run_answer_skill", self.run_answer_skill_node)
|
|
workflow.add_node("run_answer_skill", self.run_answer_skill_node)
|
|
|
workflow.add_node("run_modify_skill", self.run_modify_skill_node)
|
|
workflow.add_node("run_modify_skill", self.run_modify_skill_node)
|
|
|
|
|
+ workflow.add_node("general_answer", self.general_answer_node)
|
|
|
workflow.add_node("error_handler", self.error_handler_node)
|
|
workflow.add_node("error_handler", self.error_handler_node)
|
|
|
workflow.add_node("complete", self.complete_node)
|
|
workflow.add_node("complete", self.complete_node)
|
|
|
|
|
|
|
|
workflow.set_entry_point("validate_input")
|
|
workflow.set_entry_point("validate_input")
|
|
|
- workflow.add_edge("validate_input", "load_context")
|
|
|
|
|
|
|
+ workflow.add_conditional_edges(
|
|
|
|
|
+ "validate_input",
|
|
|
|
|
+ self.route_after_validate,
|
|
|
|
|
+ {
|
|
|
|
|
+ "general": "general_answer",
|
|
|
|
|
+ "normal": "load_context",
|
|
|
|
|
+ "error": "error_handler",
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
workflow.add_edge("load_context", "load_skill_registry")
|
|
workflow.add_edge("load_context", "load_skill_registry")
|
|
|
workflow.add_edge("load_skill_registry", "recognize_intent")
|
|
workflow.add_edge("load_skill_registry", "recognize_intent")
|
|
|
workflow.add_edge("recognize_intent", "route_intent")
|
|
workflow.add_edge("recognize_intent", "route_intent")
|
|
@@ -90,6 +99,7 @@ class DocumentChatWorkflow:
|
|
|
workflow.add_edge("unsupported", "complete")
|
|
workflow.add_edge("unsupported", "complete")
|
|
|
workflow.add_edge("run_answer_skill", "complete")
|
|
workflow.add_edge("run_answer_skill", "complete")
|
|
|
workflow.add_edge("run_modify_skill", "complete")
|
|
workflow.add_edge("run_modify_skill", "complete")
|
|
|
|
|
+ workflow.add_edge("general_answer", "complete")
|
|
|
workflow.add_edge("error_handler", "complete")
|
|
workflow.add_edge("error_handler", "complete")
|
|
|
workflow.add_edge("complete", END)
|
|
workflow.add_edge("complete", END)
|
|
|
return workflow.compile()
|
|
return workflow.compile()
|
|
@@ -137,12 +147,8 @@ class DocumentChatWorkflow:
|
|
|
try:
|
|
try:
|
|
|
selected_section = state.get("selected_section") or {}
|
|
selected_section = state.get("selected_section") or {}
|
|
|
user_message = (state.get("user_message") or "").strip()
|
|
user_message = (state.get("user_message") or "").strip()
|
|
|
- if not state.get("user_id"):
|
|
|
|
|
- raise ValueError("user_id is required")
|
|
|
|
|
if not user_message:
|
|
if not user_message:
|
|
|
raise ValueError("message is required")
|
|
raise ValueError("message is required")
|
|
|
- if not selected_section.get("index") or not selected_section.get("title"):
|
|
|
|
|
- raise ValueError("selected_section.index and selected_section.title are required")
|
|
|
|
|
if "content" not in selected_section:
|
|
if "content" not in selected_section:
|
|
|
selected_section["content"] = ""
|
|
selected_section["content"] = ""
|
|
|
return {
|
|
return {
|
|
@@ -153,6 +159,17 @@ class DocumentChatWorkflow:
|
|
|
except Exception as exc:
|
|
except Exception as exc:
|
|
|
return self._error_update("validate_input", exc)
|
|
return self._error_update("validate_input", exc)
|
|
|
|
|
|
|
|
|
|
+ def route_after_validate(self, state: DocumentChatState) -> str:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return "error"
|
|
|
|
|
+ selected_section = state.get("selected_section") or {}
|
|
|
|
|
+ has_section = bool(
|
|
|
|
|
+ selected_section.get("code")
|
|
|
|
|
+ or selected_section.get("chapter_level_1")
|
|
|
|
|
+ or selected_section.get("chapter_level_2")
|
|
|
|
|
+ )
|
|
|
|
|
+ return "normal" if has_section else "general"
|
|
|
|
|
+
|
|
|
async def load_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
async def load_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
if state.get("error_message"):
|
|
if state.get("error_message"):
|
|
|
return {}
|
|
return {}
|
|
@@ -382,6 +399,79 @@ class DocumentChatWorkflow:
|
|
|
"current_stage": "unsupported",
|
|
"current_stage": "unsupported",
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ async def general_answer_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ """Respond directly via LLM when no section is selected."""
|
|
|
|
|
+ user_message = state.get("user_message", "")
|
|
|
|
|
+ conversation_history = state.get("conversation_history") or []
|
|
|
|
|
+ project_info = state.get("project_info") or {}
|
|
|
|
|
+
|
|
|
|
|
+ system_prompt = (
|
|
|
|
|
+ "你是施工方案编辑 AI 助手。"
|
|
|
|
|
+ "用户当前未选中任何文档章节,请以通用助手的身份回答问题。"
|
|
|
|
|
+ "你可以介绍自己的能力(如:选中章节后可进行润色、扩写、改写、问答等),"
|
|
|
|
|
+ "也可以回答与施工方案编写相关的通用问题。"
|
|
|
|
|
+ "回答应简洁专业,使用中文。"
|
|
|
|
|
+ )
|
|
|
|
|
+ user_payload = {
|
|
|
|
|
+ "user_message": user_message,
|
|
|
|
|
+ "project_info": project_info,
|
|
|
|
|
+ "conversation_history": conversation_history[-6:],
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ from foundation.ai.agent.generate.model_generate import generate_model_client
|
|
|
|
|
+ from core.document_chat.component.llm_utils import compact_json
|
|
|
|
|
+
|
|
|
|
|
+ full_text_parts: List[str] = []
|
|
|
|
|
+
|
|
|
|
|
+ def _on_chunk(chunk: str):
|
|
|
|
|
+ from langgraph.config import get_stream_writer
|
|
|
|
|
+ try:
|
|
|
|
|
+ writer = get_stream_writer()
|
|
|
|
|
+ writer({"stream_chunk": chunk})
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ async for chunk in generate_model_client.get_model_generate_invoke_stream(
|
|
|
|
|
+ trace_id=state.get("callback_task_id", "general_answer"),
|
|
|
|
|
+ system_prompt=system_prompt,
|
|
|
|
|
+ user_prompt=compact_json(user_payload),
|
|
|
|
|
+ timeout=45,
|
|
|
|
|
+ function_name="general_answer",
|
|
|
|
|
+ ):
|
|
|
|
|
+ _on_chunk(chunk)
|
|
|
|
|
+ full_text_parts.append(chunk)
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ logger.warning(f"[DocumentChat] general_answer stream failed: {exc}, falling back to non-stream")
|
|
|
|
|
+ if not full_text_parts:
|
|
|
|
|
+ response = await generate_model_client.get_model_generate_invoke(
|
|
|
|
|
+ trace_id=state.get("callback_task_id", "general_answer"),
|
|
|
|
|
+ system_prompt=system_prompt,
|
|
|
|
|
+ user_prompt=compact_json(user_payload),
|
|
|
|
|
+ timeout=45,
|
|
|
|
|
+ function_name="general_answer",
|
|
|
|
|
+ )
|
|
|
|
|
+ full_text_parts.append(response or "")
|
|
|
|
|
+
|
|
|
|
|
+ answer = "".join(full_text_parts).strip()
|
|
|
|
|
+ if not answer:
|
|
|
|
|
+ answer = "您好,我是施工方案编辑 AI 助手。选中一个文档章节后,我可以帮您润色、扩写、改写或回答章节相关问题。"
|
|
|
|
|
+
|
|
|
|
|
+ skill_result = DocumentChatSkillOutput(
|
|
|
|
|
+ skill_name="general-answer",
|
|
|
|
|
+ response_type="general_answer",
|
|
|
|
|
+ answer=answer,
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "skill_result": model_to_dict(skill_result),
|
|
|
|
|
+ "response_type": "general_answer",
|
|
|
|
|
+ "current_stage": "general_answer",
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ logger.error(f"[DocumentChat] general_answer_node failed: {exc}", exc_info=True)
|
|
|
|
|
+ return self._error_update("general_answer", exc)
|
|
|
|
|
+
|
|
|
async def run_answer_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
async def run_answer_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
return await self._run_skill(state, "document-answer", "run_answer_skill")
|
|
return await self._run_skill(state, "document-answer", "run_answer_skill")
|
|
|
|
|
|