fusion.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # -*- coding: utf-8 -*-
  2. """多路文档对话检索结果的 RRF 融合合并。"""
  3. from __future__ import annotations
  4. from typing import Any, Dict, List
  5. from core.document_chat.retrieval.candidate import build_candidate_key, merge_metadata
  6. from core.document_chat.retrieval.config import RetrievalConfig
  7. from core.document_chat.retrieval.scope import metadata_matches_scope, select_tag_terms
  8. from core.document_chat.retrieval.utils import to_float
  9. def merge_recall_results(
  10. source_results: Dict[str, List[Dict[str, Any]]],
  11. scope: Dict[str, Any],
  12. keywords: List[str],
  13. config: RetrievalConfig,
  14. ) -> List[Dict[str, Any]]:
  15. """多路召回结果 RRF 融合合并。"""
  16. weights = {
  17. "parent_vector": config.parent_vector_weight,
  18. "child_locator": config.child_locator_weight,
  19. "tag": config.tag_weight,
  20. "chapter_similarity": config.chapter_similarity_weight,
  21. }
  22. merged: Dict[str, Dict[str, Any]] = {}
  23. for source, candidates in source_results.items():
  24. weight = weights.get(source, 0.0)
  25. for rank, item in enumerate(candidates or [], start=1):
  26. key = str(item.get("candidate_key") or build_candidate_key(item, item.get("text")))
  27. if not key:
  28. continue
  29. if key not in merged:
  30. candidate = dict(item)
  31. candidate["candidate_key"] = key
  32. candidate["source_hits"] = {}
  33. candidate["fusion_score"] = 0.0
  34. merged[key] = candidate
  35. current = merged[key]
  36. current["fusion_score"] = to_float(current.get("fusion_score"), 0.0) + weight / (config.rrf_k + rank)
  37. current["vector_similarity"] = max(
  38. to_float(current.get("vector_similarity"), 0.0),
  39. to_float(item.get("vector_similarity"), 0.0),
  40. )
  41. current.setdefault("source_hits", {})[source] = {
  42. "rank": rank,
  43. "vector_similarity": to_float(item.get("vector_similarity"), 0.0),
  44. }
  45. merge_metadata(current, item)
  46. for candidate in merged.values():
  47. source_hits = candidate.get("source_hits") if isinstance(candidate.get("source_hits"), dict) else {}
  48. metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
  49. if len(source_hits) > 1:
  50. candidate["fusion_score"] += config.multi_source_bonus
  51. if metadata_matches_scope(metadata, scope):
  52. candidate["fusion_score"] += config.scope_bonus
  53. candidate["fusion_score"] += calc_tag_bonus(candidate, keywords, config)
  54. return sorted(merged.values(), key=lambda item: item.get("fusion_score", 0.0), reverse=True)[: config.recall_top_k]
  55. def calc_tag_bonus(candidate: Dict[str, Any], keywords: List[str], config: RetrievalConfig) -> float:
  56. """计算标签匹配加分:关键词精确匹配 tag_list 加分更多,仅出现在文本中加分较少。"""
  57. metadata = candidate.get("metadata") if isinstance(candidate.get("metadata"), dict) else {}
  58. text = " ".join(
  59. str(value or "")
  60. for value in (
  61. candidate.get("text"),
  62. metadata.get("tag_list"),
  63. " ".join(metadata.get("matched_child_texts") or []),
  64. )
  65. )
  66. bonus = 0.0
  67. for keyword in select_tag_terms(
  68. keywords,
  69. config.tag_terms_limit,
  70. generic_terms=config.tag_generic_terms,
  71. priority_terms=config.tag_priority_terms,
  72. ):
  73. if not keyword:
  74. continue
  75. if keyword in str(metadata.get("tag_list") or ""):
  76. bonus += config.tag_exact_bonus
  77. elif keyword in text:
  78. bonus += config.tag_partial_bonus
  79. return bonus