# -*- coding: utf-8 -*- """多路文档对话检索结果的 RRF 融合合并。""" from __future__ import annotations from typing import Any, Dict, List from core.document_chat.retrieval.candidate import build_candidate_key, merge_metadata from core.document_chat.retrieval.config import RetrievalConfig from core.document_chat.retrieval.scope import metadata_matches_scope, select_tag_terms from core.document_chat.retrieval.utils import to_float def merge_recall_results( source_results: Dict[str, List[Dict[str, Any]]], scope: Dict[str, Any], keywords: List[str], config: RetrievalConfig, ) -> List[Dict[str, Any]]: """多路召回结果 RRF 融合合并。""" weights = { "parent_vector": config.parent_vector_weight, "child_locator": config.child_locator_weight, "tag": config.tag_weight, "chapter_similarity": config.chapter_similarity_weight, } merged: Dict[str, Dict[str, Any]] = {} for source, candidates in source_results.items(): weight = weights.get(source, 0.0) for rank, item in enumerate(candidates or [], start=1): key = str(item.get("candidate_key") or build_candidate_key(item, item.get("text"))) if not key: continue if key not in merged: candidate = dict(item) candidate["candidate_key"] = key candidate["source_hits"] = {} candidate["fusion_score"] = 0.0 merged[key] = candidate current = merged[key] current["fusion_score"] = to_float(current.get("fusion_score"), 0.0) + weight / (config.rrf_k + rank) current["vector_similarity"] = max( to_float(current.get("vector_similarity"), 0.0), to_float(item.get("vector_similarity"), 0.0), ) current.setdefault("source_hits", {})[source] = { "rank": rank, "vector_similarity": to_float(item.get("vector_similarity"), 0.0), } merge_metadata(current, item) for candidate in merged.values(): source_hits = candidate.get("source_hits") if isinstance(candidate.get("source_hits"), dict) else {} metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {} if len(source_hits) > 1: candidate["fusion_score"] += config.multi_source_bonus if metadata_matches_scope(metadata, scope): candidate["fusion_score"] += config.scope_bonus candidate["fusion_score"] += calc_tag_bonus(candidate, keywords, config) return sorted(merged.values(), key=lambda item: item.get("fusion_score", 0.0), reverse=True)[: config.recall_top_k] def calc_tag_bonus(candidate: Dict[str, Any], keywords: List[str], config: RetrievalConfig) -> float: """计算标签匹配加分:关键词精确匹配 tag_list 加分更多,仅出现在文本中加分较少。""" metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {} text = " ".join( str(value or "") for value in ( candidate.get("text"), metadata.get("tag_list"), " ".join(metadata.get("matched_child_texts") or []), ) ) bonus = 0.0 for keyword in select_tag_terms( keywords, config.tag_terms_limit, generic_terms=config.tag_generic_terms, priority_terms=config.tag_priority_terms, ): if not keyword: continue if keyword in str(metadata.get("tag_list") or ""): bonus += config.tag_exact_bonus elif keyword in text: bonus += config.tag_partial_bonus return bonus