|
|
@@ -21,115 +21,37 @@
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-from dataclasses import dataclass
|
|
|
-from hashlib import md5
|
|
|
-from pathlib import Path
|
|
|
-import re
|
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
|
|
-import yaml
|
|
|
-
|
|
|
from core.document_chat.component.document_chat_logger import document_chat_logger as logger
|
|
|
-
|
|
|
-
|
|
|
-PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
|
|
-RETRIEVAL_CONFIG_PATH = PROJECT_ROOT / "config" / "document_chat_retrieval.yaml"
|
|
|
-
|
|
|
-
|
|
|
-@dataclass(frozen=True)
|
|
|
-class RetrievalConfig:
|
|
|
- """检索配置(不可变)。各参数含义见下方字段注释。"""
|
|
|
- enabled: bool = True
|
|
|
- parent_collection: str = "t_kngs_construction_plan_parent"
|
|
|
- child_collection: str = "t_kngs_construction_plan_child"
|
|
|
- # 各路径召回上限
|
|
|
- parent_recall_top_k: int = 30
|
|
|
- child_recall_top_k: int = 40
|
|
|
- tag_recall_top_k: int = 30
|
|
|
- chapter_recall_top_k: int = 15
|
|
|
- recall_top_k: int = 30
|
|
|
- rerank_top_k: int = 8
|
|
|
- submit_top_k: int = 3 # 最终送入 LLM prompt 的参考条数上限
|
|
|
- # 质量阈值
|
|
|
- min_vector_similarity: float = 0.45
|
|
|
- min_rerank_score: float = 0.65 # 重排质量门,低于此值被过滤
|
|
|
- min_qualified_count: int = 1
|
|
|
- # 参考内容长度限制
|
|
|
- max_reference_chars: int = 4000 # 所有参考总字符上限
|
|
|
- max_single_reference_chars: int = 1500 # 单条参考字符上限
|
|
|
- # 降级策略
|
|
|
- allow_vector_fallback: bool = False
|
|
|
- allow_unscoped_search: bool = False
|
|
|
- # 混合搜索权重(dense=sparse 向量融合)
|
|
|
- dense_weight: float = 0.7
|
|
|
- sparse_weight: float = 0.3
|
|
|
- child_dense_weight: float = 0.6
|
|
|
- child_sparse_weight: float = 0.4
|
|
|
- ranker_type: str = "weighted"
|
|
|
- # 标签召回
|
|
|
- tag_recall_enabled: bool = True
|
|
|
- tag_terms_limit: int = 8
|
|
|
- # RRF 参数
|
|
|
- rrf_k: int = 60
|
|
|
- # 路径权重
|
|
|
- parent_vector_weight: float = 1.0
|
|
|
- child_locator_weight: float = 0.8
|
|
|
- tag_weight: float = 1.2
|
|
|
- chapter_similarity_weight: float = 0.5
|
|
|
- # 加分项
|
|
|
- tag_exact_bonus: float = 0.08
|
|
|
- tag_partial_bonus: float = 0.03
|
|
|
- multi_source_bonus: float = 0.02
|
|
|
- scope_bonus: float = 0.03
|
|
|
- warnings: Optional[Dict[str, str]] = None
|
|
|
-
|
|
|
-
|
|
|
-def load_retrieval_config() -> RetrievalConfig:
|
|
|
- """从 YAML 配置文件加载检索参数,文件不存在时使用默认值。"""
|
|
|
- if not RETRIEVAL_CONFIG_PATH.exists():
|
|
|
- return RetrievalConfig(warnings=_default_warnings())
|
|
|
-
|
|
|
- with open(RETRIEVAL_CONFIG_PATH, "r", encoding="utf-8") as handle:
|
|
|
- raw = yaml.safe_load(handle) or {}
|
|
|
-
|
|
|
- retrieval = raw.get("retrieval") or {}
|
|
|
- warnings = raw.get("warnings") or _default_warnings()
|
|
|
- return RetrievalConfig(
|
|
|
- enabled=bool(retrieval.get("enabled", True)),
|
|
|
- parent_collection=str(retrieval.get("parent_collection") or "t_kngs_construction_plan_parent"),
|
|
|
- child_collection=str(retrieval.get("child_collection") or "t_kngs_construction_plan_child"),
|
|
|
- parent_recall_top_k=_to_int(retrieval.get("parent_recall_top_k"), 30),
|
|
|
- child_recall_top_k=_to_int(retrieval.get("child_recall_top_k"), 40),
|
|
|
- tag_recall_top_k=_to_int(retrieval.get("tag_recall_top_k"), 30),
|
|
|
- chapter_recall_top_k=_to_int(retrieval.get("chapter_recall_top_k"), 15),
|
|
|
- recall_top_k=_to_int(retrieval.get("recall_top_k"), 30),
|
|
|
- rerank_top_k=_to_int(retrieval.get("rerank_top_k"), 8),
|
|
|
- submit_top_k=_to_int(retrieval.get("submit_top_k"), 3),
|
|
|
- min_vector_similarity=_to_float(retrieval.get("min_vector_similarity"), 0.45),
|
|
|
- min_rerank_score=_to_float(retrieval.get("min_rerank_score"), 0.65),
|
|
|
- min_qualified_count=_to_int(retrieval.get("min_qualified_count"), 1),
|
|
|
- max_reference_chars=_to_int(retrieval.get("max_reference_chars"), 4000),
|
|
|
- max_single_reference_chars=_to_int(retrieval.get("max_single_reference_chars"), 1500),
|
|
|
- allow_vector_fallback=bool(retrieval.get("allow_vector_fallback", False)),
|
|
|
- allow_unscoped_search=bool(retrieval.get("allow_unscoped_search", False)),
|
|
|
- dense_weight=_to_float(retrieval.get("dense_weight"), 0.7),
|
|
|
- sparse_weight=_to_float(retrieval.get("sparse_weight"), 0.3),
|
|
|
- child_dense_weight=_to_float(retrieval.get("child_dense_weight"), 0.6),
|
|
|
- child_sparse_weight=_to_float(retrieval.get("child_sparse_weight"), 0.4),
|
|
|
- ranker_type=str(retrieval.get("ranker_type") or "weighted"),
|
|
|
- tag_recall_enabled=bool(retrieval.get("tag_recall_enabled", True)),
|
|
|
- tag_terms_limit=_to_int(retrieval.get("tag_terms_limit"), 8),
|
|
|
- rrf_k=_to_int(retrieval.get("rrf_k"), 60),
|
|
|
- parent_vector_weight=_to_float(retrieval.get("parent_vector_weight"), 1.0),
|
|
|
- child_locator_weight=_to_float(retrieval.get("child_locator_weight"), 0.8),
|
|
|
- tag_weight=_to_float(retrieval.get("tag_weight"), 1.2),
|
|
|
- chapter_similarity_weight=_to_float(retrieval.get("chapter_similarity_weight"), 0.5),
|
|
|
- tag_exact_bonus=_to_float(retrieval.get("tag_exact_bonus"), 0.08),
|
|
|
- tag_partial_bonus=_to_float(retrieval.get("tag_partial_bonus"), 0.03),
|
|
|
- multi_source_bonus=_to_float(retrieval.get("multi_source_bonus"), 0.02),
|
|
|
- scope_bonus=_to_float(retrieval.get("scope_bonus"), 0.03),
|
|
|
- warnings=warnings,
|
|
|
- )
|
|
|
+from core.document_chat.retrieval.candidate import (
|
|
|
+ build_candidate_key,
|
|
|
+ clean_candidates,
|
|
|
+ merge_metadata,
|
|
|
+ metadata_value,
|
|
|
+ normalize_metadata,
|
|
|
+ normalize_row_metadata,
|
|
|
+)
|
|
|
+from core.document_chat.retrieval.config import RetrievalConfig, default_warnings, load_retrieval_config
|
|
|
+from core.document_chat.retrieval.fusion import calc_tag_bonus, merge_recall_results
|
|
|
+from core.document_chat.retrieval.query_builder import (
|
|
|
+ build_query as build_retrieval_query,
|
|
|
+ build_query_keywords as build_retrieval_query_keywords,
|
|
|
+)
|
|
|
+from core.document_chat.retrieval.scope import (
|
|
|
+ build_filter_expr,
|
|
|
+ build_tag_expr,
|
|
|
+ extract_scope,
|
|
|
+ has_reliable_scope,
|
|
|
+ metadata_matches_scope,
|
|
|
+ select_tag_terms,
|
|
|
+)
|
|
|
+from core.document_chat.retrieval.utils import (
|
|
|
+ combine_expr as _combine_expr,
|
|
|
+ escape_milvus_string as _escape_milvus_string,
|
|
|
+ pack_log_items as _pack_log_items,
|
|
|
+ to_float as _to_float,
|
|
|
+)
|
|
|
|
|
|
|
|
|
class DocumentChatRetrievalService:
|
|
|
@@ -160,75 +82,21 @@ class DocumentChatRetrievalService:
|
|
|
# Query 构建
|
|
|
# ============================================================
|
|
|
def build_query(self, state: Dict[str, Any]) -> str:
|
|
|
- """构建精炼检索 query,避免冗余的项目摘要。
|
|
|
-
|
|
|
- 拼接内容:
|
|
|
- - 用户原始输入
|
|
|
- - 意图识别后的规范化指令
|
|
|
- - 当前选中章节编号 + 标题
|
|
|
- - 提取的关键词(最多 8 个)
|
|
|
- 去重后截取 120 字符。
|
|
|
- """
|
|
|
- selected_section = state.get("selected_section") or {}
|
|
|
- intent_result = state.get("intent_result") or {}
|
|
|
- keywords = self.build_query_keywords(state)
|
|
|
-
|
|
|
- parts = [
|
|
|
- state.get("user_message") or "",
|
|
|
- intent_result.get("normalized_instruction") or "",
|
|
|
- f"{selected_section.get('index', '')} {selected_section.get('title', '')}".strip(),
|
|
|
- " ".join(keywords[:8]),
|
|
|
- ]
|
|
|
- return _dedupe_join(parts, max_chars=120)
|
|
|
+ """构建精炼检索 query,避免冗余的项目摘要。"""
|
|
|
+ return build_retrieval_query(
|
|
|
+ state,
|
|
|
+ domain_terms=self.config.keyword_domain_terms,
|
|
|
+ action_terms=self.config.keyword_action_terms,
|
|
|
+ )
|
|
|
|
|
|
def build_query_keywords(self, state: Dict[str, Any], query: Optional[str] = None) -> List[str]:
|
|
|
- """从多来源提取检索关键词。
|
|
|
-
|
|
|
- 来源优先级:
|
|
|
- 1. 用户输入
|
|
|
- 2. 意图规范化指令
|
|
|
- 3. 章节编号 + 标题
|
|
|
- 4. 章节正文内容(前 500 字)
|
|
|
- 5. 已拼接的 query
|
|
|
- 6. 历史对话中用户消息(排除 AI 回复,防止助手建议污染检索)
|
|
|
-
|
|
|
- 关键词提取规则见 _extract_retrieval_keywords。
|
|
|
- """
|
|
|
- selected_section = state.get("selected_section") or {}
|
|
|
- intent_result = state.get("intent_result") or {}
|
|
|
- history = state.get("conversation_history") or []
|
|
|
-
|
|
|
- sources = [
|
|
|
- state.get("user_message") or "",
|
|
|
- intent_result.get("normalized_instruction") or "",
|
|
|
- f"{selected_section.get('index', '')} {selected_section.get('title', '')}",
|
|
|
- str(selected_section.get("content") or "")[:500],
|
|
|
- query or "",
|
|
|
- ]
|
|
|
- if history:
|
|
|
- for turn in history[-6:]:
|
|
|
- if not isinstance(turn, dict):
|
|
|
- continue
|
|
|
- role = str(turn.get("role") or turn.get("sender") or "").lower()
|
|
|
- # 仅取用户消息,跳过 AI 助手回复
|
|
|
- if role in ("assistant", "ai", "bot", "model"):
|
|
|
- continue
|
|
|
- content = str(turn.get("content") or turn.get("message") or "")
|
|
|
- if content:
|
|
|
- sources.append(content)
|
|
|
-
|
|
|
- keywords: List[str] = []
|
|
|
- seen = set()
|
|
|
- for text in sources:
|
|
|
- for keyword in _extract_retrieval_keywords(str(text or "")):
|
|
|
- normalized = keyword.strip()
|
|
|
- if not normalized or normalized in seen:
|
|
|
- continue
|
|
|
- seen.add(normalized)
|
|
|
- keywords.append(normalized)
|
|
|
- if len(keywords) >= 20:
|
|
|
- return keywords
|
|
|
- return keywords
|
|
|
+ """从多来源提取检索关键词。"""
|
|
|
+ return build_retrieval_query_keywords(
|
|
|
+ state,
|
|
|
+ query,
|
|
|
+ domain_terms=self.config.keyword_domain_terms,
|
|
|
+ action_terms=self.config.keyword_action_terms,
|
|
|
+ )
|
|
|
|
|
|
# ============================================================
|
|
|
# 主召回入口
|
|
|
@@ -580,59 +448,8 @@ class DocumentChatRetrievalService:
|
|
|
scope: Dict[str, Any],
|
|
|
keywords: List[str],
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
- """多路召回结果 RRF 融合合并。
|
|
|
-
|
|
|
- 融合分数计算:
|
|
|
- - 基础分:weight / (rrf_k + rank),按路径权重和排名计算
|
|
|
- - 多源加分:同一条候选在多个路径中被召回时额外加分
|
|
|
- - Scope 加分:与当前项目范围一致时额外加分
|
|
|
- - 标签加分:关键词出现在候选文本中时额外加分
|
|
|
- """
|
|
|
- weights = {
|
|
|
- "parent_vector": self.config.parent_vector_weight,
|
|
|
- "child_locator": self.config.child_locator_weight,
|
|
|
- "tag": self.config.tag_weight,
|
|
|
- "chapter_similarity": self.config.chapter_similarity_weight,
|
|
|
- }
|
|
|
- merged: Dict[str, Dict[str, Any]] = {}
|
|
|
-
|
|
|
- for source, candidates in source_results.items():
|
|
|
- weight = weights.get(source, 0.0)
|
|
|
- for rank, item in enumerate(candidates or [], start=1):
|
|
|
- key = str(item.get("candidate_key") or self._build_candidate_key(item, item.get("text")))
|
|
|
- if not key:
|
|
|
- continue
|
|
|
- if key not in merged:
|
|
|
- candidate = dict(item)
|
|
|
- candidate["candidate_key"] = key
|
|
|
- candidate["source_hits"] = {}
|
|
|
- candidate["fusion_score"] = 0.0
|
|
|
- merged[key] = candidate
|
|
|
-
|
|
|
- current = merged[key]
|
|
|
- # RRF 公式:累加 weight / (rrf_k + rank)
|
|
|
- current["fusion_score"] = _to_float(current.get("fusion_score"), 0.0) + weight / (self.config.rrf_k + rank)
|
|
|
- current["vector_similarity"] = max(
|
|
|
- _to_float(current.get("vector_similarity"), 0.0),
|
|
|
- _to_float(item.get("vector_similarity"), 0.0),
|
|
|
- )
|
|
|
- current.setdefault("source_hits", {})[source] = {
|
|
|
- "rank": rank,
|
|
|
- "vector_similarity": _to_float(item.get("vector_similarity"), 0.0),
|
|
|
- }
|
|
|
- self._merge_metadata(current, item)
|
|
|
-
|
|
|
- # 额外加分
|
|
|
- for candidate in merged.values():
|
|
|
- source_hits = candidate.get("source_hits") if isinstance(candidate.get("source_hits"), dict) else {}
|
|
|
- metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
|
|
|
- if len(source_hits) > 1:
|
|
|
- candidate["fusion_score"] += self.config.multi_source_bonus
|
|
|
- if self._metadata_matches_scope(metadata, scope):
|
|
|
- candidate["fusion_score"] += self.config.scope_bonus
|
|
|
- candidate["fusion_score"] += self._calc_tag_bonus(candidate, keywords)
|
|
|
-
|
|
|
- return sorted(merged.values(), key=lambda item: item.get("fusion_score", 0.0), reverse=True)[: self.config.recall_top_k]
|
|
|
+ """多路召回结果 RRF 融合合并。"""
|
|
|
+ return merge_recall_results(source_results, scope, keywords, self.config)
|
|
|
|
|
|
# ============================================================
|
|
|
# Milvus 查询辅助
|
|
|
@@ -726,199 +543,67 @@ class DocumentChatRetrievalService:
|
|
|
# Scope 提取与过滤
|
|
|
# ============================================================
|
|
|
def _extract_scope(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
- """从工作流状态中提取检索范围信息。
|
|
|
-
|
|
|
- 按优先级从 selected_section、document_context、project_info、retrieval_filters
|
|
|
- 中查找字段值,兼容多种字段命名。
|
|
|
- """
|
|
|
- selected = state.get("selected_section") or {}
|
|
|
- context = state.get("document_context") or {}
|
|
|
- project = state.get("project_info") or {}
|
|
|
- filters = context.get("retrieval_filters") if isinstance(context.get("retrieval_filters"), dict) else {}
|
|
|
- filters = filters or project.get("retrieval_filters") if isinstance(project.get("retrieval_filters"), dict) else filters
|
|
|
-
|
|
|
- def pick(*keys: str) -> str:
|
|
|
- for source in (selected, context, project, filters or {}):
|
|
|
- for key in keys:
|
|
|
- value = source.get(key) if isinstance(source, dict) else None
|
|
|
- if value not in (None, ""):
|
|
|
- return str(value).strip()
|
|
|
- return ""
|
|
|
-
|
|
|
- return {
|
|
|
- "tenant_id": pick("tenant_id"),
|
|
|
- "project_id": pick("project_id"),
|
|
|
- "knowledge_base_id": pick("knowledge_base_id", "kb_id"),
|
|
|
- "engineering_type": pick("engineering_type", "project_type"),
|
|
|
- "plan_type": pick("plan_type"),
|
|
|
- "chapter_level_1": pick("chapter_level_1", "level1"),
|
|
|
- "chapter_level_2": pick("chapter_level_2", "level2"),
|
|
|
- "chapter_level_3": pick("chapter_level_3", "level3"),
|
|
|
- }
|
|
|
+ """从工作流状态中提取检索范围信息。"""
|
|
|
+ return extract_scope(state)
|
|
|
|
|
|
@staticmethod
|
|
|
def _has_reliable_scope(scope: Dict[str, Any]) -> bool:
|
|
|
"""判断是否有足够可靠的 scope 用于限定检索范围。"""
|
|
|
- if scope.get("chapter_level_1") and scope.get("chapter_level_2"):
|
|
|
- return True
|
|
|
- return bool(scope.get("plan_type"))
|
|
|
+ return has_reliable_scope(scope)
|
|
|
|
|
|
def _build_filter_expr(self, scope: Dict[str, Any]) -> str:
|
|
|
"""构建 Milvus 过滤表达式,按章节层级限定检索范围。"""
|
|
|
- conditions = []
|
|
|
- for key in ("plan_type", "chapter_level_1", "chapter_level_2", "chapter_level_3"):
|
|
|
- value = str(scope.get(key) or "").strip()
|
|
|
- if value:
|
|
|
- conditions.append(f"{key} == '{_escape_milvus_string(value)}'")
|
|
|
- return " and ".join(conditions)
|
|
|
+ return build_filter_expr(scope)
|
|
|
|
|
|
def _build_tag_expr(self, tag_terms: List[str]) -> str:
|
|
|
"""构建标签 LIKE 查询表达式。"""
|
|
|
- conditions = []
|
|
|
- for term in tag_terms[: self.config.tag_terms_limit]:
|
|
|
- conditions.append(f'tag_list like "%{_escape_milvus_string(term)}%"')
|
|
|
- return " or ".join(conditions)
|
|
|
+ return build_tag_expr(tag_terms, self.config.tag_terms_limit)
|
|
|
|
|
|
def _select_tag_terms(self, keywords: List[str]) -> List[str]:
|
|
|
- """从关键词中筛选高价值标签术语。
|
|
|
-
|
|
|
- 排除:验收、标准、规范等通用词(几乎匹配所有文档)
|
|
|
- 优先:标准号(如 TB10212-2012)、设备名(架桥机、龙门吊等)
|
|
|
- """
|
|
|
- generic_terms = {
|
|
|
- "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除",
|
|
|
- "要求", "安全", "环保", "质量", "进度", "交底",
|
|
|
- }
|
|
|
- device_terms = {"架桥机", "龙门吊", "吊车", "塔吊", "施工电梯", "挂篮", "支架", "台车"}
|
|
|
- selected = []
|
|
|
- priority = [] # 标准号和设备名优先
|
|
|
- seen = set()
|
|
|
- for keyword in keywords:
|
|
|
- value = str(keyword or "").strip()
|
|
|
- if len(value) < 2 or value in seen:
|
|
|
- continue
|
|
|
- seen.add(value)
|
|
|
- if value in generic_terms:
|
|
|
- continue
|
|
|
- if re.match(r"[A-Z]{1,3}\d{4,}", value) or value in device_terms:
|
|
|
- priority.append(value)
|
|
|
- elif len(selected) < self.config.tag_terms_limit:
|
|
|
- selected.append(value)
|
|
|
- return priority + selected
|
|
|
+ """从关键词中筛选高价值标签术语。"""
|
|
|
+ return select_tag_terms(
|
|
|
+ keywords,
|
|
|
+ self.config.tag_terms_limit,
|
|
|
+ generic_terms=self.config.tag_generic_terms,
|
|
|
+ priority_terms=self.config.tag_priority_terms,
|
|
|
+ )
|
|
|
|
|
|
@staticmethod
|
|
|
def _metadata_matches_scope(metadata: Dict[str, Any], scope: Dict[str, Any]) -> bool:
|
|
|
- """检查候选 metadata 是否与当前检索 scope 兼容。
|
|
|
-
|
|
|
- 不要求所有字段都匹配,仅校验 scope 和 metadata 同时存在且不一致的字段。
|
|
|
- """
|
|
|
- required_keys = ["tenant_id", "project_id", "knowledge_base_id", "chapter_level_1", "chapter_level_2", "chapter_level_3"]
|
|
|
- for key in required_keys:
|
|
|
- expected = str(scope.get(key) or "").strip()
|
|
|
- if not expected:
|
|
|
- continue
|
|
|
- actual = str(metadata.get(key) or "").strip()
|
|
|
- if actual and actual != expected:
|
|
|
- return False
|
|
|
- return True
|
|
|
+ """检查候选 metadata 是否与当前检索 scope 兼容。"""
|
|
|
+ return metadata_matches_scope(metadata, scope)
|
|
|
|
|
|
# ============================================================
|
|
|
# Metadata 处理
|
|
|
# ============================================================
|
|
|
def _normalize_row_metadata(self, row_or_metadata: Any) -> Dict[str, Any]:
|
|
|
"""规范化行数据为统一的 metadata 字典。处理嵌套 metadata 和 YAML 字符串。"""
|
|
|
- metadata = self._normalize_metadata(row_or_metadata)
|
|
|
- inner = self._normalize_metadata(metadata.get("metadata")) if metadata.get("metadata") else {}
|
|
|
- for key, value in inner.items():
|
|
|
- metadata.setdefault(key, value)
|
|
|
- for key in self.PARENT_OUTPUT_FIELDS:
|
|
|
- if isinstance(row_or_metadata, dict) and row_or_metadata.get(key) not in (None, ""):
|
|
|
- metadata[key] = row_or_metadata.get(key)
|
|
|
- return metadata
|
|
|
+ return normalize_row_metadata(row_or_metadata, self.PARENT_OUTPUT_FIELDS)
|
|
|
|
|
|
@staticmethod
|
|
|
def _normalize_metadata(metadata: Any) -> Dict[str, Any]:
|
|
|
"""将 metadata 转为字典,支持 YAML 字符串解析。"""
|
|
|
- if isinstance(metadata, dict):
|
|
|
- return dict(metadata)
|
|
|
- if isinstance(metadata, str) and metadata.strip():
|
|
|
- try:
|
|
|
- loaded = yaml.safe_load(metadata)
|
|
|
- return dict(loaded) if isinstance(loaded, dict) else {}
|
|
|
- except Exception:
|
|
|
- return {}
|
|
|
- return {}
|
|
|
+ return normalize_metadata(metadata)
|
|
|
|
|
|
@staticmethod
|
|
|
def _metadata_value(metadata: Dict[str, Any], key: str) -> Any:
|
|
|
"""安全获取 metadata 值,支持嵌套 metadata.metadata 和 YAML 字符串。"""
|
|
|
- if key in metadata:
|
|
|
- return metadata.get(key)
|
|
|
- nested = metadata.get("metadata")
|
|
|
- if isinstance(nested, dict):
|
|
|
- return nested.get(key)
|
|
|
- if isinstance(nested, str) and nested.strip():
|
|
|
- try:
|
|
|
- parsed = yaml.safe_load(nested)
|
|
|
- if isinstance(parsed, dict):
|
|
|
- return parsed.get(key)
|
|
|
- except Exception:
|
|
|
- return None
|
|
|
- return None
|
|
|
+ return metadata_value(metadata, key)
|
|
|
|
|
|
def _build_candidate_key(self, metadata: Dict[str, Any], text: Any = "") -> str:
|
|
|
"""构建候选唯一标识键,按优先级尝试不同字段组合。"""
|
|
|
- metadata = self._normalize_row_metadata(metadata)
|
|
|
- document_id = str(self._metadata_value(metadata, "document_id") or "").strip()
|
|
|
- parent_id = str(self._metadata_value(metadata, "parent_id") or "").strip()
|
|
|
- chunk_id = str(self._metadata_value(metadata, "chunk_id") or "").strip()
|
|
|
- chapter_title = str(self._metadata_value(metadata, "chapter_title") or "").strip()
|
|
|
- index = self._metadata_value(metadata, "index")
|
|
|
- pk = str(self._metadata_value(metadata, "pk") or "").strip()
|
|
|
-
|
|
|
- if document_id and parent_id and chunk_id:
|
|
|
- return f"{document_id}::{parent_id}::{chunk_id}"
|
|
|
- if document_id and parent_id and chapter_title and index not in (None, ""):
|
|
|
- return f"{document_id}::{parent_id}::{chapter_title}::{index}"
|
|
|
- if pk:
|
|
|
- return pk
|
|
|
- if parent_id and chapter_title and index not in (None, ""):
|
|
|
- return f"{parent_id}::{chapter_title}::{index}"
|
|
|
- return str(text or "")[:300]
|
|
|
+ return build_candidate_key(metadata, text, self.PARENT_OUTPUT_FIELDS)
|
|
|
|
|
|
def _merge_metadata(self, current: Dict[str, Any], incoming: Dict[str, Any]) -> None:
|
|
|
"""合并两条候选的 metadata,不覆盖已有非空值。"""
|
|
|
- current_meta = current.setdefault("metadata", {})
|
|
|
- incoming_meta = incoming.get("metadata") if isinstance(incoming.get("metadata"), dict) else {}
|
|
|
- for key, value in incoming_meta.items():
|
|
|
- if key not in current_meta or current_meta.get(key) in (None, "", []):
|
|
|
- current_meta[key] = value
|
|
|
- if incoming.get("source") and not current.get("source"):
|
|
|
- current["source"] = incoming.get("source")
|
|
|
+ merge_metadata(current, incoming)
|
|
|
|
|
|
# ============================================================
|
|
|
# 加分计算
|
|
|
# ============================================================
|
|
|
def _calc_tag_bonus(self, candidate: Dict[str, Any], keywords: List[str]) -> float:
|
|
|
"""计算标签匹配加分:关键词精确匹配 tag_list 加分更多,仅出现在文本中加分较少。"""
|
|
|
- metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
|
|
|
- text = " ".join(
|
|
|
- str(value or "")
|
|
|
- for value in (
|
|
|
- candidate.get("text"),
|
|
|
- metadata.get("tag_list"),
|
|
|
- " ".join(metadata.get("matched_child_texts") or []),
|
|
|
- )
|
|
|
- )
|
|
|
- bonus = 0.0
|
|
|
- for keyword in self._select_tag_terms(keywords):
|
|
|
- if not keyword:
|
|
|
- continue
|
|
|
- if keyword in str(metadata.get("tag_list") or ""):
|
|
|
- bonus += self.config.tag_exact_bonus
|
|
|
- elif keyword in text:
|
|
|
- bonus += self.config.tag_partial_bonus
|
|
|
- return bonus
|
|
|
+ return calc_tag_bonus(candidate, keywords, self.config)
|
|
|
|
|
|
# ============================================================
|
|
|
# 候选清理
|
|
|
@@ -930,37 +615,7 @@ class DocumentChatRetrievalService:
|
|
|
1. candidate_key 去重:相同 document+parent+chunk 视为同一条
|
|
|
2. 内容哈希去重:同一文件同一文本内容(即使路径不同)只保留一条
|
|
|
"""
|
|
|
- cleaned = []
|
|
|
- seen_keys = set()
|
|
|
- seen_hashes = set()
|
|
|
- for item in candidates:
|
|
|
- text = str(item.get("text") or "").strip()
|
|
|
- if len(text) < 20:
|
|
|
- continue
|
|
|
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
|
|
- dedupe_key = str(item.get("candidate_key") or text[:300])
|
|
|
- # 内容哈希去重
|
|
|
- content_hash = _content_hash(text[:300])
|
|
|
- file_name = str(metadata.get("file_name") or "")
|
|
|
- hash_key = f"{file_name}::{content_hash}"
|
|
|
- if dedupe_key in seen_keys or hash_key in seen_hashes:
|
|
|
- continue
|
|
|
- seen_keys.add(dedupe_key)
|
|
|
- seen_hashes.add(hash_key)
|
|
|
- metadata["candidate_key"] = dedupe_key
|
|
|
- cleaned.append(
|
|
|
- {
|
|
|
- "candidate_key": dedupe_key,
|
|
|
- "text": text[: self.config.max_single_reference_chars],
|
|
|
- "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
|
|
|
- "vector_similarity": _to_float(item.get("vector_similarity"), 0.0),
|
|
|
- "fusion_score": _to_float(item.get("fusion_score"), 0.0),
|
|
|
- "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {},
|
|
|
- "metadata": metadata,
|
|
|
- }
|
|
|
- )
|
|
|
- cleaned.sort(key=lambda item: (item.get("fusion_score", 0.0), item.get("vector_similarity", 0.0)), reverse=True)
|
|
|
- return cleaned[: self.config.recall_top_k]
|
|
|
+ return clean_candidates(candidates, self.config)
|
|
|
|
|
|
# ============================================================
|
|
|
# 空结果/告警
|
|
|
@@ -989,146 +644,5 @@ class DocumentChatRetrievalService:
|
|
|
|
|
|
def _warning(self, key: str) -> str:
|
|
|
"""获取指定键的告警文案。"""
|
|
|
- warnings = self.config.warnings or _default_warnings()
|
|
|
+ warnings = self.config.warnings or default_warnings()
|
|
|
return warnings.get(key) or ""
|
|
|
-
|
|
|
-
|
|
|
-def _default_warnings() -> Dict[str, str]:
|
|
|
- return {
|
|
|
- "no_scope": "缺少可靠的知识库检索范围,本次未引用向量库内容。",
|
|
|
- "no_recall": "未召回可信知识库内容,本次回答不引用向量库。",
|
|
|
- "low_confidence": "未找到可信度足够的知识库片段,本次未引用向量库内容。",
|
|
|
- "rerank_failed": "知识库片段重排不可用,本次未引用向量库内容。",
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
-def _escape_milvus_string(value: str) -> str:
|
|
|
- """转义 Milvus 字符串中的特殊字符(反斜杠、单引号、双引号)。"""
|
|
|
- return str(value).replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
|
|
|
-
|
|
|
-
|
|
|
-def _combine_expr(*exprs: str) -> str:
|
|
|
- """用 AND 连接多个过滤表达式,每个子表达式加括号。"""
|
|
|
- parts = [f"({expr})" for expr in exprs if str(expr or "").strip()]
|
|
|
- return " and ".join(parts)
|
|
|
-
|
|
|
-
|
|
|
-def _dedupe_join(parts: List[str], max_chars: int) -> str:
|
|
|
- """去重后拼接文本片段,限制总长度。"""
|
|
|
- values = []
|
|
|
- seen = set()
|
|
|
- for part in parts:
|
|
|
- value = re.sub(r"\s+", " ", str(part or "")).strip()
|
|
|
- if not value or value in seen:
|
|
|
- continue
|
|
|
- seen.add(value)
|
|
|
- values.append(value)
|
|
|
- return " ".join(values)[:max_chars]
|
|
|
-
|
|
|
-
|
|
|
-def _extract_retrieval_keywords(text: str) -> List[str]:
|
|
|
- """从文本中提取检索关键词,支持多种模式:
|
|
|
-
|
|
|
- 1. 标准号/型号:如 TB10212-2012、φ48.3×3.6
|
|
|
- 2. 规范名称:《XXX规范》
|
|
|
- 3. 领域专业术语:验收、架桥机、箱梁等
|
|
|
- 4. 术语+动作组合:XX验收、XX安装
|
|
|
- 5. 长词中的领域术语片段
|
|
|
- """
|
|
|
- if not text:
|
|
|
- return []
|
|
|
-
|
|
|
- keywords: List[str] = []
|
|
|
- # 模式1:标准号/型号(字母+数字,可选连字符)
|
|
|
- for match in re.findall(r"[A-Za-z]{1,8}\s*\d{2,8}(?:[-—]\d{2,4})?", text):
|
|
|
- keywords.append(re.sub(r"\s+", "", match).upper())
|
|
|
- # 模式2:《XXX》规范名称
|
|
|
- for match in re.findall(r"《([^》]{2,40})》", text):
|
|
|
- keywords.append(match.strip())
|
|
|
-
|
|
|
- # 模式3:领域专业术语
|
|
|
- domain_terms = (
|
|
|
- "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "吊装",
|
|
|
- "架桥机", "龙门吊", "吊车", "箱梁", "T梁", "梁板", "钢丝绳", "支座",
|
|
|
- "地基", "安全装置", "操作证", "合格证", "静载", "动载", "空载",
|
|
|
- )
|
|
|
- for term in domain_terms:
|
|
|
- if term in text:
|
|
|
- keywords.append(term)
|
|
|
-
|
|
|
- # 模式4:术语+动作组合
|
|
|
- for match in re.findall(r"[一-鿿A-Za-z0-9.-]{0,12}(?:验收|标准|规范|检查|检测|试验|安装|拆除|吊装|要求)", text):
|
|
|
- if 2 <= len(match) <= 20:
|
|
|
- keywords.append(match)
|
|
|
-
|
|
|
- # 模式5:分词后含领域术语的片段
|
|
|
- normalized = re.sub(r"[\s,,。;;::、/\\|()\[\]{}<>《》\"'""??]+", " ", text)
|
|
|
- for token in normalized.split():
|
|
|
- token = token.strip()
|
|
|
- if len(token) < 2 or len(token) > 12:
|
|
|
- continue
|
|
|
- if any(term in token for term in domain_terms):
|
|
|
- keywords.append(token)
|
|
|
-
|
|
|
- seen = set()
|
|
|
- unique = []
|
|
|
- for keyword in keywords:
|
|
|
- keyword = keyword.strip()
|
|
|
- if keyword and keyword not in seen:
|
|
|
- seen.add(keyword)
|
|
|
- unique.append(keyword)
|
|
|
- return unique
|
|
|
-
|
|
|
-
|
|
|
-def _pack_log_items(items: List[Dict[str, Any]], limit: int = 20, text_limit: int = 1500) -> List[Dict[str, Any]]:
|
|
|
- """打包候选条目为日志格式,限制条数和文本长度。"""
|
|
|
- packed = []
|
|
|
- for item in (items or [])[:limit]:
|
|
|
- if not isinstance(item, dict):
|
|
|
- continue
|
|
|
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
|
|
- text = str(item.get("text") or item.get("text_content") or item.get("content") or "").strip()
|
|
|
- packed.append(
|
|
|
- {
|
|
|
- "candidate_key": item.get("candidate_key"),
|
|
|
- "source": item.get("source") or metadata.get("file_name") or "",
|
|
|
- "text": text[:text_limit],
|
|
|
- "vector_similarity": _to_float(item.get("vector_similarity", item.get("similarity")), 0.0),
|
|
|
- "fusion_score": _to_float(item.get("fusion_score"), 0.0),
|
|
|
- "rerank_score": _to_float(item.get("rerank_score"), 0.0) if "rerank_score" in item else None,
|
|
|
- "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {},
|
|
|
- "metadata": {
|
|
|
- key: metadata.get(key)
|
|
|
- for key in (
|
|
|
- "document_id", "parent_id", "file_name", "chapter_title",
|
|
|
- "chapter_level_1", "chapter_level_2", "chapter_level_3",
|
|
|
- "parent_count", "child_hit_count", "matched_child_texts",
|
|
|
- "tag_match_terms", "source_scope_valid",
|
|
|
- )
|
|
|
- if metadata.get(key) not in (None, "")
|
|
|
- },
|
|
|
- }
|
|
|
- )
|
|
|
- return packed
|
|
|
-
|
|
|
-
|
|
|
-def _to_int(value: Any, default: int) -> int:
|
|
|
- """安全整数转换。"""
|
|
|
- try:
|
|
|
- return int(value)
|
|
|
- except (TypeError, ValueError):
|
|
|
- return default
|
|
|
-
|
|
|
-
|
|
|
-def _to_float(value: Any, default: float = 0.0) -> float:
|
|
|
- """安全浮点数转换。"""
|
|
|
- try:
|
|
|
- return float(value)
|
|
|
- except (TypeError, ValueError):
|
|
|
- return default
|
|
|
-
|
|
|
-
|
|
|
-def _content_hash(text: str) -> str:
|
|
|
- """基于归一化文本的短 MD5 哈希,用于内容去重。"""
|
|
|
- normalized = re.sub(r"\s+", " ", text.strip().lower())
|
|
|
- return md5(normalized.encode("utf-8")).hexdigest()[:12]
|