# -*- coding: utf-8 -*- """检索结果候选规范化与去重。""" from __future__ import annotations from hashlib import md5 import re from typing import Any, Dict, List, Sequence import yaml from core.document_chat.retrieval.config import RetrievalConfig from core.document_chat.retrieval.utils import to_float DEFAULT_OUTPUT_FIELDS = ( "pk", "text", "document_id", "parent_id", "index", "tag_list", "metadata", "file_name", "chapter_title", "chapter_level_1", "chapter_level_2", "chapter_level_3", ) def normalize_row_metadata( row_or_metadata: Any, output_fields: Sequence[str] = DEFAULT_OUTPUT_FIELDS, ) -> Dict[str, Any]: """规范化行数据为统一的 metadata 字典。处理嵌套 metadata 和 YAML 字符串。""" metadata = normalize_metadata(row_or_metadata) inner = normalize_metadata(metadata.get("metadata")) if metadata.get("metadata") else {} for key, value in inner.items(): metadata.setdefault(key, value) for key in output_fields: if isinstance(row_or_metadata, dict) and row_or_metadata.get(key) not in (None, ""): metadata[key] = row_or_metadata.get(key) return metadata def normalize_metadata(metadata: Any) -> Dict[str, Any]: """将 metadata 转为字典,支持 YAML 字符串解析。""" if isinstance(metadata, dict): return dict(metadata) if isinstance(metadata, str) and metadata.strip(): try: loaded = yaml.safe_load(metadata) return dict(loaded) if isinstance(loaded, dict) else {} except Exception: return {} return {} def metadata_value(metadata: Dict[str, Any], key: str) -> Any: """安全获取 metadata 值,支持嵌套 metadata.metadata 和 YAML 字符串。""" if key in metadata: return metadata.get(key) nested = metadata.get("metadata") if isinstance(nested, dict): return nested.get(key) if isinstance(nested, str) and nested.strip(): try: parsed = yaml.safe_load(nested) if isinstance(parsed, dict): return parsed.get(key) except Exception: return None return None def build_candidate_key( metadata: Dict[str, Any], text: Any = "", output_fields: Sequence[str] = DEFAULT_OUTPUT_FIELDS, ) -> str: """构建候选唯一标识键,按优先级尝试不同字段组合。""" metadata = normalize_row_metadata(metadata, output_fields) document_id = str(metadata_value(metadata, "document_id") or "").strip() parent_id = str(metadata_value(metadata, "parent_id") or "").strip() chunk_id = str(metadata_value(metadata, "chunk_id") or "").strip() chapter_title = str(metadata_value(metadata, "chapter_title") or "").strip() index = metadata_value(metadata, "index") pk = str(metadata_value(metadata, "pk") or "").strip() if document_id and parent_id and chunk_id: return f"{document_id}::{parent_id}::{chunk_id}" if document_id and parent_id and chapter_title and index not in (None, ""): return f"{document_id}::{parent_id}::{chapter_title}::{index}" if pk: return pk if parent_id and chapter_title and index not in (None, ""): return f"{parent_id}::{chapter_title}::{index}" return str(text or "")[:300] def merge_metadata(current: Dict[str, Any], incoming: Dict[str, Any]) -> None: """合并两条候选的 metadata,不覆盖已有非空值。""" current_meta = current.setdefault("metadata", {}) incoming_meta = incoming.get("metadata") if isinstance(incoming.get("metadata"), dict) else {} for key, value in incoming_meta.items(): if key not in current_meta or current_meta.get(key) in (None, "", []): current_meta[key] = value if incoming.get("source") and not current.get("source"): current["source"] = incoming.get("source") def clean_candidates(candidates: List[Dict[str, Any]], config: RetrievalConfig) -> List[Dict[str, Any]]: """清理候选:过滤过短文本、双重去重(candidate_key + 内容哈希)。""" cleaned = [] seen_keys = set() seen_hashes = set() for item in candidates: text = str(item.get("text") or "").strip() if len(text) < 20: continue metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} dedupe_key = str(item.get("candidate_key") or text[:300]) content_hash = _content_hash(text[:300]) file_name = str(metadata.get("file_name") or "") hash_key = f"{file_name}::{content_hash}" if dedupe_key in seen_keys or hash_key in seen_hashes: continue seen_keys.add(dedupe_key) seen_hashes.add(hash_key) metadata["candidate_key"] = dedupe_key cleaned.append( { "candidate_key": dedupe_key, "text": text[: config.max_single_reference_chars], "source": str(item.get("source") or metadata.get("file_name") or "向量知识库"), "vector_similarity": to_float(item.get("vector_similarity"), 0.0), "fusion_score": to_float(item.get("fusion_score"), 0.0), "source_hits": item.get("source_hits") if isinstance(item.get("source_hits"), dict) else {}, "metadata": metadata, } ) cleaned.sort(key=lambda item: (item.get("fusion_score", 0.0), item.get("vector_similarity", 0.0)), reverse=True) return cleaned[: config.recall_top_k] def _content_hash(text: str) -> str: """基于归一化文本的短 MD5 哈希,用于内容去重。""" normalized = re.sub(r"\s+", " ", text.strip().lower()) return md5(normalized.encode("utf-8")).hexdigest()[:12]