# -*- coding: utf-8 -*- """质量优先的多路向量检索服务。 四路召回架构: 1. parent_vector:父表向量检索(主体内容向量) 2. child_locator:子表向量定位 → 反查父行(精确定位片段) 3. tag_keyword:标签关键词匹配(设备型号、标准号等) 4. chapter_similarity:章节相似度检索(同类型章节参考) 合并策略: - RRF(Reciprocal Rank Fusion)融合多路排名 - 按路径加权:parent_vector 1.0, child_locator 0.8, tag 1.2, chapter 0.5 - 多源加分:同一条候选在多个路径中被召回时额外加分 - 标签匹配加分:关键词出现在 tag_list 或文本中时额外加分 - Scope 匹配加分:与当前项目/章节范围一致时额外加分 去重策略: - candidate_key 去重(基于 document_id + parent_id + chunk_id) - 内容哈希去重(同一文件同一文本内容仅保留一条) """ from __future__ import annotations from dataclasses import dataclass from hashlib import md5 from pathlib import Path import re from typing import Any, Callable, Dict, List, Optional import yaml from core.document_chat.component.document_chat_logger import document_chat_logger as logger PROJECT_ROOT = Path(__file__).resolve().parents[3] RETRIEVAL_CONFIG_PATH = PROJECT_ROOT / "config" / "document_chat_retrieval.yaml" @dataclass(frozen=True) class RetrievalConfig: """检索配置(不可变)。各参数含义见下方字段注释。""" enabled: bool = True parent_collection: str = "t_kngs_construction_plan_parent" child_collection: str = "t_kngs_construction_plan_child" # 各路径召回上限 parent_recall_top_k: int = 30 child_recall_top_k: int = 40 tag_recall_top_k: int = 30 chapter_recall_top_k: int = 15 recall_top_k: int = 30 rerank_top_k: int = 8 submit_top_k: int = 3 # 最终送入 LLM prompt 的参考条数上限 # 质量阈值 min_vector_similarity: float = 0.45 min_rerank_score: float = 0.65 # 重排质量门,低于此值被过滤 min_qualified_count: int = 1 # 参考内容长度限制 max_reference_chars: int = 4000 # 所有参考总字符上限 max_single_reference_chars: int = 1500 # 单条参考字符上限 # 降级策略 allow_vector_fallback: bool = False allow_unscoped_search: bool = False # 混合搜索权重(dense=sparse 向量融合) dense_weight: float = 0.7 sparse_weight: float = 0.3 child_dense_weight: float = 0.6 child_sparse_weight: float = 0.4 ranker_type: str = "weighted" # 标签召回 tag_recall_enabled: bool = True tag_terms_limit: int = 8 # RRF 参数 rrf_k: int = 60 # 路径权重 parent_vector_weight: float = 1.0 child_locator_weight: float = 0.8 tag_weight: float = 1.2 chapter_similarity_weight: float = 0.5 # 加分项 tag_exact_bonus: float = 0.08 tag_partial_bonus: float = 0.03 multi_source_bonus: float = 0.02 scope_bonus: float = 0.03 warnings: Optional[Dict[str, str]] = None def load_retrieval_config() -> RetrievalConfig: """从 YAML 配置文件加载检索参数,文件不存在时使用默认值。""" if not RETRIEVAL_CONFIG_PATH.exists(): return RetrievalConfig(warnings=_default_warnings()) with open(RETRIEVAL_CONFIG_PATH, "r", encoding="utf-8") as handle: raw = yaml.safe_load(handle) or {} retrieval = raw.get("retrieval") or {} warnings = raw.get("warnings") or _default_warnings() return RetrievalConfig( enabled=bool(retrieval.get("enabled", True)), parent_collection=str(retrieval.get("parent_collection") or "t_kngs_construction_plan_parent"), child_collection=str(retrieval.get("child_collection") or "t_kngs_construction_plan_child"), parent_recall_top_k=_to_int(retrieval.get("parent_recall_top_k"), 30), child_recall_top_k=_to_int(retrieval.get("child_recall_top_k"), 40), tag_recall_top_k=_to_int(retrieval.get("tag_recall_top_k"), 30), chapter_recall_top_k=_to_int(retrieval.get("chapter_recall_top_k"), 15), recall_top_k=_to_int(retrieval.get("recall_top_k"), 30), rerank_top_k=_to_int(retrieval.get("rerank_top_k"), 8), submit_top_k=_to_int(retrieval.get("submit_top_k"), 3), min_vector_similarity=_to_float(retrieval.get("min_vector_similarity"), 0.45), min_rerank_score=_to_float(retrieval.get("min_rerank_score"), 0.65), min_qualified_count=_to_int(retrieval.get("min_qualified_count"), 1), max_reference_chars=_to_int(retrieval.get("max_reference_chars"), 4000), max_single_reference_chars=_to_int(retrieval.get("max_single_reference_chars"), 1500), allow_vector_fallback=bool(retrieval.get("allow_vector_fallback", False)), allow_unscoped_search=bool(retrieval.get("allow_unscoped_search", False)), dense_weight=_to_float(retrieval.get("dense_weight"), 0.7), sparse_weight=_to_float(retrieval.get("sparse_weight"), 0.3), child_dense_weight=_to_float(retrieval.get("child_dense_weight"), 0.6), child_sparse_weight=_to_float(retrieval.get("child_sparse_weight"), 0.4), ranker_type=str(retrieval.get("ranker_type") or "weighted"), tag_recall_enabled=bool(retrieval.get("tag_recall_enabled", True)), tag_terms_limit=_to_int(retrieval.get("tag_terms_limit"), 8), rrf_k=_to_int(retrieval.get("rrf_k"), 60), parent_vector_weight=_to_float(retrieval.get("parent_vector_weight"), 1.0), child_locator_weight=_to_float(retrieval.get("child_locator_weight"), 0.8), tag_weight=_to_float(retrieval.get("tag_weight"), 1.2), chapter_similarity_weight=_to_float(retrieval.get("chapter_similarity_weight"), 0.5), tag_exact_bonus=_to_float(retrieval.get("tag_exact_bonus"), 0.08), tag_partial_bonus=_to_float(retrieval.get("tag_partial_bonus"), 0.03), multi_source_bonus=_to_float(retrieval.get("multi_source_bonus"), 0.02), scope_bonus=_to_float(retrieval.get("scope_bonus"), 0.03), warnings=warnings, ) class DocumentChatRetrievalService: """构建检索查询,从向量库召回高质量候选。 核心流程: 1. build_query:将用户输入、章节信息、意图拼接为检索 query 2. recall:执行多路召回 → RRF 合并 → 去重 """ # 父表查询输出字段 PARENT_OUTPUT_FIELDS = [ "pk", "text", "document_id", "parent_id", "index", "tag_list", "metadata", "file_name", "chapter_title", "chapter_level_1", "chapter_level_2", "chapter_level_3", ] # 子表查询输出字段 CHILD_OUTPUT_FIELDS = [ "pk", "text", "document_id", "parent_id", "index", "tag_list", "metadata", "file_name", "chapter_title", "chapter_level_1", "chapter_level_2", "chapter_level_3", ] def __init__(self, config: Optional[RetrievalConfig] = None): self.config = config or load_retrieval_config() # ============================================================ # Query 构建 # ============================================================ def build_query(self, state: Dict[str, Any]) -> str: """构建精炼检索 query,避免冗余的项目摘要。 拼接内容: - 用户原始输入 - 意图识别后的规范化指令 - 当前选中章节编号 + 标题 - 提取的关键词(最多 8 个) 去重后截取 120 字符。 """ selected_section = state.get("selected_section") or {} intent_result = state.get("intent_result") or {} keywords = self.build_query_keywords(state) parts = [ state.get("user_message") or "", intent_result.get("normalized_instruction") or "", f"{selected_section.get('index', '')} {selected_section.get('title', '')}".strip(), " ".join(keywords[:8]), ] return _dedupe_join(parts, max_chars=120) def build_query_keywords(self, state: Dict[str, Any], query: Optional[str] = None) -> List[str]: """从多来源提取检索关键词。 来源优先级: 1. 用户输入 2. 意图规范化指令 3. 章节编号 + 标题 4. 章节正文内容(前 500 字) 5. 已拼接的 query 6. 历史对话中用户消息(排除 AI 回复,防止助手建议污染检索) 关键词提取规则见 _extract_retrieval_keywords。 """ selected_section = state.get("selected_section") or {} intent_result = state.get("intent_result") or {} history = state.get("conversation_history") or [] sources = [ state.get("user_message") or "", intent_result.get("normalized_instruction") or "", f"{selected_section.get('index', '')} {selected_section.get('title', '')}", str(selected_section.get("content") or "")[:500], query or "", ] if history: for turn in history[-6:]: if not isinstance(turn, dict): continue role = str(turn.get("role") or turn.get("sender") or "").lower() # 仅取用户消息,跳过 AI 助手回复 if role in ("assistant", "ai", "bot", "model"): continue content = str(turn.get("content") or turn.get("message") or "") if content: sources.append(content) keywords: List[str] = [] seen = set() for text in sources: for keyword in _extract_retrieval_keywords(str(text or "")): normalized = keyword.strip() if not normalized or normalized in seen: continue seen.add(normalized) keywords.append(normalized) if len(keywords) >= 20: return keywords return keywords # ============================================================ # 主召回入口 # ============================================================ def recall(self, state: Dict[str, Any]) -> Dict[str, Any]: """执行多路向量召回,RRF 合并,去重过滤。 返回: - retrieval_candidates:去重后的候选列表 - retrieval_status:recalled / no_scope / no_recall / disabled - retrieval_metrics:各路径召回统计 - retrieval_steps:每步详细日志 """ if not self.config.enabled: return self._empty_result("disabled", [], retrieval_method="disabled") query = str(state.get("retrieval_query") or "").strip() if not query: return self._empty_result("no_recall", [self._warning("no_recall")], retrieval_method="empty_query") # 提取检索范围(项目ID、工程类型、章节分类等) scope = self._extract_scope(state) if not self._has_reliable_scope(scope) and not self.config.allow_unscoped_search: return self._empty_result( "no_scope", [self._warning("no_scope")], retrieval_method="no_scope", retrieval_scope=scope, ) keywords = list(state.get("retrieval_keywords") or self.build_query_keywords(state, query)) retrieval_steps: List[Dict[str, Any]] = [] source_results: Dict[str, List[Dict[str, Any]]] = {} # ===== 四路召回 ===== source_results["parent_vector"] = self._run_recall_path( "parent_vector", lambda: self._recall_by_parent_vector(scope, query), retrieval_steps, query=query, scope=scope, ) source_results["child_locator"] = self._run_recall_path( "child_locator", lambda: self._recall_by_child_locator(scope, query), retrieval_steps, query=query, scope=scope, ) if self.config.tag_recall_enabled: source_results["tag"] = self._run_recall_path( "tag", lambda: self._recall_by_tag(scope, keywords), retrieval_steps, query=" ".join(keywords[: self.config.tag_terms_limit]), scope=scope, ) if scope.get("chapter_level_1") and scope.get("chapter_level_2"): source_results["chapter_similarity"] = self._run_recall_path( "chapter_similarity", lambda: self._recall_by_chapter(scope, query), retrieval_steps, query=query, scope=scope, ) # ===== RRF 合并 + 去重 ===== merged_candidates = self._merge_recall_results(source_results, scope, keywords) cleaned = self._clean_candidates(merged_candidates) retrieval_steps.append( { "step": "rrf_merge", "query": query, "scope": {key: value for key, value in scope.items() if value}, "count": len(merged_candidates), "items": _pack_log_items(merged_candidates), } ) retrieval_steps.append( { "step": "clean_candidates", "count": len(cleaned), "items": _pack_log_items(cleaned), } ) if not cleaned: return self._empty_result( "no_recall", [self._warning("no_recall")], retrieval_method="multi_path_rrf", retrieval_scope=scope, retrieval_steps=retrieval_steps, ) source_counts = {source: len(items or []) for source, items in source_results.items()} # 日志:区分请求的 scope、实际应用的过滤、实际召回的文件 applied_expr = self._build_filter_expr(scope) actual_files = list(dict.fromkeys( str(item.get("source", ""))[:40] for item in cleaned if item.get("source") ))[:5] logger.info( f"[DocumentChat] recall completed: method=multi_path_rrf " f"requested_scope={dict((k, v) for k, v in scope.items() if v)} " f"applied_filter='{applied_expr}' " f"actual_sources={actual_files} " f"source_counts={source_counts} " f"total={len(cleaned)} max_sim={max((item.get('vector_similarity', 0.0) for item in cleaned), default=0.0):.4f}" ) metrics = { "recall_count": len(cleaned), "merged_count": len(merged_candidates), "source_counts": source_counts, "max_vector_similarity": max((item.get("vector_similarity", 0.0) for item in cleaned), default=0.0), "max_fusion_score": max((item.get("fusion_score", 0.0) for item in cleaned), default=0.0), "scope": {key: value for key, value in scope.items() if value}, "retrieval_method": "multi_path_rrf", } return { "retrieval_candidates": cleaned, "retrieval_steps": retrieval_steps, "retrieval_status": "recalled", "retrieval_method": "multi_path_rrf", "retrieval_metrics": metrics, "warnings": [], } def _run_recall_path( self, step: str, func: Callable[[], List[Dict[str, Any]]], retrieval_steps: List[Dict[str, Any]], query: str, scope: Dict[str, Any], ) -> List[Dict[str, Any]]: """执行单路召回,异常时不阻断其他路径。""" try: candidates = func() or [] retrieval_steps.append( { "step": step, "query": query, "scope": {key: value for key, value in scope.items() if value}, "count": len(candidates), "items": _pack_log_items(candidates), } ) return candidates except Exception as exc: logger.warning(f"[DocumentChat] {step} recall failed: {exc}", exc_info=True) retrieval_steps.append( { "step": step, "query": query, "scope": {key: value for key, value in scope.items() if value}, "count": 0, "error": str(exc), "items": [], } ) return [] # ============================================================ # 四路召回具体实现 # ============================================================ def _recall_by_parent_vector(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]: """父表向量检索:Milvus 混合搜索(dense + sparse),直接返回父表文档。""" from foundation.database.base.vector.milvus_vector import MilvusVectorManager expr = self._build_filter_expr(scope) results = MilvusVectorManager().hybrid_search( param={"collection_name": self.config.parent_collection, "expr": expr}, query_text=query, top_k=self.config.parent_recall_top_k, ranker_type=self.config.ranker_type, dense_weight=self.config.dense_weight, sparse_weight=self.config.sparse_weight, ) return [ self._candidate_from_vector_row(row, "parent_vector", scope) for row in results if str(row.get("text_content") or "").strip() ] def _recall_by_child_locator(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]: """子表向量定位 + 父表反查:先用 query 在子表中找到匹配片段, 再通过 parent_id 反查父表行,获取完整的父文档内容。 优势:子表粒度更细,能精确定位到段落级别,然后拉取对应父文档的完整内容。 """ from foundation.database.base.vector.milvus_vector import MilvusVectorManager expr = self._build_filter_expr(scope) child_rows = MilvusVectorManager().hybrid_search( param={"collection_name": self.config.child_collection, "expr": expr}, query_text=query, top_k=self.config.child_recall_top_k, ranker_type=self.config.ranker_type, dense_weight=self.config.child_dense_weight, sparse_weight=self.config.child_sparse_weight, ) # 按 parent_id 分组子表命中结果 child_groups: Dict[str, List[Dict[str, Any]]] = {} for row in child_rows: metadata = self._normalize_row_metadata(row.get("metadata") or {}) parent_id = str(self._metadata_value(metadata, "parent_id") or "").strip() if not parent_id: continue child_groups.setdefault(parent_id, []).append(row) # 通过 parent_id 反查父表 parent_rows = self._fetch_parent_rows_by_parent_ids(list(child_groups.keys()), scope) candidates = [] for parent_row in parent_rows: parent_id = str(parent_row.get("parent_id") or "").strip() matches = child_groups.get(parent_id) or [] max_similarity = max((_to_float(item.get("similarity"), 0.0) for item in matches), default=0.0) candidate = self._candidate_from_parent_row(parent_row, "child_locator", scope, max_similarity) metadata = candidate.setdefault("metadata", {}) metadata["child_hit_count"] = len(matches) # 子表命中次数 metadata["matched_child_texts"] = [ str(item.get("text_content") or "").strip() for item in matches[:5] if str(item.get("text_content") or "").strip() ] candidates.append(candidate) return candidates def _recall_by_tag(self, scope: Dict[str, Any], keywords: List[str]) -> List[Dict[str, Any]]: """标签关键词召回:从关键词中筛选标准号、设备名等专业术语, 在 tag_list 字段上做 LIKE 匹配。 注意:标签召回容易过度匹配,因此结果相似度乘以 0.7 打折。 """ tag_terms = self._select_tag_terms(keywords) if not tag_terms: return [] tag_expr = self._build_tag_expr(tag_terms) scope_expr = self._build_filter_expr(scope) expr = _combine_expr(scope_expr, tag_expr) # 父表标签匹配 parent_rows = self._condition_query( collection_name=self.config.parent_collection, filter_expr=expr, output_fields=self.PARENT_OUTPUT_FIELDS, limit=self.config.tag_recall_top_k, ) candidates = [ self._candidate_from_parent_row(row, "tag", scope, self.config.min_vector_similarity) for row in parent_rows ] # 子表标签匹配,再反查父行 child_rows = self._condition_query( collection_name=self.config.child_collection, filter_expr=expr, output_fields=self.CHILD_OUTPUT_FIELDS, limit=self.config.tag_recall_top_k, ) child_parent_ids = [] child_tag_map: Dict[str, List[str]] = {} for row in child_rows: parent_id = str(row.get("parent_id") or self._metadata_value(row, "parent_id") or "").strip() if not parent_id: continue child_parent_ids.append(parent_id) text = str(row.get("text") or "").strip() if text: child_tag_map.setdefault(parent_id, []).append(text) for row in self._fetch_parent_rows_by_parent_ids(child_parent_ids, scope): parent_id = str(row.get("parent_id") or "").strip() candidate = self._candidate_from_parent_row(row, "tag", scope, self.config.min_vector_similarity) metadata = candidate.setdefault("metadata", {}) metadata["matched_child_texts"] = child_tag_map.get(parent_id, [])[:5] candidates.append(candidate) # 标签结果打折,防止过度匹配 for candidate in candidates: candidate["vector_similarity"] *= 0.7 # 记录匹配的标签术语 for candidate in candidates: metadata = candidate.setdefault("metadata", {}) tag_text = " ".join( str(value or "") for value in ( metadata.get("tag_list"), candidate.get("text"), " ".join(metadata.get("matched_child_texts") or []), ) ) metadata["tag_match_terms"] = [term for term in tag_terms if term and term in tag_text] return candidates def _recall_by_chapter(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]: """章节相似度检索:调用现有 similar_fragment_service, 按 chapter_level_1 + chapter_level_2 限定范围搜索相似片段。 """ from core.construction_write.component.similar_fragment_service import search_similar_fragments rows = search_similar_fragments( level1=str(scope.get("chapter_level_1") or ""), level2=str(scope.get("chapter_level_2") or ""), search_text=query, top_k=self.config.chapter_recall_top_k, ) candidates = [] for row in rows: metadata = { "tenant_id": scope.get("tenant_id") or "", "project_id": scope.get("project_id") or "", "knowledge_base_id": scope.get("knowledge_base_id") or "", "file_name": row.get("file_name") or "", "chapter_level_1": row.get("chapter_level_1") or scope.get("chapter_level_1") or "", "chapter_level_2": row.get("chapter_level_2") or scope.get("chapter_level_2") or "", "parent_count": row.get("parent_count", 0), "source_scope_valid": True, # 通过章节分类限定,天然 scope 匹配 } text = str(row.get("text") or "").strip() candidates.append( { "candidate_key": self._build_candidate_key({**row, "metadata": metadata}, text), "text": text, "source": metadata.get("file_name") or "向量知识库", "vector_similarity": _to_float(row.get("similarity"), 0.0), "fusion_score": 0.0, "metadata": metadata, "source_hits": {}, "retrieval_source": "chapter_similarity", } ) return candidates # ============================================================ # RRF 合并 # ============================================================ def _merge_recall_results( self, source_results: Dict[str, List[Dict[str, Any]]], scope: Dict[str, Any], keywords: List[str], ) -> List[Dict[str, Any]]: """多路召回结果 RRF 融合合并。 融合分数计算: - 基础分:weight / (rrf_k + rank),按路径权重和排名计算 - 多源加分:同一条候选在多个路径中被召回时额外加分 - Scope 加分:与当前项目范围一致时额外加分 - 标签加分:关键词出现在候选文本中时额外加分 """ weights = { "parent_vector": self.config.parent_vector_weight, "child_locator": self.config.child_locator_weight, "tag": self.config.tag_weight, "chapter_similarity": self.config.chapter_similarity_weight, } merged: Dict[str, Dict[str, Any]] = {} for source, candidates in source_results.items(): weight = weights.get(source, 0.0) for rank, item in enumerate(candidates or [], start=1): key = str(item.get("candidate_key") or self._build_candidate_key(item, item.get("text"))) if not key: continue if key not in merged: candidate = dict(item) candidate["candidate_key"] = key candidate["source_hits"] = {} candidate["fusion_score"] = 0.0 merged[key] = candidate current = merged[key] # RRF 公式:累加 weight / (rrf_k + rank) current["fusion_score"] = _to_float(current.get("fusion_score"), 0.0) + weight / (self.config.rrf_k + rank) current["vector_similarity"] = max( _to_float(current.get("vector_similarity"), 0.0), _to_float(item.get("vector_similarity"), 0.0), ) current.setdefault("source_hits", {})[source] = { "rank": rank, "vector_similarity": _to_float(item.get("vector_similarity"), 0.0), } self._merge_metadata(current, item) # 额外加分 for candidate in merged.values(): source_hits = candidate.get("source_hits") if isinstance(candidate.get("source_hits"), dict) else {} metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {} if len(source_hits) > 1: candidate["fusion_score"] += self.config.multi_source_bonus if self._metadata_matches_scope(metadata, scope): candidate["fusion_score"] += self.config.scope_bonus candidate["fusion_score"] += self._calc_tag_bonus(candidate, keywords) return sorted(merged.values(), key=lambda item: item.get("fusion_score", 0.0), reverse=True)[: self.config.recall_top_k] # ============================================================ # Milvus 查询辅助 # ============================================================ def _fetch_parent_rows_by_parent_ids(self, parent_ids: List[str], scope: Dict[str, Any]) -> List[Dict[str, Any]]: """根据 parent_id 列表反查父表行,去重后逐条查询。""" unique_ids = [] seen = set() for parent_id in parent_ids: value = str(parent_id or "").strip() if value and value not in seen: seen.add(value) unique_ids.append(value) rows: List[Dict[str, Any]] = [] scope_expr = self._build_filter_expr(scope) for parent_id in unique_ids[: self.config.recall_top_k]: parent_expr = f"parent_id == '{_escape_milvus_string(parent_id)}'" expr = _combine_expr(parent_expr, scope_expr) rows.extend( self._condition_query( collection_name=self.config.parent_collection, filter_expr=expr, output_fields=self.PARENT_OUTPUT_FIELDS, limit=100, ) ) return rows def _condition_query( self, collection_name: str, filter_expr: str, output_fields: List[str], limit: int, ) -> List[Dict[str, Any]]: """Milvus 条件查询(非向量),按 filter 表达式筛选文档。""" from core.construction_write.component.similar_fragment_service import get_milvus_manager if not filter_expr: return [] return get_milvus_manager().condition_query( collection_name=collection_name, filter=filter_expr, output_fields=output_fields, limit=limit, ) # ============================================================ # 候选构建 # ============================================================ def _candidate_from_vector_row(self, row: Dict[str, Any], source: str, scope: Dict[str, Any]) -> Dict[str, Any]: """从 Milvus 混合搜索结果行构建标准候选。""" metadata = self._normalize_row_metadata(row.get("metadata") or {}) text = str(row.get("text_content") or row.get("text") or "").strip() metadata["source_scope_valid"] = self._metadata_matches_scope(metadata, scope) return { "candidate_key": self._build_candidate_key(metadata, text), "text": text, "source": metadata.get("file_name") or metadata.get("title") or "向量知识库", "vector_similarity": _to_float(row.get("similarity"), 0.0), "fusion_score": 0.0, "metadata": metadata, "source_hits": {}, "retrieval_source": source, } def _candidate_from_parent_row( self, row: Dict[str, Any], source: str, scope: Dict[str, Any], vector_similarity: float, ) -> Dict[str, Any]: """从父表行构建标准候选。""" metadata = self._normalize_row_metadata(row) text = str(row.get("text") or "").strip() metadata["source_scope_valid"] = self._metadata_matches_scope(metadata, scope) return { "candidate_key": self._build_candidate_key(metadata, text), "text": text, "source": metadata.get("file_name") or "向量知识库", "vector_similarity": _to_float(vector_similarity, 0.0), "fusion_score": 0.0, "metadata": metadata, "source_hits": {}, "retrieval_source": source, } # ============================================================ # Scope 提取与过滤 # ============================================================ def _extract_scope(self, state: Dict[str, Any]) -> Dict[str, Any]: """从工作流状态中提取检索范围信息。 按优先级从 selected_section、document_context、project_info、retrieval_filters 中查找字段值,兼容多种字段命名。 """ selected = state.get("selected_section") or {} context = state.get("document_context") or {} project = state.get("project_info") or {} filters = context.get("retrieval_filters") if isinstance(context.get("retrieval_filters"), dict) else {} filters = filters or project.get("retrieval_filters") if isinstance(project.get("retrieval_filters"), dict) else filters def pick(*keys: str) -> str: for source in (selected, context, project, filters or {}): for key in keys: value = source.get(key) if isinstance(source, dict) else None if value not in (None, ""): return str(value).strip() return "" return { "tenant_id": pick("tenant_id"), "project_id": pick("project_id"), "knowledge_base_id": pick("knowledge_base_id", "kb_id"), "engineering_type": pick("engineering_type", "project_type"), "plan_type": pick("plan_type"), "chapter_level_1": pick("chapter_level_1", "level1"), "chapter_level_2": pick("chapter_level_2", "level2"), "chapter_level_3": pick("chapter_level_3", "level3"), } @staticmethod def _has_reliable_scope(scope: Dict[str, Any]) -> bool: """判断是否有足够可靠的 scope 用于限定检索范围。""" if scope.get("chapter_level_1") and scope.get("chapter_level_2"): return True return bool(scope.get("plan_type")) def _build_filter_expr(self, scope: Dict[str, Any]) -> str: """构建 Milvus 过滤表达式,按章节层级限定检索范围。""" conditions = [] for key in ("plan_type", "chapter_level_1", "chapter_level_2", "chapter_level_3"): value = str(scope.get(key) or "").strip() if value: conditions.append(f"{key} == '{_escape_milvus_string(value)}'") return " and ".join(conditions) def _build_tag_expr(self, tag_terms: List[str]) -> str: """构建标签 LIKE 查询表达式。""" conditions = [] for term in tag_terms[: self.config.tag_terms_limit]: conditions.append(f'tag_list like "%{_escape_milvus_string(term)}%"') return " or ".join(conditions) def _select_tag_terms(self, keywords: List[str]) -> List[str]: """从关键词中筛选高价值标签术语。 排除:验收、标准、规范等通用词(几乎匹配所有文档) 优先:标准号(如 TB10212-2012)、设备名(架桥机、龙门吊等) """ generic_terms = { "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "要求", "安全", "环保", "质量", "进度", "交底", } device_terms = {"架桥机", "龙门吊", "吊车", "塔吊", "施工电梯", "挂篮", "支架", "台车"} selected = [] priority = [] # 标准号和设备名优先 seen = set() for keyword in keywords: value = str(keyword or "").strip() if len(value) < 2 or value in seen: continue seen.add(value) if value in generic_terms: continue if re.match(r"[A-Z]{1,3}\d{4,}", value) or value in device_terms: priority.append(value) elif len(selected) < self.config.tag_terms_limit: selected.append(value) return priority + selected @staticmethod def _metadata_matches_scope(metadata: Dict[str, Any], scope: Dict[str, Any]) -> bool: """检查候选 metadata 是否与当前检索 scope 兼容。 不要求所有字段都匹配,仅校验 scope 和 metadata 同时存在且不一致的字段。 """ required_keys = ["tenant_id", "project_id", "knowledge_base_id", "chapter_level_1", "chapter_level_2", "chapter_level_3"] for key in required_keys: expected = str(scope.get(key) or "").strip() if not expected: continue actual = str(metadata.get(key) or "").strip() if actual and actual != expected: return False return True # ============================================================ # Metadata 处理 # ============================================================ def _normalize_row_metadata(self, row_or_metadata: Any) -> Dict[str, Any]: """规范化行数据为统一的 metadata 字典。处理嵌套 metadata 和 YAML 字符串。""" metadata = self._normalize_metadata(row_or_metadata) inner = self._normalize_metadata(metadata.get("metadata")) if metadata.get("metadata") else {} for key, value in inner.items(): metadata.setdefault(key, value) for key in self.PARENT_OUTPUT_FIELDS: if isinstance(row_or_metadata, dict) and row_or_metadata.get(key) not in (None, ""): metadata[key] = row_or_metadata.get(key) return metadata @staticmethod def _normalize_metadata(metadata: Any) -> Dict[str, Any]: """将 metadata 转为字典,支持 YAML 字符串解析。""" if isinstance(metadata, dict): return dict(metadata) if isinstance(metadata, str) and metadata.strip(): try: loaded = yaml.safe_load(metadata) return dict(loaded) if isinstance(loaded, dict) else {} except Exception: return {} return {} @staticmethod def _metadata_value(metadata: Dict[str, Any], key: str) -> Any: """安全获取 metadata 值,支持嵌套 metadata.metadata 和 YAML 字符串。""" if key in metadata: return metadata.get(key) nested = metadata.get("metadata") if isinstance(nested, dict): return nested.get(key) if isinstance(nested, str) and nested.strip(): try: parsed = yaml.safe_load(nested) if isinstance(parsed, dict): return parsed.get(key) except Exception: return None return None def _build_candidate_key(self, metadata: Dict[str, Any], text: Any = "") -> str: """构建候选唯一标识键,按优先级尝试不同字段组合。""" metadata = self._normalize_row_metadata(metadata) document_id = str(self._metadata_value(metadata, "document_id") or "").strip() parent_id = str(self._metadata_value(metadata, "parent_id") or "").strip() chunk_id = str(self._metadata_value(metadata, "chunk_id") or "").strip() chapter_title = str(self._metadata_value(metadata, "chapter_title") or "").strip() index = self._metadata_value(metadata, "index") pk = str(self._metadata_value(metadata, "pk") or "").strip() if document_id and parent_id and chunk_id: return f"{document_id}::{parent_id}::{chunk_id}" if document_id and parent_id and chapter_title and index not in (None, ""): return f"{document_id}::{parent_id}::{chapter_title}::{index}" if pk: return pk if parent_id and chapter_title and index not in (None, ""): return f"{parent_id}::{chapter_title}::{index}" return str(text or "")[:300] def _merge_metadata(self, current: Dict[str, Any], incoming: Dict[str, Any]) -> None: """合并两条候选的 metadata,不覆盖已有非空值。""" current_meta = current.setdefault("metadata", {}) incoming_meta = incoming.get("metadata") if isinstance(incoming.get("metadata"), dict) else {} for key, value in incoming_meta.items(): if key not in current_meta or current_meta.get(key) in (None, "", []): current_meta[key] = value if incoming.get("source") and not current.get("source"): current["source"] = incoming.get("source") # ============================================================ # 加分计算 # ============================================================ def _calc_tag_bonus(self, candidate: Dict[str, Any], keywords: List[str]) -> float: """计算标签匹配加分:关键词精确匹配 tag_list 加分更多,仅出现在文本中加分较少。""" metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {} text = " ".join( str(value or "") for value in ( candidate.get("text"), metadata.get("tag_list"), " ".join(metadata.get("matched_child_texts") or []), ) ) bonus = 0.0 for keyword in self._select_tag_terms(keywords): if not keyword: continue if keyword in str(metadata.get("tag_list") or ""): bonus += self.config.tag_exact_bonus elif keyword in text: bonus += self.config.tag_partial_bonus return bonus # ============================================================ # 候选清理 # ============================================================ def _clean_candidates(self, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """清理候选:过滤过短文本、双重去重(candidate_key + 内容哈希)。 去重策略: 1. candidate_key 去重:相同 document+parent+chunk 视为同一条 2. 内容哈希去重:同一文件同一文本内容(即使路径不同)只保留一条 """ cleaned = [] seen_keys = set() seen_hashes = set() for item in candidates: text = str(item.get("text") or "").strip() if len(text) < 20: continue metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} dedupe_key = str(item.get("candidate_key") or text[:300]) # 内容哈希去重 content_hash = _content_hash(text[:300]) file_name = str(metadata.get("file_name") or "") hash_key = f"{file_name}::{content_hash}" if dedupe_key in seen_keys or hash_key in seen_hashes: continue seen_keys.add(dedupe_key) seen_hashes.add(hash_key) metadata["candidate_key"] = dedupe_key cleaned.append( { "candidate_key": dedupe_key, "text": text[: self.config.max_single_reference_chars], "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"), "vector_similarity": _to_float(item.get("vector_similarity"), 0.0), "fusion_score": _to_float(item.get("fusion_score"), 0.0), "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {}, "metadata": metadata, } ) cleaned.sort(key=lambda item: (item.get("fusion_score", 0.0), item.get("vector_similarity", 0.0)), reverse=True) return cleaned[: self.config.recall_top_k] # ============================================================ # 空结果/告警 # ============================================================ def _empty_result( self, status: str, warnings: List[str], retrieval_method: str = "", retrieval_scope: Optional[Dict[str, Any]] = None, retrieval_steps: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: """构建空召回结果。""" return { "retrieval_candidates": [], "retrieval_steps": retrieval_steps or [], "retrieval_status": status, "retrieval_method": retrieval_method, "retrieval_metrics": { "recall_count": 0, "retrieval_method": retrieval_method, "scope": {key: value for key, value in (retrieval_scope or {}).items() if value}, }, "warnings": warnings, } def _warning(self, key: str) -> str: """获取指定键的告警文案。""" warnings = self.config.warnings or _default_warnings() return warnings.get(key) or "" def _default_warnings() -> Dict[str, str]: return { "no_scope": "缺少可靠的知识库检索范围,本次未引用向量库内容。", "no_recall": "未召回可信知识库内容,本次回答不引用向量库。", "low_confidence": "未找到可信度足够的知识库片段,本次未引用向量库内容。", "rerank_failed": "知识库片段重排不可用,本次未引用向量库内容。", } def _escape_milvus_string(value: str) -> str: """转义 Milvus 字符串中的特殊字符(反斜杠、单引号、双引号)。""" return str(value).replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"') def _combine_expr(*exprs: str) -> str: """用 AND 连接多个过滤表达式,每个子表达式加括号。""" parts = [f"({expr})" for expr in exprs if str(expr or "").strip()] return " and ".join(parts) def _dedupe_join(parts: List[str], max_chars: int) -> str: """去重后拼接文本片段,限制总长度。""" values = [] seen = set() for part in parts: value = re.sub(r"\s+", " ", str(part or "")).strip() if not value or value in seen: continue seen.add(value) values.append(value) return " ".join(values)[:max_chars] def _extract_retrieval_keywords(text: str) -> List[str]: """从文本中提取检索关键词,支持多种模式: 1. 标准号/型号:如 TB10212-2012、φ48.3×3.6 2. 规范名称:《XXX规范》 3. 领域专业术语:验收、架桥机、箱梁等 4. 术语+动作组合:XX验收、XX安装 5. 长词中的领域术语片段 """ if not text: return [] keywords: List[str] = [] # 模式1:标准号/型号(字母+数字,可选连字符) for match in re.findall(r"[A-Za-z]{1,8}\s*\d{2,8}(?:[-—]\d{2,4})?", text): keywords.append(re.sub(r"\s+", "", match).upper()) # 模式2:《XXX》规范名称 for match in re.findall(r"《([^》]{2,40})》", text): keywords.append(match.strip()) # 模式3:领域专业术语 domain_terms = ( "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "吊装", "架桥机", "龙门吊", "吊车", "箱梁", "T梁", "梁板", "钢丝绳", "支座", "地基", "安全装置", "操作证", "合格证", "静载", "动载", "空载", ) for term in domain_terms: if term in text: keywords.append(term) # 模式4:术语+动作组合 for match in re.findall(r"[一-鿿A-Za-z0-9.-]{0,12}(?:验收|标准|规范|检查|检测|试验|安装|拆除|吊装|要求)", text): if 2 <= len(match) <= 20: keywords.append(match) # 模式5:分词后含领域术语的片段 normalized = re.sub(r"[\s,,。;;::、/\\|()\[\]{}<>《》\"'""??]+", " ", text) for token in normalized.split(): token = token.strip() if len(token) < 2 or len(token) > 12: continue if any(term in token for term in domain_terms): keywords.append(token) seen = set() unique = [] for keyword in keywords: keyword = keyword.strip() if keyword and keyword not in seen: seen.add(keyword) unique.append(keyword) return unique def _pack_log_items(items: List[Dict[str, Any]], limit: int = 20, text_limit: int = 1500) -> List[Dict[str, Any]]: """打包候选条目为日志格式,限制条数和文本长度。""" packed = [] for item in (items or [])[:limit]: if not isinstance(item, dict): continue metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} text = str(item.get("text") or item.get("text_content") or item.get("content") or "").strip() packed.append( { "candidate_key": item.get("candidate_key"), "source": item.get("source") or metadata.get("file_name") or "", "text": text[:text_limit], "vector_similarity": _to_float(item.get("vector_similarity", item.get("similarity")), 0.0), "fusion_score": _to_float(item.get("fusion_score"), 0.0), "rerank_score": _to_float(item.get("rerank_score"), 0.0) if "rerank_score" in item else None, "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {}, "metadata": { key: metadata.get(key) for key in ( "document_id", "parent_id", "file_name", "chapter_title", "chapter_level_1", "chapter_level_2", "chapter_level_3", "parent_count", "child_hit_count", "matched_child_texts", "tag_match_terms", "source_scope_valid", ) if metadata.get(key) not in (None, "") }, } ) return packed def _to_int(value: Any, default: int) -> int: """安全整数转换。""" try: return int(value) except (TypeError, ValueError): return default def _to_float(value: Any, default: float = 0.0) -> float: """安全浮点数转换。""" try: return float(value) except (TypeError, ValueError): return default def _content_hash(text: str) -> str: """基于归一化文本的短 MD5 哈希,用于内容去重。""" normalized = re.sub(r"\s+", " ", text.strip().lower()) return md5(normalized.encode("utf-8")).hexdigest()[:12]