candidate.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # -*- coding: utf-8 -*-
  2. """检索结果候选规范化与去重。"""
  3. from __future__ import annotations
  4. from hashlib import md5
  5. import re
  6. from typing import Any, Dict, List, Sequence
  7. import yaml
  8. from core.document_chat.retrieval.config import RetrievalConfig
  9. from core.document_chat.retrieval.utils import to_float
  10. DEFAULT_OUTPUT_FIELDS = (
  11. "pk",
  12. "text",
  13. "document_id",
  14. "parent_id",
  15. "index",
  16. "tag_list",
  17. "metadata",
  18. "file_name",
  19. "chapter_title",
  20. "chapter_level_1",
  21. "chapter_level_2",
  22. "chapter_level_3",
  23. )
  24. def normalize_row_metadata(
  25. row_or_metadata: Any,
  26. output_fields: Sequence[str] = DEFAULT_OUTPUT_FIELDS,
  27. ) -> Dict[str, Any]:
  28. """规范化行数据为统一的 metadata 字典。处理嵌套 metadata 和 YAML 字符串。"""
  29. metadata = normalize_metadata(row_or_metadata)
  30. inner = normalize_metadata(metadata.get("metadata")) if metadata.get("metadata") else {}
  31. for key, value in inner.items():
  32. metadata.setdefault(key, value)
  33. for key in output_fields:
  34. if isinstance(row_or_metadata, dict) and row_or_metadata.get(key) not in (None, ""):
  35. metadata[key] = row_or_metadata.get(key)
  36. return metadata
  37. def normalize_metadata(metadata: Any) -> Dict[str, Any]:
  38. """将 metadata 转为字典,支持 YAML 字符串解析。"""
  39. if isinstance(metadata, dict):
  40. return dict(metadata)
  41. if isinstance(metadata, str) and metadata.strip():
  42. try:
  43. loaded = yaml.safe_load(metadata)
  44. return dict(loaded) if isinstance(loaded, dict) else {}
  45. except Exception:
  46. return {}
  47. return {}
  48. def metadata_value(metadata: Dict[str, Any], key: str) -> Any:
  49. """安全获取 metadata 值,支持嵌套 metadata.metadata 和 YAML 字符串。"""
  50. if key in metadata:
  51. return metadata.get(key)
  52. nested = metadata.get("metadata")
  53. if isinstance(nested, dict):
  54. return nested.get(key)
  55. if isinstance(nested, str) and nested.strip():
  56. try:
  57. parsed = yaml.safe_load(nested)
  58. if isinstance(parsed, dict):
  59. return parsed.get(key)
  60. except Exception:
  61. return None
  62. return None
  63. def build_candidate_key(
  64. metadata: Dict[str, Any],
  65. text: Any = "",
  66. output_fields: Sequence[str] = DEFAULT_OUTPUT_FIELDS,
  67. ) -> str:
  68. """构建候选唯一标识键,按优先级尝试不同字段组合。"""
  69. metadata = normalize_row_metadata(metadata, output_fields)
  70. document_id = str(metadata_value(metadata, "document_id") or "").strip()
  71. parent_id = str(metadata_value(metadata, "parent_id") or "").strip()
  72. chunk_id = str(metadata_value(metadata, "chunk_id") or "").strip()
  73. chapter_title = str(metadata_value(metadata, "chapter_title") or "").strip()
  74. index = metadata_value(metadata, "index")
  75. pk = str(metadata_value(metadata, "pk") or "").strip()
  76. if document_id and parent_id and chunk_id:
  77. return f"{document_id}::{parent_id}::{chunk_id}"
  78. if document_id and parent_id and chapter_title and index not in (None, ""):
  79. return f"{document_id}::{parent_id}::{chapter_title}::{index}"
  80. if pk:
  81. return pk
  82. if parent_id and chapter_title and index not in (None, ""):
  83. return f"{parent_id}::{chapter_title}::{index}"
  84. return str(text or "")[:300]
  85. def merge_metadata(current: Dict[str, Any], incoming: Dict[str, Any]) -> None:
  86. """合并两条候选的 metadata,不覆盖已有非空值。"""
  87. current_meta = current.setdefault("metadata", {})
  88. incoming_meta = incoming.get("metadata") if isinstance(incoming.get("metadata"), dict) else {}
  89. for key, value in incoming_meta.items():
  90. if key not in current_meta or current_meta.get(key) in (None, "", []):
  91. current_meta[key] = value
  92. if incoming.get("source") and not current.get("source"):
  93. current["source"] = incoming.get("source")
  94. def clean_candidates(candidates: List[Dict[str, Any]], config: RetrievalConfig) -> List[Dict[str, Any]]:
  95. """清理候选:过滤过短文本、双重去重(candidate_key + 内容哈希)。"""
  96. cleaned = []
  97. seen_keys = set()
  98. seen_hashes = set()
  99. for item in candidates:
  100. text = str(item.get("text") or "").strip()
  101. if len(text) < 20:
  102. continue
  103. metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
  104. dedupe_key = str(item.get("candidate_key") or text[:300])
  105. content_hash = _content_hash(text[:300])
  106. file_name = str(metadata.get("file_name") or "")
  107. hash_key = f"{file_name}::{content_hash}"
  108. if dedupe_key in seen_keys or hash_key in seen_hashes:
  109. continue
  110. seen_keys.add(dedupe_key)
  111. seen_hashes.add(hash_key)
  112. metadata["candidate_key"] = dedupe_key
  113. cleaned.append(
  114. {
  115. "candidate_key": dedupe_key,
  116. "text": text[: config.max_single_reference_chars],
  117. "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
  118. "vector_similarity": to_float(item.get("vector_similarity"), 0.0),
  119. "fusion_score": to_float(item.get("fusion_score"), 0.0),
  120. "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {},
  121. "metadata": metadata,
  122. }
  123. )
  124. cleaned.sort(key=lambda item: (item.get("fusion_score", 0.0), item.get("vector_similarity", 0.0)), reverse=True)
  125. return cleaned[: config.recall_top_k]
  126. def _content_hash(text: str) -> str:
  127. """基于归一化文本的短 MD5 哈希,用于内容去重。"""
  128. normalized = re.sub(r"\s+", " ", text.strip().lower())
  129. return md5(normalized.encode("utf-8")).hexdigest()[:12]