# -*- coding: utf-8 -*- """Rerank retrieved document-chat references.""" from __future__ import annotations from typing import Any, Dict, List, Optional from foundation.observability.logger.loggering import write_logger as logger from core.document_chat.component.retrieval_service import RetrievalConfig, load_retrieval_config class DocumentChatRerankService: """Run rerank and merge scores back by original candidate index.""" def __init__(self, config: Optional[RetrievalConfig] = None): self.config = config or load_retrieval_config() def rerank(self, query: str, candidates: List[Dict[str, Any]]) -> Dict[str, Any]: if not candidates: return { "reranked_references": [], "retrieval_status": "no_recall", "retrieval_metrics": {"rerank_count": 0}, "warnings": [], } if not query.strip(): return self._failed("查询为空,无法进行知识库重排。") try: from foundation.ai.models.rerank_model import rerank_model documents = [str(item.get("text") or "") for item in candidates] raw_results = rerank_model.shutian_rerank( query=query, candidates=documents, top_k=self.config.rerank_top_k, ) except Exception as exc: logger.warning(f"[DocumentChat] rerank failed: {exc}", exc_info=True) return self._failed("知识库片段重排不可用,本次未引用向量库内容。") reranked = self._merge_rerank_results(raw_results, candidates) if not reranked: return self._failed("知识库片段重排不可用,本次未引用向量库内容。") metrics = { "rerank_count": len(reranked), "max_rerank_score": max((item.get("rerank_score", 0.0) for item in reranked), default=0.0), } return { "reranked_references": reranked, "retrieval_status": "reranked", "retrieval_metrics": metrics, "warnings": [], } def _merge_rerank_results( self, raw_results: List[Dict[str, Any]], candidates: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: if not isinstance(raw_results, list): return [] merged = [] used_indexes = set() text_to_unique_index = self._unique_text_index(candidates) for item in raw_results: if not isinstance(item, dict): continue original_index = self._resolve_index(item, text_to_unique_index) if original_index is None or original_index in used_indexes: continue if original_index < 0 or original_index >= len(candidates): continue score = self._to_float(item.get("score", item.get("relevance_score")), 0.0) candidate = dict(candidates[original_index]) candidate["rerank_score"] = score candidate["rerank_index"] = original_index merged.append(candidate) used_indexes.add(original_index) merged.sort(key=lambda row: row.get("rerank_score", 0.0), reverse=True) return merged[: self.config.rerank_top_k] @staticmethod def _unique_text_index(candidates: List[Dict[str, Any]]) -> Dict[str, int]: counts = {} for item in candidates: text = str(item.get("text") or "") counts[text] = counts.get(text, 0) + 1 return { str(item.get("text") or ""): index for index, item in enumerate(candidates) if counts.get(str(item.get("text") or ""), 0) == 1 } def _resolve_index(self, item: Dict[str, Any], text_to_unique_index: Dict[str, int]) -> Optional[int]: try: return int(item["index"]) except (KeyError, TypeError, ValueError): pass doc = item.get("document") text = doc if isinstance(doc, str) else "" if isinstance(doc, dict): text = str(doc.get("text") or "") text = text or str(item.get("text") or "") if text in text_to_unique_index: return text_to_unique_index[text] return None @staticmethod def _to_float(value: Any, default: float) -> float: try: return float(value) except (TypeError, ValueError): return default @staticmethod def _failed(message: str) -> Dict[str, Any]: return { "reranked_references": [], "retrieval_status": "rerank_failed", "retrieval_metrics": {"rerank_count": 0, "max_rerank_score": 0.0}, "warnings": [message], }