config.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # -*- coding: utf-8 -*-
  2. """文档对话检索配置加载。"""
  3. from __future__ import annotations
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. from typing import Any, Dict, Optional, Tuple
  7. import yaml
  8. from core.document_chat.retrieval.utils import to_float, to_int
  9. PROJECT_ROOT = Path(__file__).resolve().parents[3]
  10. RETRIEVAL_CONFIG_PATH = PROJECT_ROOT / "config" / "document_chat_retrieval.yaml"
  11. DEFAULT_DOMAIN_TERMS = (
  12. "工程概况",
  13. "编制依据",
  14. "施工部署",
  15. "施工准备",
  16. "资源配置",
  17. "测量放线",
  18. "临时用电",
  19. "临时用水",
  20. "交通组织",
  21. "土方",
  22. "基坑",
  23. "模板",
  24. "钢筋",
  25. "混凝土",
  26. "脚手架",
  27. "防水",
  28. "装饰装修",
  29. "验收",
  30. "标准",
  31. "规范",
  32. "检查",
  33. "检测",
  34. "试验",
  35. "安装",
  36. "拆除",
  37. "吊装",
  38. "质量控制",
  39. "安全文明施工",
  40. "环境保护",
  41. "水土保持",
  42. "应急预案",
  43. "成品保护",
  44. "进度计划",
  45. "机械设备",
  46. "劳动力",
  47. "材料计划",
  48. "架桥机",
  49. "龙门吊",
  50. "吊车",
  51. "塔吊",
  52. "施工电梯",
  53. "挂篮",
  54. "支架",
  55. "台车",
  56. "箱梁",
  57. "T梁",
  58. "梁板",
  59. "钢丝绳",
  60. "支座",
  61. "地基",
  62. "安全装置",
  63. "操作证",
  64. "合格证",
  65. "静载",
  66. "动载",
  67. "空载",
  68. )
  69. DEFAULT_ACTION_TERMS = (
  70. "验收",
  71. "标准",
  72. "规范",
  73. "检查",
  74. "检测",
  75. "试验",
  76. "安装",
  77. "拆除",
  78. "吊装",
  79. "要求",
  80. "控制",
  81. "保护",
  82. "预案",
  83. "计划",
  84. )
  85. DEFAULT_TAG_GENERIC_TERMS = (
  86. "验收",
  87. "标准",
  88. "规范",
  89. "检查",
  90. "检测",
  91. "试验",
  92. "安装",
  93. "拆除",
  94. "要求",
  95. "安全",
  96. "环保",
  97. "质量",
  98. "进度",
  99. "交底",
  100. )
  101. DEFAULT_TAG_PRIORITY_TERMS = (
  102. "架桥机",
  103. "龙门吊",
  104. "吊车",
  105. "塔吊",
  106. "施工电梯",
  107. "挂篮",
  108. "支架",
  109. "台车",
  110. )
  111. @dataclass(frozen=True)
  112. class RetrievalConfig:
  113. """检索配置(不可变)。各参数含义见字段注释。"""
  114. enabled: bool = True
  115. parent_collection: str = "t_kngs_construction_plan_parent"
  116. child_collection: str = "t_kngs_construction_plan_child"
  117. # 各路径召回上限
  118. parent_recall_top_k: int = 30
  119. child_recall_top_k: int = 40
  120. tag_recall_top_k: int = 30
  121. chapter_recall_top_k: int = 15
  122. recall_top_k: int = 30
  123. rerank_top_k: int = 8
  124. submit_top_k: int = 3 # 最终送入 LLM prompt 的参考条数上限
  125. # 质量阈值
  126. min_vector_similarity: float = 0.45
  127. min_rerank_score: float = 0.65 # 重排质量门,低于此值被过滤
  128. min_qualified_count: int = 1
  129. # 参考内容长度限制
  130. max_reference_chars: int = 4000
  131. max_single_reference_chars: int = 1500
  132. # 降级策略
  133. allow_vector_fallback: bool = False
  134. allow_unscoped_search: bool = False
  135. # 混合搜索权重(dense=sparse 向量融合)
  136. dense_weight: float = 0.7
  137. sparse_weight: float = 0.3
  138. child_dense_weight: float = 0.6
  139. child_sparse_weight: float = 0.4
  140. ranker_type: str = "weighted"
  141. # 标签召回
  142. tag_recall_enabled: bool = True
  143. tag_terms_limit: int = 8
  144. # RRF 参数
  145. rrf_k: int = 60
  146. # 路径权重
  147. parent_vector_weight: float = 1.0
  148. child_locator_weight: float = 0.8
  149. tag_weight: float = 1.2
  150. chapter_similarity_weight: float = 0.5
  151. # 加分项
  152. tag_exact_bonus: float = 0.08
  153. tag_partial_bonus: float = 0.03
  154. multi_source_bonus: float = 0.02
  155. scope_bonus: float = 0.03
  156. keyword_domain_terms: Tuple[str, ...] = DEFAULT_DOMAIN_TERMS
  157. keyword_action_terms: Tuple[str, ...] = DEFAULT_ACTION_TERMS
  158. tag_generic_terms: Tuple[str, ...] = DEFAULT_TAG_GENERIC_TERMS
  159. tag_priority_terms: Tuple[str, ...] = DEFAULT_TAG_PRIORITY_TERMS
  160. warnings: Optional[Dict[str, str]] = None
  161. def load_retrieval_config() -> RetrievalConfig:
  162. """从 YAML 配置文件加载检索参数,文件不存在时使用默认值。"""
  163. if not RETRIEVAL_CONFIG_PATH.exists():
  164. return RetrievalConfig(warnings=default_warnings())
  165. with open(RETRIEVAL_CONFIG_PATH, "r", encoding="utf-8") as handle:
  166. raw = yaml.safe_load(handle) or {}
  167. retrieval = raw.get("retrieval") or {}
  168. keyword_extraction = raw.get("keyword_extraction") or {}
  169. warnings = raw.get("warnings") or default_warnings()
  170. return RetrievalConfig(
  171. enabled=bool(retrieval.get("enabled", True)),
  172. parent_collection=str(retrieval.get("parent_collection") or "t_kngs_construction_plan_parent"),
  173. child_collection=str(retrieval.get("child_collection") or "t_kngs_construction_plan_child"),
  174. parent_recall_top_k=to_int(retrieval.get("parent_recall_top_k"), 30),
  175. child_recall_top_k=to_int(retrieval.get("child_recall_top_k"), 40),
  176. tag_recall_top_k=to_int(retrieval.get("tag_recall_top_k"), 30),
  177. chapter_recall_top_k=to_int(retrieval.get("chapter_recall_top_k"), 15),
  178. recall_top_k=to_int(retrieval.get("recall_top_k"), 30),
  179. rerank_top_k=to_int(retrieval.get("rerank_top_k"), 8),
  180. submit_top_k=to_int(retrieval.get("submit_top_k"), 3),
  181. min_vector_similarity=to_float(retrieval.get("min_vector_similarity"), 0.45),
  182. min_rerank_score=to_float(retrieval.get("min_rerank_score"), 0.65),
  183. min_qualified_count=to_int(retrieval.get("min_qualified_count"), 1),
  184. max_reference_chars=to_int(retrieval.get("max_reference_chars"), 4000),
  185. max_single_reference_chars=to_int(retrieval.get("max_single_reference_chars"), 1500),
  186. allow_vector_fallback=bool(retrieval.get("allow_vector_fallback", False)),
  187. allow_unscoped_search=bool(retrieval.get("allow_unscoped_search", False)),
  188. dense_weight=to_float(retrieval.get("dense_weight"), 0.7),
  189. sparse_weight=to_float(retrieval.get("sparse_weight"), 0.3),
  190. child_dense_weight=to_float(retrieval.get("child_dense_weight"), 0.6),
  191. child_sparse_weight=to_float(retrieval.get("child_sparse_weight"), 0.4),
  192. ranker_type=str(retrieval.get("ranker_type") or "weighted"),
  193. tag_recall_enabled=bool(retrieval.get("tag_recall_enabled", True)),
  194. tag_terms_limit=to_int(retrieval.get("tag_terms_limit"), 8),
  195. rrf_k=to_int(retrieval.get("rrf_k"), 60),
  196. parent_vector_weight=to_float(retrieval.get("parent_vector_weight"), 1.0),
  197. child_locator_weight=to_float(retrieval.get("child_locator_weight"), 0.8),
  198. tag_weight=to_float(retrieval.get("tag_weight"), 1.2),
  199. chapter_similarity_weight=to_float(retrieval.get("chapter_similarity_weight"), 0.5),
  200. tag_exact_bonus=to_float(retrieval.get("tag_exact_bonus"), 0.08),
  201. tag_partial_bonus=to_float(retrieval.get("tag_partial_bonus"), 0.03),
  202. multi_source_bonus=to_float(retrieval.get("multi_source_bonus"), 0.02),
  203. scope_bonus=to_float(retrieval.get("scope_bonus"), 0.03),
  204. keyword_domain_terms=to_str_tuple(keyword_extraction.get("domain_terms"), DEFAULT_DOMAIN_TERMS),
  205. keyword_action_terms=to_str_tuple(keyword_extraction.get("action_terms"), DEFAULT_ACTION_TERMS),
  206. tag_generic_terms=to_str_tuple(keyword_extraction.get("tag_generic_terms"), DEFAULT_TAG_GENERIC_TERMS),
  207. tag_priority_terms=to_str_tuple(keyword_extraction.get("tag_priority_terms"), DEFAULT_TAG_PRIORITY_TERMS),
  208. warnings=warnings,
  209. )
  210. def default_warnings() -> Dict[str, str]:
  211. return {
  212. "no_scope": "缺少可靠的知识库检索范围,本次未引用向量库内容。",
  213. "no_recall": "未召回可信知识库内容,本次回答不引用向量库。",
  214. "low_confidence": "未找到可信度足够的知识库片段,本次未引用向量库内容。",
  215. "rerank_failed": "知识库片段重排不可用,本次未引用向量库内容。",
  216. }
  217. def to_str_tuple(value: Any, default: Tuple[str, ...]) -> Tuple[str, ...]:
  218. """将 YAML 列表/元组值转为字符串元组。"""
  219. if not isinstance(value, (list, tuple)):
  220. return default
  221. terms = tuple(str(item).strip() for item in value if str(item or "").strip())
  222. return terms or default