| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- # -*- coding: utf-8 -*-
- """多路文档对话检索结果的 RRF 融合合并。"""
- from __future__ import annotations
- from typing import Any, Dict, List
- from core.document_chat.retrieval.candidate import build_candidate_key, merge_metadata
- from core.document_chat.retrieval.config import RetrievalConfig
- from core.document_chat.retrieval.scope import metadata_matches_scope, select_tag_terms
- from core.document_chat.retrieval.utils import to_float
- def merge_recall_results(
- source_results: Dict[str, List[Dict[str, Any]]],
- scope: Dict[str, Any],
- keywords: List[str],
- config: RetrievalConfig,
- ) -> List[Dict[str, Any]]:
- """多路召回结果 RRF 融合合并。"""
- weights = {
- "parent_vector": config.parent_vector_weight,
- "child_locator": config.child_locator_weight,
- "tag": config.tag_weight,
- "chapter_similarity": config.chapter_similarity_weight,
- }
- merged: Dict[str, Dict[str, Any]] = {}
- for source, candidates in source_results.items():
- weight = weights.get(source, 0.0)
- for rank, item in enumerate(candidates or [], start=1):
- key = str(item.get("candidate_key") or build_candidate_key(item, item.get("text")))
- if not key:
- continue
- if key not in merged:
- candidate = dict(item)
- candidate["candidate_key"] = key
- candidate["source_hits"] = {}
- candidate["fusion_score"] = 0.0
- merged[key] = candidate
- current = merged[key]
- current["fusion_score"] = to_float(current.get("fusion_score"), 0.0) + weight / (config.rrf_k + rank)
- current["vector_similarity"] = max(
- to_float(current.get("vector_similarity"), 0.0),
- to_float(item.get("vector_similarity"), 0.0),
- )
- current.setdefault("source_hits", {})[source] = {
- "rank": rank,
- "vector_similarity": to_float(item.get("vector_similarity"), 0.0),
- }
- merge_metadata(current, item)
- for candidate in merged.values():
- source_hits = candidate.get("source_hits") if isinstance(candidate.get("source_hits"), dict) else {}
- metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
- if len(source_hits) > 1:
- candidate["fusion_score"] += config.multi_source_bonus
- if metadata_matches_scope(metadata, scope):
- candidate["fusion_score"] += config.scope_bonus
- candidate["fusion_score"] += calc_tag_bonus(candidate, keywords, config)
- return sorted(merged.values(), key=lambda item: item.get("fusion_score", 0.0), reverse=True)[: config.recall_top_k]
- def calc_tag_bonus(candidate: Dict[str, Any], keywords: List[str], config: RetrievalConfig) -> float:
- """计算标签匹配加分:关键词精确匹配 tag_list 加分更多,仅出现在文本中加分较少。"""
- metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
- text = " ".join(
- str(value or "")
- for value in (
- candidate.get("text"),
- metadata.get("tag_list"),
- " ".join(metadata.get("matched_child_texts") or []),
- )
- )
- bonus = 0.0
- for keyword in select_tag_terms(
- keywords,
- config.tag_terms_limit,
- generic_terms=config.tag_generic_terms,
- priority_terms=config.tag_priority_terms,
- ):
- if not keyword:
- continue
- if keyword in str(metadata.get("tag_list") or ""):
- bonus += config.tag_exact_bonus
- elif keyword in text:
- bonus += config.tag_partial_bonus
- return bonus
|