| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134 |
- # -*- coding: utf-8 -*-
- """质量优先的多路向量检索服务。
- 四路召回架构:
- 1. parent_vector:父表向量检索(主体内容向量)
- 2. child_locator:子表向量定位 → 反查父行(精确定位片段)
- 3. tag_keyword:标签关键词匹配(设备型号、标准号等)
- 4. chapter_similarity:章节相似度检索(同类型章节参考)
- 合并策略:
- - RRF(Reciprocal Rank Fusion)融合多路排名
- - 按路径加权:parent_vector 1.0, child_locator 0.8, tag 1.2, chapter 0.5
- - 多源加分:同一条候选在多个路径中被召回时额外加分
- - 标签匹配加分:关键词出现在 tag_list 或文本中时额外加分
- - Scope 匹配加分:与当前项目/章节范围一致时额外加分
- 去重策略:
- - candidate_key 去重(基于 document_id + parent_id + chunk_id)
- - 内容哈希去重(同一文件同一文本内容仅保留一条)
- """
- from __future__ import annotations
- from dataclasses import dataclass
- from hashlib import md5
- from pathlib import Path
- import re
- from typing import Any, Callable, Dict, List, Optional
- import yaml
- from core.document_chat.component.document_chat_logger import document_chat_logger as logger
- PROJECT_ROOT = Path(__file__).resolve().parents[3]
- RETRIEVAL_CONFIG_PATH = PROJECT_ROOT / "config" / "document_chat_retrieval.yaml"
- @dataclass(frozen=True)
- class RetrievalConfig:
- """检索配置(不可变)。各参数含义见下方字段注释。"""
- enabled: bool = True
- parent_collection: str = "t_kngs_construction_plan_parent"
- child_collection: str = "t_kngs_construction_plan_child"
- # 各路径召回上限
- parent_recall_top_k: int = 30
- child_recall_top_k: int = 40
- tag_recall_top_k: int = 30
- chapter_recall_top_k: int = 15
- recall_top_k: int = 30
- rerank_top_k: int = 8
- submit_top_k: int = 3 # 最终送入 LLM prompt 的参考条数上限
- # 质量阈值
- min_vector_similarity: float = 0.45
- min_rerank_score: float = 0.65 # 重排质量门,低于此值被过滤
- min_qualified_count: int = 1
- # 参考内容长度限制
- max_reference_chars: int = 4000 # 所有参考总字符上限
- max_single_reference_chars: int = 1500 # 单条参考字符上限
- # 降级策略
- allow_vector_fallback: bool = False
- allow_unscoped_search: bool = False
- # 混合搜索权重(dense=sparse 向量融合)
- dense_weight: float = 0.7
- sparse_weight: float = 0.3
- child_dense_weight: float = 0.6
- child_sparse_weight: float = 0.4
- ranker_type: str = "weighted"
- # 标签召回
- tag_recall_enabled: bool = True
- tag_terms_limit: int = 8
- # RRF 参数
- rrf_k: int = 60
- # 路径权重
- parent_vector_weight: float = 1.0
- child_locator_weight: float = 0.8
- tag_weight: float = 1.2
- chapter_similarity_weight: float = 0.5
- # 加分项
- tag_exact_bonus: float = 0.08
- tag_partial_bonus: float = 0.03
- multi_source_bonus: float = 0.02
- scope_bonus: float = 0.03
- warnings: Optional[Dict[str, str]] = None
- def load_retrieval_config() -> RetrievalConfig:
- """从 YAML 配置文件加载检索参数,文件不存在时使用默认值。"""
- if not RETRIEVAL_CONFIG_PATH.exists():
- return RetrievalConfig(warnings=_default_warnings())
- with open(RETRIEVAL_CONFIG_PATH, "r", encoding="utf-8") as handle:
- raw = yaml.safe_load(handle) or {}
- retrieval = raw.get("retrieval") or {}
- warnings = raw.get("warnings") or _default_warnings()
- return RetrievalConfig(
- enabled=bool(retrieval.get("enabled", True)),
- parent_collection=str(retrieval.get("parent_collection") or "t_kngs_construction_plan_parent"),
- child_collection=str(retrieval.get("child_collection") or "t_kngs_construction_plan_child"),
- parent_recall_top_k=_to_int(retrieval.get("parent_recall_top_k"), 30),
- child_recall_top_k=_to_int(retrieval.get("child_recall_top_k"), 40),
- tag_recall_top_k=_to_int(retrieval.get("tag_recall_top_k"), 30),
- chapter_recall_top_k=_to_int(retrieval.get("chapter_recall_top_k"), 15),
- recall_top_k=_to_int(retrieval.get("recall_top_k"), 30),
- rerank_top_k=_to_int(retrieval.get("rerank_top_k"), 8),
- submit_top_k=_to_int(retrieval.get("submit_top_k"), 3),
- min_vector_similarity=_to_float(retrieval.get("min_vector_similarity"), 0.45),
- min_rerank_score=_to_float(retrieval.get("min_rerank_score"), 0.65),
- min_qualified_count=_to_int(retrieval.get("min_qualified_count"), 1),
- max_reference_chars=_to_int(retrieval.get("max_reference_chars"), 4000),
- max_single_reference_chars=_to_int(retrieval.get("max_single_reference_chars"), 1500),
- allow_vector_fallback=bool(retrieval.get("allow_vector_fallback", False)),
- allow_unscoped_search=bool(retrieval.get("allow_unscoped_search", False)),
- dense_weight=_to_float(retrieval.get("dense_weight"), 0.7),
- sparse_weight=_to_float(retrieval.get("sparse_weight"), 0.3),
- child_dense_weight=_to_float(retrieval.get("child_dense_weight"), 0.6),
- child_sparse_weight=_to_float(retrieval.get("child_sparse_weight"), 0.4),
- ranker_type=str(retrieval.get("ranker_type") or "weighted"),
- tag_recall_enabled=bool(retrieval.get("tag_recall_enabled", True)),
- tag_terms_limit=_to_int(retrieval.get("tag_terms_limit"), 8),
- rrf_k=_to_int(retrieval.get("rrf_k"), 60),
- parent_vector_weight=_to_float(retrieval.get("parent_vector_weight"), 1.0),
- child_locator_weight=_to_float(retrieval.get("child_locator_weight"), 0.8),
- tag_weight=_to_float(retrieval.get("tag_weight"), 1.2),
- chapter_similarity_weight=_to_float(retrieval.get("chapter_similarity_weight"), 0.5),
- tag_exact_bonus=_to_float(retrieval.get("tag_exact_bonus"), 0.08),
- tag_partial_bonus=_to_float(retrieval.get("tag_partial_bonus"), 0.03),
- multi_source_bonus=_to_float(retrieval.get("multi_source_bonus"), 0.02),
- scope_bonus=_to_float(retrieval.get("scope_bonus"), 0.03),
- warnings=warnings,
- )
- class DocumentChatRetrievalService:
- """构建检索查询,从向量库召回高质量候选。
- 核心流程:
- 1. build_query:将用户输入、章节信息、意图拼接为检索 query
- 2. recall:执行多路召回 → RRF 合并 → 去重
- """
- # 父表查询输出字段
- PARENT_OUTPUT_FIELDS = [
- "pk", "text", "document_id", "parent_id", "index", "tag_list",
- "metadata", "file_name", "chapter_title",
- "chapter_level_1", "chapter_level_2", "chapter_level_3",
- ]
- # 子表查询输出字段
- CHILD_OUTPUT_FIELDS = [
- "pk", "text", "document_id", "parent_id", "index", "tag_list",
- "metadata", "file_name", "chapter_title",
- "chapter_level_1", "chapter_level_2", "chapter_level_3",
- ]
- def __init__(self, config: Optional[RetrievalConfig] = None):
- self.config = config or load_retrieval_config()
- # ============================================================
- # Query 构建
- # ============================================================
- def build_query(self, state: Dict[str, Any]) -> str:
- """构建精炼检索 query,避免冗余的项目摘要。
- 拼接内容:
- - 用户原始输入
- - 意图识别后的规范化指令
- - 当前选中章节编号 + 标题
- - 提取的关键词(最多 8 个)
- 去重后截取 120 字符。
- """
- selected_section = state.get("selected_section") or {}
- intent_result = state.get("intent_result") or {}
- keywords = self.build_query_keywords(state)
- parts = [
- state.get("user_message") or "",
- intent_result.get("normalized_instruction") or "",
- f"{selected_section.get('index', '')} {selected_section.get('title', '')}".strip(),
- " ".join(keywords[:8]),
- ]
- return _dedupe_join(parts, max_chars=120)
- def build_query_keywords(self, state: Dict[str, Any], query: Optional[str] = None) -> List[str]:
- """从多来源提取检索关键词。
- 来源优先级:
- 1. 用户输入
- 2. 意图规范化指令
- 3. 章节编号 + 标题
- 4. 章节正文内容(前 500 字)
- 5. 已拼接的 query
- 6. 历史对话中用户消息(排除 AI 回复,防止助手建议污染检索)
- 关键词提取规则见 _extract_retrieval_keywords。
- """
- selected_section = state.get("selected_section") or {}
- intent_result = state.get("intent_result") or {}
- history = state.get("conversation_history") or []
- sources = [
- state.get("user_message") or "",
- intent_result.get("normalized_instruction") or "",
- f"{selected_section.get('index', '')} {selected_section.get('title', '')}",
- str(selected_section.get("content") or "")[:500],
- query or "",
- ]
- if history:
- for turn in history[-6:]:
- if not isinstance(turn, dict):
- continue
- role = str(turn.get("role") or turn.get("sender") or "").lower()
- # 仅取用户消息,跳过 AI 助手回复
- if role in ("assistant", "ai", "bot", "model"):
- continue
- content = str(turn.get("content") or turn.get("message") or "")
- if content:
- sources.append(content)
- keywords: List[str] = []
- seen = set()
- for text in sources:
- for keyword in _extract_retrieval_keywords(str(text or "")):
- normalized = keyword.strip()
- if not normalized or normalized in seen:
- continue
- seen.add(normalized)
- keywords.append(normalized)
- if len(keywords) >= 20:
- return keywords
- return keywords
- # ============================================================
- # 主召回入口
- # ============================================================
- def recall(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """执行多路向量召回,RRF 合并,去重过滤。
- 返回:
- - retrieval_candidates:去重后的候选列表
- - retrieval_status:recalled / no_scope / no_recall / disabled
- - retrieval_metrics:各路径召回统计
- - retrieval_steps:每步详细日志
- """
- if not self.config.enabled:
- return self._empty_result("disabled", [], retrieval_method="disabled")
- query = str(state.get("retrieval_query") or "").strip()
- if not query:
- return self._empty_result("no_recall", [self._warning("no_recall")], retrieval_method="empty_query")
- # 提取检索范围(项目ID、工程类型、章节分类等)
- scope = self._extract_scope(state)
- if not self._has_reliable_scope(scope) and not self.config.allow_unscoped_search:
- return self._empty_result(
- "no_scope",
- [self._warning("no_scope")],
- retrieval_method="no_scope",
- retrieval_scope=scope,
- )
- keywords = list(state.get("retrieval_keywords") or self.build_query_keywords(state, query))
- retrieval_steps: List[Dict[str, Any]] = []
- source_results: Dict[str, List[Dict[str, Any]]] = {}
- # ===== 四路召回 =====
- source_results["parent_vector"] = self._run_recall_path(
- "parent_vector",
- lambda: self._recall_by_parent_vector(scope, query),
- retrieval_steps,
- query=query,
- scope=scope,
- )
- source_results["child_locator"] = self._run_recall_path(
- "child_locator",
- lambda: self._recall_by_child_locator(scope, query),
- retrieval_steps,
- query=query,
- scope=scope,
- )
- if self.config.tag_recall_enabled:
- source_results["tag"] = self._run_recall_path(
- "tag",
- lambda: self._recall_by_tag(scope, keywords),
- retrieval_steps,
- query=" ".join(keywords[: self.config.tag_terms_limit]),
- scope=scope,
- )
- if scope.get("chapter_level_1") and scope.get("chapter_level_2"):
- source_results["chapter_similarity"] = self._run_recall_path(
- "chapter_similarity",
- lambda: self._recall_by_chapter(scope, query),
- retrieval_steps,
- query=query,
- scope=scope,
- )
- # ===== RRF 合并 + 去重 =====
- merged_candidates = self._merge_recall_results(source_results, scope, keywords)
- cleaned = self._clean_candidates(merged_candidates)
- retrieval_steps.append(
- {
- "step": "rrf_merge",
- "query": query,
- "scope": {key: value for key, value in scope.items() if value},
- "count": len(merged_candidates),
- "items": _pack_log_items(merged_candidates),
- }
- )
- retrieval_steps.append(
- {
- "step": "clean_candidates",
- "count": len(cleaned),
- "items": _pack_log_items(cleaned),
- }
- )
- if not cleaned:
- return self._empty_result(
- "no_recall",
- [self._warning("no_recall")],
- retrieval_method="multi_path_rrf",
- retrieval_scope=scope,
- retrieval_steps=retrieval_steps,
- )
- source_counts = {source: len(items or []) for source, items in source_results.items()}
- # 日志:区分请求的 scope、实际应用的过滤、实际召回的文件
- applied_expr = self._build_filter_expr(scope)
- actual_files = list(dict.fromkeys(
- str(item.get("source", ""))[:40]
- for item in cleaned
- if item.get("source")
- ))[:5]
- logger.info(
- f"[DocumentChat] recall completed: method=multi_path_rrf "
- f"requested_scope={dict((k, v) for k, v in scope.items() if v)} "
- f"applied_filter='{applied_expr}' "
- f"actual_sources={actual_files} "
- f"source_counts={source_counts} "
- f"total={len(cleaned)} max_sim={max((item.get('vector_similarity', 0.0) for item in cleaned), default=0.0):.4f}"
- )
- metrics = {
- "recall_count": len(cleaned),
- "merged_count": len(merged_candidates),
- "source_counts": source_counts,
- "max_vector_similarity": max((item.get("vector_similarity", 0.0) for item in cleaned), default=0.0),
- "max_fusion_score": max((item.get("fusion_score", 0.0) for item in cleaned), default=0.0),
- "scope": {key: value for key, value in scope.items() if value},
- "retrieval_method": "multi_path_rrf",
- }
- return {
- "retrieval_candidates": cleaned,
- "retrieval_steps": retrieval_steps,
- "retrieval_status": "recalled",
- "retrieval_method": "multi_path_rrf",
- "retrieval_metrics": metrics,
- "warnings": [],
- }
- def _run_recall_path(
- self,
- step: str,
- func: Callable[[], List[Dict[str, Any]]],
- retrieval_steps: List[Dict[str, Any]],
- query: str,
- scope: Dict[str, Any],
- ) -> List[Dict[str, Any]]:
- """执行单路召回,异常时不阻断其他路径。"""
- try:
- candidates = func() or []
- retrieval_steps.append(
- {
- "step": step,
- "query": query,
- "scope": {key: value for key, value in scope.items() if value},
- "count": len(candidates),
- "items": _pack_log_items(candidates),
- }
- )
- return candidates
- except Exception as exc:
- logger.warning(f"[DocumentChat] {step} recall failed: {exc}", exc_info=True)
- retrieval_steps.append(
- {
- "step": step,
- "query": query,
- "scope": {key: value for key, value in scope.items() if value},
- "count": 0,
- "error": str(exc),
- "items": [],
- }
- )
- return []
- # ============================================================
- # 四路召回具体实现
- # ============================================================
- def _recall_by_parent_vector(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]:
- """父表向量检索:Milvus 混合搜索(dense + sparse),直接返回父表文档。"""
- from foundation.database.base.vector.milvus_vector import MilvusVectorManager
- expr = self._build_filter_expr(scope)
- results = MilvusVectorManager().hybrid_search(
- param={"collection_name": self.config.parent_collection, "expr": expr},
- query_text=query,
- top_k=self.config.parent_recall_top_k,
- ranker_type=self.config.ranker_type,
- dense_weight=self.config.dense_weight,
- sparse_weight=self.config.sparse_weight,
- )
- return [
- self._candidate_from_vector_row(row, "parent_vector", scope)
- for row in results
- if str(row.get("text_content") or "").strip()
- ]
- def _recall_by_child_locator(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]:
- """子表向量定位 + 父表反查:先用 query 在子表中找到匹配片段,
- 再通过 parent_id 反查父表行,获取完整的父文档内容。
- 优势:子表粒度更细,能精确定位到段落级别,然后拉取对应父文档的完整内容。
- """
- from foundation.database.base.vector.milvus_vector import MilvusVectorManager
- expr = self._build_filter_expr(scope)
- child_rows = MilvusVectorManager().hybrid_search(
- param={"collection_name": self.config.child_collection, "expr": expr},
- query_text=query,
- top_k=self.config.child_recall_top_k,
- ranker_type=self.config.ranker_type,
- dense_weight=self.config.child_dense_weight,
- sparse_weight=self.config.child_sparse_weight,
- )
- # 按 parent_id 分组子表命中结果
- child_groups: Dict[str, List[Dict[str, Any]]] = {}
- for row in child_rows:
- metadata = self._normalize_row_metadata(row.get("metadata") or {})
- parent_id = str(self._metadata_value(metadata, "parent_id") or "").strip()
- if not parent_id:
- continue
- child_groups.setdefault(parent_id, []).append(row)
- # 通过 parent_id 反查父表
- parent_rows = self._fetch_parent_rows_by_parent_ids(list(child_groups.keys()), scope)
- candidates = []
- for parent_row in parent_rows:
- parent_id = str(parent_row.get("parent_id") or "").strip()
- matches = child_groups.get(parent_id) or []
- max_similarity = max((_to_float(item.get("similarity"), 0.0) for item in matches), default=0.0)
- candidate = self._candidate_from_parent_row(parent_row, "child_locator", scope, max_similarity)
- metadata = candidate.setdefault("metadata", {})
- metadata["child_hit_count"] = len(matches) # 子表命中次数
- metadata["matched_child_texts"] = [
- str(item.get("text_content") or "").strip()
- for item in matches[:5]
- if str(item.get("text_content") or "").strip()
- ]
- candidates.append(candidate)
- return candidates
- def _recall_by_tag(self, scope: Dict[str, Any], keywords: List[str]) -> List[Dict[str, Any]]:
- """标签关键词召回:从关键词中筛选标准号、设备名等专业术语,
- 在 tag_list 字段上做 LIKE 匹配。
- 注意:标签召回容易过度匹配,因此结果相似度乘以 0.7 打折。
- """
- tag_terms = self._select_tag_terms(keywords)
- if not tag_terms:
- return []
- tag_expr = self._build_tag_expr(tag_terms)
- scope_expr = self._build_filter_expr(scope)
- expr = _combine_expr(scope_expr, tag_expr)
- # 父表标签匹配
- parent_rows = self._condition_query(
- collection_name=self.config.parent_collection,
- filter_expr=expr,
- output_fields=self.PARENT_OUTPUT_FIELDS,
- limit=self.config.tag_recall_top_k,
- )
- candidates = [
- self._candidate_from_parent_row(row, "tag", scope, self.config.min_vector_similarity)
- for row in parent_rows
- ]
- # 子表标签匹配,再反查父行
- child_rows = self._condition_query(
- collection_name=self.config.child_collection,
- filter_expr=expr,
- output_fields=self.CHILD_OUTPUT_FIELDS,
- limit=self.config.tag_recall_top_k,
- )
- child_parent_ids = []
- child_tag_map: Dict[str, List[str]] = {}
- for row in child_rows:
- parent_id = str(row.get("parent_id") or self._metadata_value(row, "parent_id") or "").strip()
- if not parent_id:
- continue
- child_parent_ids.append(parent_id)
- text = str(row.get("text") or "").strip()
- if text:
- child_tag_map.setdefault(parent_id, []).append(text)
- for row in self._fetch_parent_rows_by_parent_ids(child_parent_ids, scope):
- parent_id = str(row.get("parent_id") or "").strip()
- candidate = self._candidate_from_parent_row(row, "tag", scope, self.config.min_vector_similarity)
- metadata = candidate.setdefault("metadata", {})
- metadata["matched_child_texts"] = child_tag_map.get(parent_id, [])[:5]
- candidates.append(candidate)
- # 标签结果打折,防止过度匹配
- for candidate in candidates:
- candidate["vector_similarity"] *= 0.7
- # 记录匹配的标签术语
- for candidate in candidates:
- metadata = candidate.setdefault("metadata", {})
- tag_text = " ".join(
- str(value or "")
- for value in (
- metadata.get("tag_list"),
- candidate.get("text"),
- " ".join(metadata.get("matched_child_texts") or []),
- )
- )
- metadata["tag_match_terms"] = [term for term in tag_terms if term and term in tag_text]
- return candidates
- def _recall_by_chapter(self, scope: Dict[str, Any], query: str) -> List[Dict[str, Any]]:
- """章节相似度检索:调用现有 similar_fragment_service,
- 按 chapter_level_1 + chapter_level_2 限定范围搜索相似片段。
- """
- from core.construction_write.component.similar_fragment_service import search_similar_fragments
- rows = search_similar_fragments(
- level1=str(scope.get("chapter_level_1") or ""),
- level2=str(scope.get("chapter_level_2") or ""),
- search_text=query,
- top_k=self.config.chapter_recall_top_k,
- )
- candidates = []
- for row in rows:
- metadata = {
- "tenant_id": scope.get("tenant_id") or "",
- "project_id": scope.get("project_id") or "",
- "knowledge_base_id": scope.get("knowledge_base_id") or "",
- "file_name": row.get("file_name") or "",
- "chapter_level_1": row.get("chapter_level_1") or scope.get("chapter_level_1") or "",
- "chapter_level_2": row.get("chapter_level_2") or scope.get("chapter_level_2") or "",
- "parent_count": row.get("parent_count", 0),
- "source_scope_valid": True, # 通过章节分类限定,天然 scope 匹配
- }
- text = str(row.get("text") or "").strip()
- candidates.append(
- {
- "candidate_key": self._build_candidate_key({**row, "metadata": metadata}, text),
- "text": text,
- "source": metadata.get("file_name") or "向量知识库",
- "vector_similarity": _to_float(row.get("similarity"), 0.0),
- "fusion_score": 0.0,
- "metadata": metadata,
- "source_hits": {},
- "retrieval_source": "chapter_similarity",
- }
- )
- return candidates
- # ============================================================
- # RRF 合并
- # ============================================================
- def _merge_recall_results(
- self,
- source_results: Dict[str, List[Dict[str, Any]]],
- scope: Dict[str, Any],
- keywords: List[str],
- ) -> List[Dict[str, Any]]:
- """多路召回结果 RRF 融合合并。
- 融合分数计算:
- - 基础分:weight / (rrf_k + rank),按路径权重和排名计算
- - 多源加分:同一条候选在多个路径中被召回时额外加分
- - Scope 加分:与当前项目范围一致时额外加分
- - 标签加分:关键词出现在候选文本中时额外加分
- """
- weights = {
- "parent_vector": self.config.parent_vector_weight,
- "child_locator": self.config.child_locator_weight,
- "tag": self.config.tag_weight,
- "chapter_similarity": self.config.chapter_similarity_weight,
- }
- merged: Dict[str, Dict[str, Any]] = {}
- for source, candidates in source_results.items():
- weight = weights.get(source, 0.0)
- for rank, item in enumerate(candidates or [], start=1):
- key = str(item.get("candidate_key") or self._build_candidate_key(item, item.get("text")))
- if not key:
- continue
- if key not in merged:
- candidate = dict(item)
- candidate["candidate_key"] = key
- candidate["source_hits"] = {}
- candidate["fusion_score"] = 0.0
- merged[key] = candidate
- current = merged[key]
- # RRF 公式:累加 weight / (rrf_k + rank)
- current["fusion_score"] = _to_float(current.get("fusion_score"), 0.0) + weight / (self.config.rrf_k + rank)
- current["vector_similarity"] = max(
- _to_float(current.get("vector_similarity"), 0.0),
- _to_float(item.get("vector_similarity"), 0.0),
- )
- current.setdefault("source_hits", {})[source] = {
- "rank": rank,
- "vector_similarity": _to_float(item.get("vector_similarity"), 0.0),
- }
- self._merge_metadata(current, item)
- # 额外加分
- for candidate in merged.values():
- source_hits = candidate.get("source_hits") if isinstance(candidate.get("source_hits"), dict) else {}
- metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
- if len(source_hits) > 1:
- candidate["fusion_score"] += self.config.multi_source_bonus
- if self._metadata_matches_scope(metadata, scope):
- candidate["fusion_score"] += self.config.scope_bonus
- candidate["fusion_score"] += self._calc_tag_bonus(candidate, keywords)
- return sorted(merged.values(), key=lambda item: item.get("fusion_score", 0.0), reverse=True)[: self.config.recall_top_k]
- # ============================================================
- # Milvus 查询辅助
- # ============================================================
- def _fetch_parent_rows_by_parent_ids(self, parent_ids: List[str], scope: Dict[str, Any]) -> List[Dict[str, Any]]:
- """根据 parent_id 列表反查父表行,去重后逐条查询。"""
- unique_ids = []
- seen = set()
- for parent_id in parent_ids:
- value = str(parent_id or "").strip()
- if value and value not in seen:
- seen.add(value)
- unique_ids.append(value)
- rows: List[Dict[str, Any]] = []
- scope_expr = self._build_filter_expr(scope)
- for parent_id in unique_ids[: self.config.recall_top_k]:
- parent_expr = f"parent_id == '{_escape_milvus_string(parent_id)}'"
- expr = _combine_expr(parent_expr, scope_expr)
- rows.extend(
- self._condition_query(
- collection_name=self.config.parent_collection,
- filter_expr=expr,
- output_fields=self.PARENT_OUTPUT_FIELDS,
- limit=100,
- )
- )
- return rows
- def _condition_query(
- self,
- collection_name: str,
- filter_expr: str,
- output_fields: List[str],
- limit: int,
- ) -> List[Dict[str, Any]]:
- """Milvus 条件查询(非向量),按 filter 表达式筛选文档。"""
- from core.construction_write.component.similar_fragment_service import get_milvus_manager
- if not filter_expr:
- return []
- return get_milvus_manager().condition_query(
- collection_name=collection_name,
- filter=filter_expr,
- output_fields=output_fields,
- limit=limit,
- )
- # ============================================================
- # 候选构建
- # ============================================================
- def _candidate_from_vector_row(self, row: Dict[str, Any], source: str, scope: Dict[str, Any]) -> Dict[str, Any]:
- """从 Milvus 混合搜索结果行构建标准候选。"""
- metadata = self._normalize_row_metadata(row.get("metadata") or {})
- text = str(row.get("text_content") or row.get("text") or "").strip()
- metadata["source_scope_valid"] = self._metadata_matches_scope(metadata, scope)
- return {
- "candidate_key": self._build_candidate_key(metadata, text),
- "text": text,
- "source": metadata.get("file_name") or metadata.get("title") or "向量知识库",
- "vector_similarity": _to_float(row.get("similarity"), 0.0),
- "fusion_score": 0.0,
- "metadata": metadata,
- "source_hits": {},
- "retrieval_source": source,
- }
- def _candidate_from_parent_row(
- self,
- row: Dict[str, Any],
- source: str,
- scope: Dict[str, Any],
- vector_similarity: float,
- ) -> Dict[str, Any]:
- """从父表行构建标准候选。"""
- metadata = self._normalize_row_metadata(row)
- text = str(row.get("text") or "").strip()
- metadata["source_scope_valid"] = self._metadata_matches_scope(metadata, scope)
- return {
- "candidate_key": self._build_candidate_key(metadata, text),
- "text": text,
- "source": metadata.get("file_name") or "向量知识库",
- "vector_similarity": _to_float(vector_similarity, 0.0),
- "fusion_score": 0.0,
- "metadata": metadata,
- "source_hits": {},
- "retrieval_source": source,
- }
- # ============================================================
- # Scope 提取与过滤
- # ============================================================
- def _extract_scope(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """从工作流状态中提取检索范围信息。
- 按优先级从 selected_section、document_context、project_info、retrieval_filters
- 中查找字段值,兼容多种字段命名。
- """
- selected = state.get("selected_section") or {}
- context = state.get("document_context") or {}
- project = state.get("project_info") or {}
- filters = context.get("retrieval_filters") if isinstance(context.get("retrieval_filters"), dict) else {}
- filters = filters or project.get("retrieval_filters") if isinstance(project.get("retrieval_filters"), dict) else filters
- def pick(*keys: str) -> str:
- for source in (selected, context, project, filters or {}):
- for key in keys:
- value = source.get(key) if isinstance(source, dict) else None
- if value not in (None, ""):
- return str(value).strip()
- return ""
- return {
- "tenant_id": pick("tenant_id"),
- "project_id": pick("project_id"),
- "knowledge_base_id": pick("knowledge_base_id", "kb_id"),
- "engineering_type": pick("engineering_type", "project_type"),
- "plan_type": pick("plan_type"),
- "chapter_level_1": pick("chapter_level_1", "level1"),
- "chapter_level_2": pick("chapter_level_2", "level2"),
- "chapter_level_3": pick("chapter_level_3", "level3"),
- }
- @staticmethod
- def _has_reliable_scope(scope: Dict[str, Any]) -> bool:
- """判断是否有足够可靠的 scope 用于限定检索范围。"""
- if scope.get("chapter_level_1") and scope.get("chapter_level_2"):
- return True
- return bool(scope.get("plan_type"))
- def _build_filter_expr(self, scope: Dict[str, Any]) -> str:
- """构建 Milvus 过滤表达式,按章节层级限定检索范围。"""
- conditions = []
- for key in ("plan_type", "chapter_level_1", "chapter_level_2", "chapter_level_3"):
- value = str(scope.get(key) or "").strip()
- if value:
- conditions.append(f"{key} == '{_escape_milvus_string(value)}'")
- return " and ".join(conditions)
- def _build_tag_expr(self, tag_terms: List[str]) -> str:
- """构建标签 LIKE 查询表达式。"""
- conditions = []
- for term in tag_terms[: self.config.tag_terms_limit]:
- conditions.append(f'tag_list like "%{_escape_milvus_string(term)}%"')
- return " or ".join(conditions)
- def _select_tag_terms(self, keywords: List[str]) -> List[str]:
- """从关键词中筛选高价值标签术语。
- 排除:验收、标准、规范等通用词(几乎匹配所有文档)
- 优先:标准号(如 TB10212-2012)、设备名(架桥机、龙门吊等)
- """
- generic_terms = {
- "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除",
- "要求", "安全", "环保", "质量", "进度", "交底",
- }
- device_terms = {"架桥机", "龙门吊", "吊车", "塔吊", "施工电梯", "挂篮", "支架", "台车"}
- selected = []
- priority = [] # 标准号和设备名优先
- seen = set()
- for keyword in keywords:
- value = str(keyword or "").strip()
- if len(value) < 2 or value in seen:
- continue
- seen.add(value)
- if value in generic_terms:
- continue
- if re.match(r"[A-Z]{1,3}\d{4,}", value) or value in device_terms:
- priority.append(value)
- elif len(selected) < self.config.tag_terms_limit:
- selected.append(value)
- return priority + selected
- @staticmethod
- def _metadata_matches_scope(metadata: Dict[str, Any], scope: Dict[str, Any]) -> bool:
- """检查候选 metadata 是否与当前检索 scope 兼容。
- 不要求所有字段都匹配,仅校验 scope 和 metadata 同时存在且不一致的字段。
- """
- required_keys = ["tenant_id", "project_id", "knowledge_base_id", "chapter_level_1", "chapter_level_2", "chapter_level_3"]
- for key in required_keys:
- expected = str(scope.get(key) or "").strip()
- if not expected:
- continue
- actual = str(metadata.get(key) or "").strip()
- if actual and actual != expected:
- return False
- return True
- # ============================================================
- # Metadata 处理
- # ============================================================
- def _normalize_row_metadata(self, row_or_metadata: Any) -> Dict[str, Any]:
- """规范化行数据为统一的 metadata 字典。处理嵌套 metadata 和 YAML 字符串。"""
- metadata = self._normalize_metadata(row_or_metadata)
- inner = self._normalize_metadata(metadata.get("metadata")) if metadata.get("metadata") else {}
- for key, value in inner.items():
- metadata.setdefault(key, value)
- for key in self.PARENT_OUTPUT_FIELDS:
- if isinstance(row_or_metadata, dict) and row_or_metadata.get(key) not in (None, ""):
- metadata[key] = row_or_metadata.get(key)
- return metadata
- @staticmethod
- def _normalize_metadata(metadata: Any) -> Dict[str, Any]:
- """将 metadata 转为字典,支持 YAML 字符串解析。"""
- if isinstance(metadata, dict):
- return dict(metadata)
- if isinstance(metadata, str) and metadata.strip():
- try:
- loaded = yaml.safe_load(metadata)
- return dict(loaded) if isinstance(loaded, dict) else {}
- except Exception:
- return {}
- return {}
- @staticmethod
- def _metadata_value(metadata: Dict[str, Any], key: str) -> Any:
- """安全获取 metadata 值,支持嵌套 metadata.metadata 和 YAML 字符串。"""
- if key in metadata:
- return metadata.get(key)
- nested = metadata.get("metadata")
- if isinstance(nested, dict):
- return nested.get(key)
- if isinstance(nested, str) and nested.strip():
- try:
- parsed = yaml.safe_load(nested)
- if isinstance(parsed, dict):
- return parsed.get(key)
- except Exception:
- return None
- return None
- def _build_candidate_key(self, metadata: Dict[str, Any], text: Any = "") -> str:
- """构建候选唯一标识键,按优先级尝试不同字段组合。"""
- metadata = self._normalize_row_metadata(metadata)
- document_id = str(self._metadata_value(metadata, "document_id") or "").strip()
- parent_id = str(self._metadata_value(metadata, "parent_id") or "").strip()
- chunk_id = str(self._metadata_value(metadata, "chunk_id") or "").strip()
- chapter_title = str(self._metadata_value(metadata, "chapter_title") or "").strip()
- index = self._metadata_value(metadata, "index")
- pk = str(self._metadata_value(metadata, "pk") or "").strip()
- if document_id and parent_id and chunk_id:
- return f"{document_id}::{parent_id}::{chunk_id}"
- if document_id and parent_id and chapter_title and index not in (None, ""):
- return f"{document_id}::{parent_id}::{chapter_title}::{index}"
- if pk:
- return pk
- if parent_id and chapter_title and index not in (None, ""):
- return f"{parent_id}::{chapter_title}::{index}"
- return str(text or "")[:300]
- def _merge_metadata(self, current: Dict[str, Any], incoming: Dict[str, Any]) -> None:
- """合并两条候选的 metadata,不覆盖已有非空值。"""
- current_meta = current.setdefault("metadata", {})
- incoming_meta = incoming.get("metadata") if isinstance(incoming.get("metadata"), dict) else {}
- for key, value in incoming_meta.items():
- if key not in current_meta or current_meta.get(key) in (None, "", []):
- current_meta[key] = value
- if incoming.get("source") and not current.get("source"):
- current["source"] = incoming.get("source")
- # ============================================================
- # 加分计算
- # ============================================================
- def _calc_tag_bonus(self, candidate: Dict[str, Any], keywords: List[str]) -> float:
- """计算标签匹配加分:关键词精确匹配 tag_list 加分更多,仅出现在文本中加分较少。"""
- metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
- text = " ".join(
- str(value or "")
- for value in (
- candidate.get("text"),
- metadata.get("tag_list"),
- " ".join(metadata.get("matched_child_texts") or []),
- )
- )
- bonus = 0.0
- for keyword in self._select_tag_terms(keywords):
- if not keyword:
- continue
- if keyword in str(metadata.get("tag_list") or ""):
- bonus += self.config.tag_exact_bonus
- elif keyword in text:
- bonus += self.config.tag_partial_bonus
- return bonus
- # ============================================================
- # 候选清理
- # ============================================================
- def _clean_candidates(self, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """清理候选:过滤过短文本、双重去重(candidate_key + 内容哈希)。
- 去重策略:
- 1. candidate_key 去重:相同 document+parent+chunk 视为同一条
- 2. 内容哈希去重:同一文件同一文本内容(即使路径不同)只保留一条
- """
- cleaned = []
- seen_keys = set()
- seen_hashes = set()
- for item in candidates:
- text = str(item.get("text") or "").strip()
- if len(text) < 20:
- continue
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
- dedupe_key = str(item.get("candidate_key") or text[:300])
- # 内容哈希去重
- content_hash = _content_hash(text[:300])
- file_name = str(metadata.get("file_name") or "")
- hash_key = f"{file_name}::{content_hash}"
- if dedupe_key in seen_keys or hash_key in seen_hashes:
- continue
- seen_keys.add(dedupe_key)
- seen_hashes.add(hash_key)
- metadata["candidate_key"] = dedupe_key
- cleaned.append(
- {
- "candidate_key": dedupe_key,
- "text": text[: self.config.max_single_reference_chars],
- "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
- "vector_similarity": _to_float(item.get("vector_similarity"), 0.0),
- "fusion_score": _to_float(item.get("fusion_score"), 0.0),
- "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {},
- "metadata": metadata,
- }
- )
- cleaned.sort(key=lambda item: (item.get("fusion_score", 0.0), item.get("vector_similarity", 0.0)), reverse=True)
- return cleaned[: self.config.recall_top_k]
- # ============================================================
- # 空结果/告警
- # ============================================================
- def _empty_result(
- self,
- status: str,
- warnings: List[str],
- retrieval_method: str = "",
- retrieval_scope: Optional[Dict[str, Any]] = None,
- retrieval_steps: Optional[List[Dict[str, Any]]] = None,
- ) -> Dict[str, Any]:
- """构建空召回结果。"""
- return {
- "retrieval_candidates": [],
- "retrieval_steps": retrieval_steps or [],
- "retrieval_status": status,
- "retrieval_method": retrieval_method,
- "retrieval_metrics": {
- "recall_count": 0,
- "retrieval_method": retrieval_method,
- "scope": {key: value for key, value in (retrieval_scope or {}).items() if value},
- },
- "warnings": warnings,
- }
- def _warning(self, key: str) -> str:
- """获取指定键的告警文案。"""
- warnings = self.config.warnings or _default_warnings()
- return warnings.get(key) or ""
- def _default_warnings() -> Dict[str, str]:
- return {
- "no_scope": "缺少可靠的知识库检索范围,本次未引用向量库内容。",
- "no_recall": "未召回可信知识库内容,本次回答不引用向量库。",
- "low_confidence": "未找到可信度足够的知识库片段,本次未引用向量库内容。",
- "rerank_failed": "知识库片段重排不可用,本次未引用向量库内容。",
- }
- def _escape_milvus_string(value: str) -> str:
- """转义 Milvus 字符串中的特殊字符(反斜杠、单引号、双引号)。"""
- return str(value).replace("\\", "\\\\").replace("'", "\\'").replace('"', '\\"')
- def _combine_expr(*exprs: str) -> str:
- """用 AND 连接多个过滤表达式,每个子表达式加括号。"""
- parts = [f"({expr})" for expr in exprs if str(expr or "").strip()]
- return " and ".join(parts)
- def _dedupe_join(parts: List[str], max_chars: int) -> str:
- """去重后拼接文本片段,限制总长度。"""
- values = []
- seen = set()
- for part in parts:
- value = re.sub(r"\s+", " ", str(part or "")).strip()
- if not value or value in seen:
- continue
- seen.add(value)
- values.append(value)
- return " ".join(values)[:max_chars]
- def _extract_retrieval_keywords(text: str) -> List[str]:
- """从文本中提取检索关键词,支持多种模式:
- 1. 标准号/型号:如 TB10212-2012、φ48.3×3.6
- 2. 规范名称:《XXX规范》
- 3. 领域专业术语:验收、架桥机、箱梁等
- 4. 术语+动作组合:XX验收、XX安装
- 5. 长词中的领域术语片段
- """
- if not text:
- return []
- keywords: List[str] = []
- # 模式1:标准号/型号(字母+数字,可选连字符)
- for match in re.findall(r"[A-Za-z]{1,8}\s*\d{2,8}(?:[-—]\d{2,4})?", text):
- keywords.append(re.sub(r"\s+", "", match).upper())
- # 模式2:《XXX》规范名称
- for match in re.findall(r"《([^》]{2,40})》", text):
- keywords.append(match.strip())
- # 模式3:领域专业术语
- domain_terms = (
- "验收", "标准", "规范", "检查", "检测", "试验", "安装", "拆除", "吊装",
- "架桥机", "龙门吊", "吊车", "箱梁", "T梁", "梁板", "钢丝绳", "支座",
- "地基", "安全装置", "操作证", "合格证", "静载", "动载", "空载",
- )
- for term in domain_terms:
- if term in text:
- keywords.append(term)
- # 模式4:术语+动作组合
- for match in re.findall(r"[一-鿿A-Za-z0-9.-]{0,12}(?:验收|标准|规范|检查|检测|试验|安装|拆除|吊装|要求)", text):
- if 2 <= len(match) <= 20:
- keywords.append(match)
- # 模式5:分词后含领域术语的片段
- normalized = re.sub(r"[\s,,。;;::、/\\|()\[\]{}<>《》\"'""??]+", " ", text)
- for token in normalized.split():
- token = token.strip()
- if len(token) < 2 or len(token) > 12:
- continue
- if any(term in token for term in domain_terms):
- keywords.append(token)
- seen = set()
- unique = []
- for keyword in keywords:
- keyword = keyword.strip()
- if keyword and keyword not in seen:
- seen.add(keyword)
- unique.append(keyword)
- return unique
- def _pack_log_items(items: List[Dict[str, Any]], limit: int = 20, text_limit: int = 1500) -> List[Dict[str, Any]]:
- """打包候选条目为日志格式,限制条数和文本长度。"""
- packed = []
- for item in (items or [])[:limit]:
- if not isinstance(item, dict):
- continue
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
- text = str(item.get("text") or item.get("text_content") or item.get("content") or "").strip()
- packed.append(
- {
- "candidate_key": item.get("candidate_key"),
- "source": item.get("source") or metadata.get("file_name") or "",
- "text": text[:text_limit],
- "vector_similarity": _to_float(item.get("vector_similarity", item.get("similarity")), 0.0),
- "fusion_score": _to_float(item.get("fusion_score"), 0.0),
- "rerank_score": _to_float(item.get("rerank_score"), 0.0) if "rerank_score" in item else None,
- "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {},
- "metadata": {
- key: metadata.get(key)
- for key in (
- "document_id", "parent_id", "file_name", "chapter_title",
- "chapter_level_1", "chapter_level_2", "chapter_level_3",
- "parent_count", "child_hit_count", "matched_child_texts",
- "tag_match_terms", "source_scope_valid",
- )
- if metadata.get(key) not in (None, "")
- },
- }
- )
- return packed
- def _to_int(value: Any, default: int) -> int:
- """安全整数转换。"""
- try:
- return int(value)
- except (TypeError, ValueError):
- return default
- def _to_float(value: Any, default: float = 0.0) -> float:
- """安全浮点数转换。"""
- try:
- return float(value)
- except (TypeError, ValueError):
- return default
- def _content_hash(text: str) -> str:
- """基于归一化文本的短 MD5 哈希,用于内容去重。"""
- normalized = re.sub(r"\s+", " ", text.strip().lower())
- return md5(normalized.encode("utf-8")).hexdigest()[:12]
|