# -*- coding: utf-8 -*- """文档对话检索配置加载。""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional, Tuple import yaml from core.document_chat.retrieval.utils import to_float, to_int PROJECT_ROOT = Path(__file__).resolve().parents[3] RETRIEVAL_CONFIG_PATH = PROJECT_ROOT / "config" / "document_chat_retrieval.yaml" DEFAULT_DOMAIN_TERMS = ( "工程概况", "编制依据", "施工部署", "施工准备", "资源配置", "测量放线", "临时用电", "临时用水", "交通组织", "土方", "基坑", "模板", "钢筋", "混凝土", "脚手架", "防水", "装饰装修", "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "吊装", "质量控制", "安全文明施工", "环境保护", "水土保持", "应急预案", "成品保护", "进度计划", "机械设备", "劳动力", "材料计划", "架桥机", "龙门吊", "吊车", "塔吊", "施工电梯", "挂篮", "支架", "台车", "箱梁", "T梁", "梁板", "钢丝绳", "支座", "地基", "安全装置", "操作证", "合格证", "静载", "动载", "空载", ) DEFAULT_ACTION_TERMS = ( "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "吊装", "要求", "控制", "保护", "预案", "计划", ) DEFAULT_TAG_GENERIC_TERMS = ( "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "要求", "安全", "环保", "质量", "进度", "交底", ) DEFAULT_TAG_PRIORITY_TERMS = ( "架桥机", "龙门吊", "吊车", "塔吊", "施工电梯", "挂篮", "支架", "台车", ) @dataclass(frozen=True) class RetrievalConfig: """检索配置(不可变)。各参数含义见字段注释。""" enabled: bool = True parent_collection: str = "t_kngs_construction_plan_parent" child_collection: str = "t_kngs_construction_plan_child" # 各路径召回上限 parent_recall_top_k: int = 30 child_recall_top_k: int = 40 tag_recall_top_k: int = 30 chapter_recall_top_k: int = 15 recall_top_k: int = 30 rerank_top_k: int = 8 submit_top_k: int = 3 # 最终送入 LLM prompt 的参考条数上限 # 质量阈值 min_vector_similarity: float = 0.45 min_rerank_score: float = 0.65 # 重排质量门,低于此值被过滤 min_qualified_count: int = 1 # 参考内容长度限制 max_reference_chars: int = 4000 max_single_reference_chars: int = 1500 # 降级策略 allow_vector_fallback: bool = False allow_unscoped_search: bool = False # 混合搜索权重(dense=sparse 向量融合) dense_weight: float = 0.7 sparse_weight: float = 0.3 child_dense_weight: float = 0.6 child_sparse_weight: float = 0.4 ranker_type: str = "weighted" # 标签召回 tag_recall_enabled: bool = True tag_terms_limit: int = 8 # RRF 参数 rrf_k: int = 60 # 路径权重 parent_vector_weight: float = 1.0 child_locator_weight: float = 0.8 tag_weight: float = 1.2 chapter_similarity_weight: float = 0.5 # 加分项 tag_exact_bonus: float = 0.08 tag_partial_bonus: float = 0.03 multi_source_bonus: float = 0.02 scope_bonus: float = 0.03 keyword_domain_terms: Tuple[str, ...] = DEFAULT_DOMAIN_TERMS keyword_action_terms: Tuple[str, ...] = DEFAULT_ACTION_TERMS tag_generic_terms: Tuple[str, ...] = DEFAULT_TAG_GENERIC_TERMS tag_priority_terms: Tuple[str, ...] = DEFAULT_TAG_PRIORITY_TERMS warnings: Optional[Dict[str, str]] = None def load_retrieval_config() -> RetrievalConfig: """从 YAML 配置文件加载检索参数,文件不存在时使用默认值。""" if not RETRIEVAL_CONFIG_PATH.exists(): return RetrievalConfig(warnings=default_warnings()) with open(RETRIEVAL_CONFIG_PATH, "r", encoding="utf-8") as handle: raw = yaml.safe_load(handle) or {} retrieval = raw.get("retrieval") or {} keyword_extraction = raw.get("keyword_extraction") or {} warnings = raw.get("warnings") or default_warnings() return RetrievalConfig( enabled=bool(retrieval.get("enabled", True)), parent_collection=str(retrieval.get("parent_collection") or "t_kngs_construction_plan_parent"), child_collection=str(retrieval.get("child_collection") or "t_kngs_construction_plan_child"), parent_recall_top_k=to_int(retrieval.get("parent_recall_top_k"), 30), child_recall_top_k=to_int(retrieval.get("child_recall_top_k"), 40), tag_recall_top_k=to_int(retrieval.get("tag_recall_top_k"), 30), chapter_recall_top_k=to_int(retrieval.get("chapter_recall_top_k"), 15), recall_top_k=to_int(retrieval.get("recall_top_k"), 30), rerank_top_k=to_int(retrieval.get("rerank_top_k"), 8), submit_top_k=to_int(retrieval.get("submit_top_k"), 3), min_vector_similarity=to_float(retrieval.get("min_vector_similarity"), 0.45), min_rerank_score=to_float(retrieval.get("min_rerank_score"), 0.65), min_qualified_count=to_int(retrieval.get("min_qualified_count"), 1), max_reference_chars=to_int(retrieval.get("max_reference_chars"), 4000), max_single_reference_chars=to_int(retrieval.get("max_single_reference_chars"), 1500), allow_vector_fallback=bool(retrieval.get("allow_vector_fallback", False)), allow_unscoped_search=bool(retrieval.get("allow_unscoped_search", False)), dense_weight=to_float(retrieval.get("dense_weight"), 0.7), sparse_weight=to_float(retrieval.get("sparse_weight"), 0.3), child_dense_weight=to_float(retrieval.get("child_dense_weight"), 0.6), child_sparse_weight=to_float(retrieval.get("child_sparse_weight"), 0.4), ranker_type=str(retrieval.get("ranker_type") or "weighted"), tag_recall_enabled=bool(retrieval.get("tag_recall_enabled", True)), tag_terms_limit=to_int(retrieval.get("tag_terms_limit"), 8), rrf_k=to_int(retrieval.get("rrf_k"), 60), parent_vector_weight=to_float(retrieval.get("parent_vector_weight"), 1.0), child_locator_weight=to_float(retrieval.get("child_locator_weight"), 0.8), tag_weight=to_float(retrieval.get("tag_weight"), 1.2), chapter_similarity_weight=to_float(retrieval.get("chapter_similarity_weight"), 0.5), tag_exact_bonus=to_float(retrieval.get("tag_exact_bonus"), 0.08), tag_partial_bonus=to_float(retrieval.get("tag_partial_bonus"), 0.03), multi_source_bonus=to_float(retrieval.get("multi_source_bonus"), 0.02), scope_bonus=to_float(retrieval.get("scope_bonus"), 0.03), keyword_domain_terms=to_str_tuple(keyword_extraction.get("domain_terms"), DEFAULT_DOMAIN_TERMS), keyword_action_terms=to_str_tuple(keyword_extraction.get("action_terms"), DEFAULT_ACTION_TERMS), tag_generic_terms=to_str_tuple(keyword_extraction.get("tag_generic_terms"), DEFAULT_TAG_GENERIC_TERMS), tag_priority_terms=to_str_tuple(keyword_extraction.get("tag_priority_terms"), DEFAULT_TAG_PRIORITY_TERMS), warnings=warnings, ) def default_warnings() -> Dict[str, str]: return { "no_scope": "缺少可靠的知识库检索范围,本次未引用向量库内容。", "no_recall": "未召回可信知识库内容,本次回答不引用向量库。", "low_confidence": "未找到可信度足够的知识库片段,本次未引用向量库内容。", "rerank_failed": "知识库片段重排不可用,本次未引用向量库内容。", } def to_str_tuple(value: Any, default: Tuple[str, ...]) -> Tuple[str, ...]: """将 YAML 列表/元组值转为字符串元组。""" if not isinstance(value, (list, tuple)): return default terms = tuple(str(item).strip() for item in value if str(item or "").strip()) return terms or default