# -*- coding: utf-8 -*- """Quality-first vector retrieval for document chat.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional import yaml from foundation.observability.logger.loggering import write_logger as logger PROJECT_ROOT = Path(__file__).resolve().parents[3] RETRIEVAL_CONFIG_PATH = PROJECT_ROOT / "config" / "document_chat_retrieval.yaml" @dataclass(frozen=True) class RetrievalConfig: enabled: bool = True child_collection: str = "t_kngs_construction_plan_child" recall_top_k: int = 30 rerank_top_k: int = 8 submit_top_k: int = 3 min_vector_similarity: float = 0.45 min_rerank_score: float = 0.70 min_qualified_count: int = 1 max_reference_chars: int = 4000 max_single_reference_chars: int = 1500 allow_vector_fallback: bool = False allow_unscoped_search: bool = False dense_weight: float = 0.7 sparse_weight: float = 0.3 ranker_type: str = "weighted" warnings: Optional[Dict[str, str]] = None def load_retrieval_config() -> RetrievalConfig: if not RETRIEVAL_CONFIG_PATH.exists(): return RetrievalConfig(warnings=_default_warnings()) with open(RETRIEVAL_CONFIG_PATH, "r", encoding="utf-8") as handle: raw = yaml.safe_load(handle) or {} retrieval = raw.get("retrieval") or {} warnings = raw.get("warnings") or _default_warnings() return RetrievalConfig( enabled=bool(retrieval.get("enabled", True)), child_collection=str(retrieval.get("child_collection") or "t_kngs_construction_plan_child"), recall_top_k=_to_int(retrieval.get("recall_top_k"), 30), rerank_top_k=_to_int(retrieval.get("rerank_top_k"), 8), submit_top_k=_to_int(retrieval.get("submit_top_k"), 3), min_vector_similarity=_to_float(retrieval.get("min_vector_similarity"), 0.45), min_rerank_score=_to_float(retrieval.get("min_rerank_score"), 0.70), min_qualified_count=_to_int(retrieval.get("min_qualified_count"), 1), max_reference_chars=_to_int(retrieval.get("max_reference_chars"), 4000), max_single_reference_chars=_to_int(retrieval.get("max_single_reference_chars"), 1500), allow_vector_fallback=bool(retrieval.get("allow_vector_fallback", False)), allow_unscoped_search=bool(retrieval.get("allow_unscoped_search", False)), dense_weight=_to_float(retrieval.get("dense_weight"), 0.7), sparse_weight=_to_float(retrieval.get("sparse_weight"), 0.3), ranker_type=str(retrieval.get("ranker_type") or "weighted"), warnings=warnings, ) class DocumentChatRetrievalService: """Build retrieval queries and fetch quality candidates. Retrieval is intentionally conservative: when no reliable scope is present and unscoped search is disabled, it returns no candidates. """ def __init__(self, config: Optional[RetrievalConfig] = None): self.config = config or load_retrieval_config() def build_query(self, state: Dict[str, Any]) -> str: selected_section = state.get("selected_section") or {} project_info = state.get("project_info") or {} intent_result = state.get("intent_result") or {} section_content = str(selected_section.get("content") or "") section_preview = section_content[:1000] parts = [ f"项目名称:{project_info.get('project_name') or project_info.get('name') or ''}", f"工程类型:{project_info.get('engineering_type') or project_info.get('project_type') or ''}", f"施工位置:{project_info.get('construct_location') or project_info.get('location') or ''}", f"章节:{selected_section.get('index', '')} {selected_section.get('title', '')}", f"用户需求:{state.get('user_message') or ''}", f"归一化需求:{intent_result.get('normalized_instruction') or ''}", f"当前章节摘要:{section_preview}", ] return "\n".join(part for part in parts if part.split(":", 1)[-1].strip()) def recall(self, state: Dict[str, Any]) -> Dict[str, Any]: if not self.config.enabled: return self._empty_result("disabled", [], retrieval_method="disabled") query = str(state.get("retrieval_query") or "").strip() if not query: return self._empty_result("no_recall", [self._warning("no_recall")], retrieval_method="empty_query") scope = self._extract_scope(state) if not self._has_reliable_scope(scope) and not self.config.allow_unscoped_search: return self._empty_result( "no_scope", [self._warning("no_scope")], retrieval_method="no_scope", retrieval_scope=scope, ) try: if scope.get("chapter_level_1") and scope.get("chapter_level_2"): retrieval_method = "chapter_similarity" candidates = self._recall_by_chapter(scope, query) else: retrieval_method = "milvus_hybrid_vector" candidates = self._recall_by_vector(scope, query) except Exception as exc: logger.warning(f"[DocumentChat] retrieval failed: {exc}", exc_info=True) return self._empty_result( "no_recall", [self._warning("no_recall")], retrieval_method=retrieval_method if "retrieval_method" in locals() else "unknown", retrieval_scope=scope, ) candidates = self._clean_candidates(candidates) if not candidates: return self._empty_result( "no_recall", [self._warning("no_recall")], retrieval_method=retrieval_method, retrieval_scope=scope, ) metrics = { "recall_count": len(candidates), "max_vector_similarity": max((item.get("vector_similarity", 0.0) for item in candidates), default=0.0), "scope": {key: value for key, value in scope.items() if value}, "retrieval_method": retrieval_method, } return { "retrieval_candidates": candidates, "retrieval_status": "recalled", "retrieval_method": retrieval_method, "retrieval_metrics": metrics, "warnings": [], } def _recall_by_chapter(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]: from core.construction_write.component.similar_fragment_service import search_similar_fragments rows = search_similar_fragments( level1=str(scope.get("chapter_level_1") or ""), level2=str(scope.get("chapter_level_2") or ""), search_text=query, top_k=self.config.recall_top_k, ) candidates = [] for row in rows: text = str(row.get("text") or "").strip() metadata = { "tenant_id": scope.get("tenant_id") or "", "project_id": scope.get("project_id") or "", "knowledge_base_id": scope.get("knowledge_base_id") or "", "file_name": row.get("file_name") or "", "chapter_level_1": row.get("chapter_level_1") or scope.get("chapter_level_1") or "", "chapter_level_2": row.get("chapter_level_2") or scope.get("chapter_level_2") or "", "parent_count": row.get("parent_count", 0), "source_scope_valid": True, } candidates.append( { "text": text, "source": metadata.get("file_name") or "向量知识库", "vector_similarity": _to_float(row.get("similarity"), 0.0), "metadata": metadata, } ) return candidates def _recall_by_vector(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]: from foundation.database.base.vector.milvus_vector import MilvusVectorManager expr = self._build_filter_expr(scope) if not expr: return [] results = MilvusVectorManager().hybrid_search( param={"collection_name": self.config.child_collection, "expr": expr}, query_text=query, top_k=self.config.recall_top_k, ranker_type=self.config.ranker_type, dense_weight=self.config.dense_weight, sparse_weight=self.config.sparse_weight, ) candidates = [] for row in results: metadata = self._normalize_metadata(row.get("metadata") or {}) source_scope_valid = self._metadata_matches_scope(metadata, scope) metadata["source_scope_valid"] = source_scope_valid candidates.append( { "text": str(row.get("text_content") or "").strip(), "source": metadata.get("file_name") or metadata.get("title") or "向量知识库", "vector_similarity": _to_float(row.get("similarity"), 0.0), "metadata": metadata, } ) return candidates def _extract_scope(self, state: Dict[str, Any]) -> Dict[str, Any]: selected = state.get("selected_section") or {} context = state.get("document_context") or {} project = state.get("project_info") or {} filters = context.get("retrieval_filters") if isinstance(context.get("retrieval_filters"), dict) else {} filters = filters or project.get("retrieval_filters") if isinstance(project.get("retrieval_filters"), dict) else filters def pick(*keys: str) -> str: for source in (selected, context, project, filters or {}): for key in keys: value = source.get(key) if isinstance(source, dict) else None if value not in (None, ""): return str(value).strip() return "" return { "tenant_id": pick("tenant_id"), "project_id": pick("project_id"), "knowledge_base_id": pick("knowledge_base_id", "kb_id"), "engineering_type": pick("engineering_type", "project_type"), "chapter_level_1": pick("chapter_level_1", "level1"), "chapter_level_2": pick("chapter_level_2", "level2"), } @staticmethod def _has_reliable_scope(scope: Dict[str, Any]) -> bool: if scope.get("chapter_level_1") and scope.get("chapter_level_2"): return True return bool(scope.get("tenant_id") or scope.get("project_id") or scope.get("knowledge_base_id")) def _build_filter_expr(self, scope: Dict[str, Any]) -> str: conditions = [] for key in ("tenant_id", "project_id", "knowledge_base_id", "engineering_type", "chapter_level_1", "chapter_level_2"): value = str(scope.get(key) or "").strip() if value: conditions.append(f"{key} == '{_escape_milvus_string(value)}'") return " and ".join(conditions) @staticmethod def _metadata_matches_scope(metadata: Dict[str, Any], scope: Dict[str, Any]) -> bool: required_keys = ["tenant_id", "project_id", "knowledge_base_id", "chapter_level_1", "chapter_level_2"] for key in required_keys: expected = str(scope.get(key) or "").strip() if not expected: continue actual = str(metadata.get(key) or "").strip() if actual and actual != expected: return False return True @staticmethod def _normalize_metadata(metadata: Any) -> Dict[str, Any]: if isinstance(metadata, dict): return dict(metadata) if isinstance(metadata, str) and metadata.strip(): try: loaded = yaml.safe_load(metadata) return dict(loaded) if isinstance(loaded, dict) else {} except Exception: return {} return {} def _clean_candidates(self, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: cleaned = [] seen = set() for item in candidates: text = str(item.get("text") or "").strip() if len(text) < 20: continue dedupe_key = text[:300] if dedupe_key in seen: continue seen.add(dedupe_key) metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} cleaned.append( { "text": text[: self.config.max_single_reference_chars], "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"), "vector_similarity": _to_float(item.get("vector_similarity"), 0.0), "metadata": metadata, } ) cleaned.sort(key=lambda item: item.get("vector_similarity", 0.0), reverse=True) return cleaned[: self.config.recall_top_k] def _empty_result( self, status: str, warnings: List[str], retrieval_method: str = "", retrieval_scope: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: return { "retrieval_candidates": [], "retrieval_status": status, "retrieval_method": retrieval_method, "retrieval_metrics": { "recall_count": 0, "retrieval_method": retrieval_method, "scope": {key: value for key, value in (retrieval_scope or {}).items() if value}, }, "warnings": warnings, } def _warning(self, key: str) -> str: warnings = self.config.warnings or _default_warnings() return warnings.get(key) or "" def _default_warnings() -> Dict[str, str]: return { "no_scope": "缺少可靠的知识库检索范围,本次未引用向量库内容。", "no_recall": "未召回可信知识库内容,本次回答不引用向量库。", "low_confidence": "未找到可信度足够的知识库片段,本次未引用向量库内容。", "rerank_failed": "知识库片段重排不可用,本次未引用向量库内容。", } def _escape_milvus_string(value: str) -> str: return str(value).replace("\\", "\\\\").replace("'", "\\'") def _to_int(value: Any, default: int) -> int: try: return int(value) except (TypeError, ValueError): return default def _to_float(value: Any, default: float) -> float: try: return float(value) except (TypeError, ValueError): return default