Przeglądaj źródła

feat(增加通用标准回答)

tangle 2 dni temu
rodzic
commit
1406cc048e

+ 8 - 8
core/document_chat/schemas.py

@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
 
 
 
 
 class SelectedSection(BaseModel):
 class SelectedSection(BaseModel):
-    index: str = Field(..., description="Section index, for example 2.1")
-    title: str = Field(..., description="Section title")
+    index: str = Field(default="", description="Section index, for example 2.1")
+    title: str = Field(default="", description="Section title")
     content: str = Field(default="", description="Current section content from the editor")
     content: str = Field(default="", description="Current section content from the editor")
     code: str = Field(default="", description="Section code")
     code: str = Field(default="", description="Section code")
     chapter_level_1: str = Field(default="", description="Optional primary chapter classification for retrieval")
     chapter_level_1: str = Field(default="", description="Optional primary chapter classification for retrieval")
@@ -24,9 +24,9 @@ class DocumentContext(BaseModel):
 
 
 
 
 class DocumentChatRequest(BaseModel):
 class DocumentChatRequest(BaseModel):
-    user_id: str
+    user_id: Optional[str] = None
     message: str = Field(..., min_length=1, description="User message")
     message: str = Field(..., min_length=1, description="User message")
-    selected_section: SelectedSection
+    selected_section: Optional[SelectedSection] = Field(default=None, description="Selected section; null or empty for general questions")
     conversation_id: Optional[str] = None
     conversation_id: Optional[str] = None
     task_id: Optional[str] = None
     task_id: Optional[str] = None
     project_info: Dict[str, Any] = Field(default_factory=dict)
     project_info: Dict[str, Any] = Field(default_factory=dict)
@@ -52,9 +52,9 @@ class IntentResult(BaseModel):
 
 
 
 
 class DocumentChatSkillInput(BaseModel):
 class DocumentChatSkillInput(BaseModel):
-    user_id: str
+    user_id: Optional[str] = None
     user_message: str
     user_message: str
-    selected_section: SelectedSection
+    selected_section: Optional[SelectedSection] = None
     intent_result: IntentResult
     intent_result: IntentResult
     conversation_id: Optional[str] = None
     conversation_id: Optional[str] = None
     task_id: Optional[str] = None
     task_id: Optional[str] = None
@@ -65,7 +65,7 @@ class DocumentChatSkillInput(BaseModel):
 
 
 class DocumentChatSkillOutput(BaseModel):
 class DocumentChatSkillOutput(BaseModel):
     skill_name: str
     skill_name: str
-    response_type: Literal["answer", "proposal", "clarify", "unsupported", "error"]
+    response_type: Literal["answer", "proposal", "clarify", "unsupported", "general_answer", "error"]
     answer: Optional[str] = None
     answer: Optional[str] = None
     old_content: Optional[str] = None
     old_content: Optional[str] = None
     proposed_content: Optional[str] = None
     proposed_content: Optional[str] = None
@@ -89,7 +89,7 @@ class DiffResult(BaseModel):
 
 
 class DocumentChatData(BaseModel):
 class DocumentChatData(BaseModel):
     callback_task_id: str
     callback_task_id: str
-    response_type: Literal["answer", "proposal", "clarify", "unsupported", "error"]
+    response_type: Literal["answer", "proposal", "clarify", "unsupported", "general_answer", "error"]
     intent_result: Optional[Dict[str, Any]] = None
     intent_result: Optional[Dict[str, Any]] = None
     answer: Optional[str] = None
     answer: Optional[str] = None
     proposed_content: Optional[str] = None
     proposed_content: Optional[str] = None

+ 96 - 6
core/document_chat/workflows/document_chat_workflow.py

@@ -2,7 +2,7 @@
 """LangGraph workflow for document chat."""
 """LangGraph workflow for document chat."""
 
 
 import uuid
 import uuid
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 
 from langgraph.graph import END, StateGraph
 from langgraph.graph import END, StateGraph
 
 
