| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # -*- coding: utf-8 -*-
- """检索结果候选规范化与去重。"""
- from __future__ import annotations
- from hashlib import md5
- import re
- from typing import Any, Dict, List, Sequence
- import yaml
- from core.document_chat.retrieval.config import RetrievalConfig
- from core.document_chat.retrieval.utils import to_float
- DEFAULT_OUTPUT_FIELDS = (
- "pk",
- "text",
- "document_id",
- "parent_id",
- "index",
- "tag_list",
- "metadata",
- "file_name",
- "chapter_title",
- "chapter_level_1",
- "chapter_level_2",
- "chapter_level_3",
- )
- def normalize_row_metadata(
- row_or_metadata: Any,
- output_fields: Sequence[str] = DEFAULT_OUTPUT_FIELDS,
- ) -> Dict[str, Any]:
- """规范化行数据为统一的 metadata 字典。处理嵌套 metadata 和 YAML 字符串。"""
- metadata = normalize_metadata(row_or_metadata)
- inner = normalize_metadata(metadata.get("metadata")) if metadata.get("metadata") else {}
- for key, value in inner.items():
- metadata.setdefault(key, value)
- for key in output_fields:
- if isinstance(row_or_metadata, dict) and row_or_metadata.get(key) not in (None, ""):
- metadata[key] = row_or_metadata.get(key)
- return metadata
- def normalize_metadata(metadata: Any) -> Dict[str, Any]:
- """将 metadata 转为字典,支持 YAML 字符串解析。"""
- if isinstance(metadata, dict):
- return dict(metadata)
- if isinstance(metadata, str) and metadata.strip():
- try:
- loaded = yaml.safe_load(metadata)
- return dict(loaded) if isinstance(loaded, dict) else {}
- except Exception:
- return {}
- return {}
- def metadata_value(metadata: Dict[str, Any], key: str) -> Any:
- """安全获取 metadata 值,支持嵌套 metadata.metadata 和 YAML 字符串。"""
- if key in metadata:
- return metadata.get(key)
- nested = metadata.get("metadata")
- if isinstance(nested, dict):
- return nested.get(key)
- if isinstance(nested, str) and nested.strip():
- try:
- parsed = yaml.safe_load(nested)
- if isinstance(parsed, dict):
- return parsed.get(key)
- except Exception:
- return None
- return None
- def build_candidate_key(
- metadata: Dict[str, Any],
- text: Any = "",
- output_fields: Sequence[str] = DEFAULT_OUTPUT_FIELDS,
- ) -> str:
- """构建候选唯一标识键,按优先级尝试不同字段组合。"""
- metadata = normalize_row_metadata(metadata, output_fields)
- document_id = str(metadata_value(metadata, "document_id") or "").strip()
- parent_id = str(metadata_value(metadata, "parent_id") or "").strip()
- chunk_id = str(metadata_value(metadata, "chunk_id") or "").strip()
- chapter_title = str(metadata_value(metadata, "chapter_title") or "").strip()
- index = metadata_value(metadata, "index")
- pk = str(metadata_value(metadata, "pk") or "").strip()
- if document_id and parent_id and chunk_id:
- return f"{document_id}::{parent_id}::{chunk_id}"
- if document_id and parent_id and chapter_title and index not in (None, ""):
- return f"{document_id}::{parent_id}::{chapter_title}::{index}"
- if pk:
- return pk
- if parent_id and chapter_title and index not in (None, ""):
- return f"{parent_id}::{chapter_title}::{index}"
- return str(text or "")[:300]
- def merge_metadata(current: Dict[str, Any], incoming: Dict[str, Any]) -> None:
- """合并两条候选的 metadata,不覆盖已有非空值。"""
- current_meta = current.setdefault("metadata", {})
- incoming_meta = incoming.get("metadata") if isinstance(incoming.get("metadata"), dict) else {}
- for key, value in incoming_meta.items():
- if key not in current_meta or current_meta.get(key) in (None, "", []):
- current_meta[key] = value
- if incoming.get("source") and not current.get("source"):
- current["source"] = incoming.get("source")
- def clean_candidates(candidates: List[Dict[str, Any]], config: RetrievalConfig) -> List[Dict[str, Any]]:
- """清理候选:过滤过短文本、双重去重(candidate_key + 内容哈希)。"""
- cleaned = []
- seen_keys = set()
- seen_hashes = set()
- for item in candidates:
- text = str(item.get("text") or "").strip()
- if len(text) < 20:
- continue
- metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
- dedupe_key = str(item.get("candidate_key") or text[:300])
- content_hash = _content_hash(text[:300])
- file_name = str(metadata.get("file_name") or "")
- hash_key = f"{file_name}::{content_hash}"
- if dedupe_key in seen_keys or hash_key in seen_hashes:
- continue
- seen_keys.add(dedupe_key)
- seen_hashes.add(hash_key)
- metadata["candidate_key"] = dedupe_key
- cleaned.append(
- {
- "candidate_key": dedupe_key,
- "text": text[: config.max_single_reference_chars],
- "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"),
- "vector_similarity": to_float(item.get("vector_similarity"), 0.0),
- "fusion_score": to_float(item.get("fusion_score"), 0.0),
- "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {},
- "metadata": metadata,
- }
- )
- cleaned.sort(key=lambda item: (item.get("fusion_score", 0.0), item.get("vector_similarity", 0.0)), reverse=True)
- return cleaned[: config.recall_top_k]
- def _content_hash(text: str) -> str:
- """基于归一化文本的短 MD5 哈希,用于内容去重。"""
- normalized = re.sub(r"\s+", " ", text.strip().lower())
- return md5(normalized.encode("utf-8")).hexdigest()[:12]
|