|
@@ -10,7 +10,11 @@ from foundation.observability.logger.loggering import write_logger as logger
|
|
|
|
|
|
|
|
from core.document_chat.component.conversation_context import ConversationContextBuilder
|
|
from core.document_chat.component.conversation_context import ConversationContextBuilder
|
|
|
from core.document_chat.component.diff_service import DiffService
|
|
from core.document_chat.component.diff_service import DiffService
|
|
|
|
|
+from core.document_chat.component.document_chat_logger import log_document_chat_event
|
|
|
from core.document_chat.component.intent_recognizer import IntentRecognizer
|
|
from core.document_chat.component.intent_recognizer import IntentRecognizer
|
|
|
|
|
+from core.document_chat.component.rerank_service import DocumentChatRerankService
|
|
|
|
|
+from core.document_chat.component.retrieval_quality_gate import RetrievalQualityGate
|
|
|
|
|
+from core.document_chat.component.retrieval_service import DocumentChatRetrievalService
|
|
|
from core.document_chat.component.skill_dispatcher import SkillDispatcher
|
|
from core.document_chat.component.skill_dispatcher import SkillDispatcher
|
|
|
from core.document_chat.component.state_models import DocumentChatState
|
|
from core.document_chat.component.state_models import DocumentChatState
|
|
|
from core.document_chat.schemas import (
|
|
from core.document_chat.schemas import (
|
|
@@ -34,6 +38,9 @@ class DocumentChatWorkflow:
|
|
|
self.skill_dispatcher = SkillDispatcher()
|
|
self.skill_dispatcher = SkillDispatcher()
|
|
|
self.diff_service = DiffService()
|
|
self.diff_service = DiffService()
|
|
|
self.context_builder = ConversationContextBuilder()
|
|
self.context_builder = ConversationContextBuilder()
|
|
|
|
|
+ self.retrieval_service = DocumentChatRetrievalService()
|
|
|
|
|
+ self.rerank_service = DocumentChatRerankService(self.retrieval_service.config)
|
|
|
|
|
+ self.quality_gate = RetrievalQualityGate(self.retrieval_service.config)
|
|
|
self.graph = None
|
|
self.graph = None
|
|
|
|
|
|
|
|
def build_graph(self):
|
|
def build_graph(self):
|
|
@@ -43,6 +50,10 @@ class DocumentChatWorkflow:
|
|
|
workflow.add_node("load_skill_registry", self.load_skill_registry_node)
|
|
workflow.add_node("load_skill_registry", self.load_skill_registry_node)
|
|
|
workflow.add_node("recognize_intent", self.recognize_intent_node)
|
|
workflow.add_node("recognize_intent", self.recognize_intent_node)
|
|
|
workflow.add_node("route_intent", self.route_intent_node)
|
|
workflow.add_node("route_intent", self.route_intent_node)
|
|
|
|
|
+ workflow.add_node("build_retrieval_query", self.build_retrieval_query_node)
|
|
|
|
|
+ workflow.add_node("vector_recall", self.vector_recall_node)
|
|
|
|
|
+ workflow.add_node("rerank_context", self.rerank_context_node)
|
|
|
|
|
+ workflow.add_node("quality_gate", self.quality_gate_node)
|
|
|
workflow.add_node("clarify", self.clarify_node)
|
|
workflow.add_node("clarify", self.clarify_node)
|
|
|
workflow.add_node("unsupported", self.unsupported_node)
|
|
workflow.add_node("unsupported", self.unsupported_node)
|
|
|
workflow.add_node("run_answer_skill", self.run_answer_skill_node)
|
|
workflow.add_node("run_answer_skill", self.run_answer_skill_node)
|
|
@@ -62,6 +73,18 @@ class DocumentChatWorkflow:
|
|
|
{
|
|
{
|
|
|
"clarify": "clarify",
|
|
"clarify": "clarify",
|
|
|
"unsupported": "unsupported",
|
|
"unsupported": "unsupported",
|
|
|
|
|
+ "answer": "build_retrieval_query",
|
|
|
|
|
+ "modify": "build_retrieval_query",
|
|
|
|
|
+ "error": "error_handler",
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ workflow.add_edge("build_retrieval_query", "vector_recall")
|
|
|
|
|
+ workflow.add_edge("vector_recall", "rerank_context")
|
|
|
|
|
+ workflow.add_edge("rerank_context", "quality_gate")
|
|
|
|
|
+ workflow.add_conditional_edges(
|
|
|
|
|
+ "quality_gate",
|
|
|
|
|
+ self.route_after_retrieval,
|
|
|
|
|
+ {
|
|
|
"answer": "run_answer_skill",
|
|
"answer": "run_answer_skill",
|
|
|
"modify": "run_modify_skill",
|
|
"modify": "run_modify_skill",
|
|
|
"error": "error_handler",
|
|
"error": "error_handler",
|
|
@@ -81,9 +104,9 @@ class DocumentChatWorkflow:
|
|
|
self.graph = self.build_graph()
|
|
self.graph = self.build_graph()
|
|
|
return self.graph
|
|
return self.graph
|
|
|
|
|
|
|
|
- async def run(self, request: DocumentChatRequest, callback_task_id: Optional[str] = None) -> DocumentChatState:
|
|
|
|
|
|
|
+ def build_initial_state(self, request: DocumentChatRequest, callback_task_id: Optional[str] = None) -> DocumentChatState:
|
|
|
task_id = callback_task_id or f"doc_chat_{uuid.uuid4().hex[:12]}"
|
|
task_id = callback_task_id or f"doc_chat_{uuid.uuid4().hex[:12]}"
|
|
|
- initial_state: DocumentChatState = {
|
|
|
|
|
|
|
+ return {
|
|
|
"callback_task_id": task_id,
|
|
"callback_task_id": task_id,
|
|
|
"user_id": request.user_id,
|
|
"user_id": request.user_id,
|
|
|
"conversation_id": request.conversation_id,
|
|
"conversation_id": request.conversation_id,
|
|
@@ -94,6 +117,13 @@ class DocumentChatWorkflow:
|
|
|
"conversation_history": request.conversation_history,
|
|
"conversation_history": request.conversation_history,
|
|
|
"user_message": request.message,
|
|
"user_message": request.message,
|
|
|
"skill_registry": [],
|
|
"skill_registry": [],
|
|
|
|
|
+ "retrieval_query": None,
|
|
|
|
|
+ "retrieval_method": None,
|
|
|
|
|
+ "retrieval_candidates": [],
|
|
|
|
|
+ "reranked_references": [],
|
|
|
|
|
+ "approved_references": [],
|
|
|
|
|
+ "retrieval_status": None,
|
|
|
|
|
+ "retrieval_metrics": {},
|
|
|
"intent_result": None,
|
|
"intent_result": None,
|
|
|
"skill_result": None,
|
|
"skill_result": None,
|
|
|
"diff_result": None,
|
|
"diff_result": None,
|
|
@@ -104,6 +134,9 @@ class DocumentChatWorkflow:
|
|
|
"warnings": [],
|
|
"warnings": [],
|
|
|
"messages": [],
|
|
"messages": [],
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ async def run(self, request: DocumentChatRequest, callback_task_id: Optional[str] = None) -> DocumentChatState:
|
|
|
|
|
+ initial_state = self.build_initial_state(request, callback_task_id)
|
|
|
return await self.get_graph().ainvoke(initial_state)
|
|
return await self.get_graph().ainvoke(initial_state)
|
|
|
|
|
|
|
|
async def validate_input_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
async def validate_input_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
@@ -171,14 +204,160 @@ class DocumentChatWorkflow:
|
|
|
return "error"
|
|
return "error"
|
|
|
if intent.needs_clarification or intent.intent == "clarify" or intent.confidence < 0.65:
|
|
if intent.needs_clarification or intent.intent == "clarify" or intent.confidence < 0.65:
|
|
|
return "clarify"
|
|
return "clarify"
|
|
|
- if intent.intent == "unsupported":
|
|
|
|
|
- return "unsupported"
|
|
|
|
|
if intent.skill_name == "document-answer":
|
|
if intent.skill_name == "document-answer":
|
|
|
return "answer"
|
|
return "answer"
|
|
|
if intent.skill_name == "document-modify":
|
|
if intent.skill_name == "document-modify":
|
|
|
return "modify"
|
|
return "modify"
|
|
|
|
|
+ if intent.intent == "unsupported":
|
|
|
|
|
+ return "unsupported"
|
|
|
|
|
+ return "error"
|
|
|
|
|
+
|
|
|
|
|
+ def route_after_retrieval(self, state: DocumentChatState) -> str:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return "error"
|
|
|
|
|
+ intent_data = state.get("intent_result") or {}
|
|
|
|
|
+ skill_name = intent_data.get("skill_name")
|
|
|
|
|
+ if skill_name == "document-answer":
|
|
|
|
|
+ return "answer"
|
|
|
|
|
+ if skill_name == "document-modify":
|
|
|
|
|
+ return "modify"
|
|
|
return "error"
|
|
return "error"
|
|
|
|
|
|
|
|
|
|
+ async def build_retrieval_query_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ query = self.retrieval_service.build_query(state)
|
|
|
|
|
+ log_document_chat_event(
|
|
|
|
|
+ "rag_query_built",
|
|
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
|
|
+ {
|
|
|
|
|
+ "retrieval_query": query,
|
|
|
|
|
+ "intent_result": state.get("intent_result"),
|
|
|
|
|
+ "selected_section": state.get("selected_section"),
|
|
|
|
|
+ "project_info": state.get("project_info"),
|
|
|
|
|
+ "document_context": state.get("document_context"),
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "retrieval_query": query,
|
|
|
|
|
+ "current_stage": "build_retrieval_query",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def vector_recall_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ result = self.retrieval_service.recall(state)
|
|
|
|
|
+ log_document_chat_event(
|
|
|
|
|
+ "rag_recall_completed",
|
|
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
|
|
+ {
|
|
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
|
|
+ "retrieval_method": result.get("retrieval_method"),
|
|
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": result.get("retrieval_metrics") or {},
|
|
|
|
|
+ "retrieval_candidates": result.get("retrieval_candidates") or [],
|
|
|
|
|
+ "warnings": result.get("warnings") or [],
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "retrieval_candidates": result.get("retrieval_candidates") or [],
|
|
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_method": result.get("retrieval_method"),
|
|
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, result.get("retrieval_metrics") or {}),
|
|
|
|
|
+ "warnings": self._append_warnings(state, result.get("warnings") or []),
|
|
|
|
|
+ "current_stage": "vector_recall",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def rerank_context_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ if state.get("retrieval_status") != "recalled":
|
|
|
|
|
+ log_document_chat_event(
|
|
|
|
|
+ "rag_rerank_skipped",
|
|
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
|
|
+ {
|
|
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
|
|
+ "retrieval_status": state.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": state.get("retrieval_metrics") or {},
|
|
|
|
|
+ "warnings": state.get("warnings") or [],
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "reranked_references": [],
|
|
|
|
|
+ "approved_references": [],
|
|
|
|
|
+ "current_stage": "rerank_context",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ result = self.rerank_service.rerank(
|
|
|
|
|
+ query=state.get("retrieval_query") or "",
|
|
|
|
|
+ candidates=state.get("retrieval_candidates") or [],
|
|
|
|
|
+ )
|
|
|
|
|
+ log_document_chat_event(
|
|
|
|
|
+ "rag_rerank_completed",
|
|
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
|
|
+ {
|
|
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": result.get("retrieval_metrics") or {},
|
|
|
|
|
+ "retrieval_candidates": state.get("retrieval_candidates") or [],
|
|
|
|
|
+ "reranked_references": result.get("reranked_references") or [],
|
|
|
|
|
+ "warnings": result.get("warnings") or [],
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "reranked_references": result.get("reranked_references") or [],
|
|
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, result.get("retrieval_metrics") or {}),
|
|
|
|
|
+ "warnings": self._append_warnings(state, result.get("warnings") or []),
|
|
|
|
|
+ "current_stage": "rerank_context",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async def quality_gate_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
|
|
+ if state.get("error_message"):
|
|
|
|
|
+ return {}
|
|
|
|
|
+ if state.get("retrieval_status") != "reranked":
|
|
|
|
|
+ log_document_chat_event(
|
|
|
|
|
+ "rag_quality_gate_skipped",
|
|
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
|
|
+ {
|
|
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
|
|
+ "retrieval_status": state.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, {"approved_count": 0}),
|
|
|
|
|
+ "reranked_references": state.get("reranked_references") or [],
|
|
|
|
|
+ "warnings": state.get("warnings") or [],
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "approved_references": [],
|
|
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, {"approved_count": 0}),
|
|
|
|
|
+ "current_stage": "quality_gate",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ result = self.quality_gate.apply(state.get("reranked_references") or [])
|
|
|
|
|
+ log_document_chat_event(
|
|
|
|
|
+ "rag_quality_gate_completed",
|
|
|
|
|
+ state.get("callback_task_id", ""),
|
|
|
|
|
+ {
|
|
|
|
|
+ "retrieval_query": state.get("retrieval_query"),
|
|
|
|
|
+ "retrieval_method": state.get("retrieval_method"),
|
|
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": result.get("retrieval_metrics") or {},
|
|
|
|
|
+ "reranked_references": state.get("reranked_references") or [],
|
|
|
|
|
+ "approved_references": result.get("approved_references") or [],
|
|
|
|
|
+ "warnings": result.get("warnings") or [],
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "approved_references": result.get("approved_references") or [],
|
|
|
|
|
+ "retrieval_status": result.get("retrieval_status"),
|
|
|
|
|
+ "retrieval_metrics": self._merge_metrics(state, result.get("retrieval_metrics") or {}),
|
|
|
|
|
+ "warnings": self._append_warnings(state, result.get("warnings") or []),
|
|
|
|
|
+ "current_stage": "quality_gate",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
async def clarify_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
async def clarify_node(self, state: DocumentChatState) -> Dict[str, Any]:
|
|
|
intent = IntentResult(**(state.get("intent_result") or {"intent": "clarify"}))
|
|
intent = IntentResult(**(state.get("intent_result") or {"intent": "clarify"}))
|
|
|
question = intent.clarification_question or "请补充说明希望 AI 对当前章节做什么。"
|
|
question = intent.clarification_question or "请补充说明希望 AI 对当前章节做什么。"
|
|
@@ -270,6 +449,7 @@ class DocumentChatWorkflow:
|
|
|
warnings.extend(intent_result.get("warnings") or [])
|
|
warnings.extend(intent_result.get("warnings") or [])
|
|
|
|
|
|
|
|
response_type = state.get("response_type") or skill_result.get("response_type") or "error"
|
|
response_type = state.get("response_type") or skill_result.get("response_type") or "error"
|
|
|
|
|
+ approved_references = state.get("approved_references") or []
|
|
|
return DocumentChatData(
|
|
return DocumentChatData(
|
|
|
callback_task_id=state.get("callback_task_id", ""),
|
|
callback_task_id=state.get("callback_task_id", ""),
|
|
|
response_type=response_type,
|
|
response_type=response_type,
|
|
@@ -281,7 +461,9 @@ class DocumentChatWorkflow:
|
|
|
diff=diff_result.get("diff") or [],
|
|
diff=diff_result.get("diff") or [],
|
|
|
diff_granularity=diff_result.get("diff_granularity"),
|
|
diff_granularity=diff_result.get("diff_granularity"),
|
|
|
change_summary=skill_result.get("change_summary") or [],
|
|
change_summary=skill_result.get("change_summary") or [],
|
|
|
- references=skill_result.get("references") or [],
|
|
|
|
|
|
|
+ references=approved_references,
|
|
|
|
|
+ retrieval_status=state.get("retrieval_status"),
|
|
|
|
|
+ retrieval_metrics=self._merge_metrics(state, {"retrieval_method": state.get("retrieval_method")}),
|
|
|
warnings=warnings,
|
|
warnings=warnings,
|
|
|
selected_section={
|
|
selected_section={
|
|
|
"index": selected_section.get("index", ""),
|
|
"index": selected_section.get("index", ""),
|
|
@@ -292,18 +474,35 @@ class DocumentChatWorkflow:
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def _build_skill_input(self, state: DocumentChatState) -> DocumentChatSkillInput:
|
|
def _build_skill_input(self, state: DocumentChatState) -> DocumentChatSkillInput:
|
|
|
|
|
+ document_context = dict(state.get("document_context") or {})
|
|
|
|
|
+ document_context["references"] = state.get("approved_references") or []
|
|
|
return DocumentChatSkillInput(
|
|
return DocumentChatSkillInput(
|
|
|
user_id=state.get("user_id", ""),
|
|
user_id=state.get("user_id", ""),
|
|
|
conversation_id=state.get("conversation_id"),
|
|
conversation_id=state.get("conversation_id"),
|
|
|
task_id=state.get("task_id"),
|
|
task_id=state.get("task_id"),
|
|
|
project_info=state.get("project_info") or {},
|
|
project_info=state.get("project_info") or {},
|
|
|
selected_section=SelectedSection(**(state.get("selected_section") or {})),
|
|
selected_section=SelectedSection(**(state.get("selected_section") or {})),
|
|
|
- document_context=DocumentContext(**(state.get("document_context") or {})),
|
|
|
|
|
|
|
+ document_context=DocumentContext(**document_context),
|
|
|
conversation_history=state.get("conversation_history") or [],
|
|
conversation_history=state.get("conversation_history") or [],
|
|
|
user_message=state.get("user_message", ""),
|
|
user_message=state.get("user_message", ""),
|
|
|
intent_result=IntentResult(**(state.get("intent_result") or {})),
|
|
intent_result=IntentResult(**(state.get("intent_result") or {})),
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _append_warnings(state: DocumentChatState, new_warnings: list) -> list:
|
|
|
|
|
+ warnings = list(state.get("warnings") or [])
|
|
|
|
|
+ for warning in new_warnings:
|
|
|
|
|
+ warning = str(warning).strip()
|
|
|
|
|
+ if warning and warning not in warnings:
|
|
|
|
|
+ warnings.append(warning)
|
|
|
|
|
+ return warnings
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _merge_metrics(state: DocumentChatState, new_metrics: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ metrics = dict(state.get("retrieval_metrics") or {})
|
|
|
|
|
+ metrics.update(new_metrics or {})
|
|
|
|
|
+ return metrics
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def _error_update(stage: str, exc: Exception) -> Dict[str, Any]:
|
|
def _error_update(stage: str, exc: Exception) -> Dict[str, Any]:
|
|
|
return {
|
|
return {
|