@@ -55,11 +55,20 @@ class DocumentChatWorkflow:
         workflow.add_node("unsupported", self.unsupported_node)
         workflow.add_node("unsupported", self.unsupported_node)
         workflow.add_node("run_answer_skill", self.run_answer_skill_node)
         workflow.add_node("run_answer_skill", self.run_answer_skill_node)
         workflow.add_node("run_modify_skill", self.run_modify_skill_node)
         workflow.add_node("run_modify_skill", self.run_modify_skill_node)
+        workflow.add_node("general_answer", self.general_answer_node)
         workflow.add_node("error_handler", self.error_handler_node)
         workflow.add_node("error_handler", self.error_handler_node)
         workflow.add_node("complete", self.complete_node)
         workflow.add_node("complete", self.complete_node)
 
 
         workflow.set_entry_point("validate_input")
         workflow.set_entry_point("validate_input")
-        workflow.add_edge("validate_input", "load_context")
+        workflow.add_conditional_edges(
+            "validate_input",
+            self.route_after_validate,
+            {
+                "general": "general_answer",
+                "normal": "load_context",
+                "error": "error_handler",
+            },
+        )
         workflow.add_edge("load_context", "load_skill_registry")
         workflow.add_edge("load_context", "load_skill_registry")
         workflow.add_edge("load_skill_registry", "recognize_intent")
         workflow.add_edge("load_skill_registry", "recognize_intent")
         workflow.add_edge("recognize_intent", "route_intent")
         workflow.add_edge("recognize_intent", "route_intent")
@@ -90,6 +99,7 @@ class DocumentChatWorkflow:
         workflow.add_edge("unsupported", "complete")
         workflow.add_edge("unsupported", "complete")
         workflow.add_edge("run_answer_skill", "complete")
         workflow.add_edge("run_answer_skill", "complete")
         workflow.add_edge("run_modify_skill", "complete")
         workflow.add_edge("run_modify_skill", "complete")
+        workflow.add_edge("general_answer", "complete")
         workflow.add_edge("error_handler", "complete")
         workflow.add_edge("error_handler", "complete")
         workflow.add_edge("complete", END)
         workflow.add_edge("complete", END)
         return workflow.compile()
         return workflow.compile()
@@ -137,12 +147,8 @@ class DocumentChatWorkflow:
         try:
         try:
             selected_section = state.get("selected_section") or {}
             selected_section = state.get("selected_section") or {}
             user_message = (state.get("user_message") or "").strip()
             user_message = (state.get("user_message") or "").strip()
-            if not state.get("user_id"):
-                raise ValueError("user_id is required")
             if not user_message:
             if not user_message:
                 raise ValueError("message is required")
                 raise ValueError("message is required")
-            if not selected_section.get("index") or not selected_section.get("title"):
-                raise ValueError("selected_section.index and selected_section.title are required")
             if "content" not in selected_section:
             if "content" not in selected_section:
                 selected_section["content"] = ""
                 selected_section["content"] = ""
             return {
             return {
@@ -153,6 +159,17 @@ class DocumentChatWorkflow:
         except Exception as exc:
         except Exception as exc:
             return self._error_update("validate_input", exc)
             return self._error_update("validate_input", exc)
 
 
+    def route_after_validate(self, state: DocumentChatState) -> str:
+        if state.get("error_message"):
+            return "error"
+        selected_section = state.get("selected_section") or {}
+        has_section = bool(
+            selected_section.get("code")
+            or selected_section.get("chapter_level_1")
+            or selected_section.get("chapter_level_2")
+        )
+        return "normal" if has_section else "general"
+
     async def load_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
     async def load_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
         if state.get("error_message"):
         if state.get("error_message"):
             return {}
             return {}
@@ -382,6 +399,79 @@ class DocumentChatWorkflow:
             "current_stage": "unsupported",
             "current_stage": "unsupported",
         }
         }
 
 
+    async def general_answer_node(self, state: DocumentChatState) -> Dict[str, Any]:
+        """Respond directly via LLM when no section is selected."""
+        user_message = state.get("user_message", "")
+        conversation_history = state.get("conversation_history") or []
+        project_info = state.get("project_info") or {}
+
+        system_prompt = (
+            "你是施工方案编辑 AI 助手。"
+            "用户当前未选中任何文档章节,请以通用助手的身份回答问题。"
+            "你可以介绍自己的能力(如:选中章节后可进行润色、扩写、改写、问答等),"
+            "也可以回答与施工方案编写相关的通用问题。"
+            "回答应简洁专业,使用中文。"
+        )
+        user_payload = {
+            "user_message": user_message,
+            "project_info": project_info,
+            "conversation_history": conversation_history[-6:],
+        }
+
+        try:
+            from foundation.ai.agent.generate.model_generate import generate_model_client
+            from core.document_chat.component.llm_utils import compact_json
+
+            full_text_parts: List[str] = []
+
+            def _on_chunk(chunk: str):
+                from langgraph.config import get_stream_writer
+                try:
+                    writer = get_stream_writer()
+                    writer({"stream_chunk": chunk})
+                except Exception:
+                    pass
+
+            try:
+                async for chunk in generate_model_client.get_model_generate_invoke_stream(
+                    trace_id=state.get("callback_task_id", "general_answer"),
+                    system_prompt=system_prompt,
+                    user_prompt=compact_json(user_payload),
+                    timeout=45,
+                    function_name="general_answer",
+                ):
+                    _on_chunk(chunk)
+                    full_text_parts.append(chunk)
+            except Exception as exc:
+                logger.warning(f"[DocumentChat] general_answer stream failed: {exc}, falling back to non-stream")
+                if not full_text_parts:
+                    response = await generate_model_client.get_model_generate_invoke(
+                        trace_id=state.get("callback_task_id", "general_answer"),
+                        system_prompt=system_prompt,
+                        user_prompt=compact_json(user_payload),
+                        timeout=45,
+                        function_name="general_answer",
+                    )
+                    full_text_parts.append(response or "")
+
+            answer = "".join(full_text_parts).strip()
+            if not answer:
+                answer = "您好,我是施工方案编辑 AI 助手。选中一个文档章节后,我可以帮您润色、扩写、改写或回答章节相关问题。"
+
+            skill_result = DocumentChatSkillOutput(
+                skill_name="general-answer",
+                response_type="general_answer",
+                answer=answer,
+            )
+            return {
+                "skill_result": model_to_dict(skill_result),
+                "response_type": "general_answer",
+                "current_stage": "general_answer",
+            }
+        except Exception as exc:
+            logger.error(f"[DocumentChat] general_answer_node failed: {exc}", exc_info=True)
+            return self._error_update("general_answer", exc)
+
     async def run_answer_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
     async def run_answer_skill_node(self, state: DocumentChatState) -> Dict[str, Any]:
         return await self._run_skill(state, "document-answer", "run_answer_skill")
         return await self._run_skill(state, "document-answer", "run_answer_skill")
 
 

+ 2 - 1
views/document_chat/views.py

@@ -27,6 +27,7 @@ STAGE_MESSAGES = {
     "rerank_context": "知识库内容检索重排完成",
     "rerank_context": "知识库内容检索重排完成",
     "run_answer_skill": "已生成章节问答结果",
     "run_answer_skill": "已生成章节问答结果",
     "run_modify_skill": "已生成章节修改草案",
     "run_modify_skill": "已生成章节修改草案",
+    "general_answer": "已生成通用回答",
     "error_handler": "流程异常,已进入错误处理",
     "error_handler": "流程异常,已进入错误处理",
 }
 }
 
 
@@ -281,7 +282,7 @@ async def _generate_document_chat_events(
             yield format_sse_event("answer_completed", data_dict)
             yield format_sse_event("answer_completed", data_dict)
         elif data.response_type == "proposal":
         elif data.response_type == "proposal":
             yield format_sse_event("proposal_completed", data_dict)
             yield format_sse_event("proposal_completed", data_dict)
-        elif data.response_type in ("clarify", "unsupported"):
+        elif data.response_type in ("clarify", "unsupported", "general_answer"):
             yield format_sse_event("answer_completed", data_dict)
             yield format_sse_event("answer_completed", data_dict)
         else:
         else:
             yield format_sse_event("error", data_dict)
             yield format_sse_event("error", data_dict)