| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # -*- coding: utf-8 -*-
- """AI对话引用文档重排序服务。"""
- from __future__ import annotations
- from typing import Any, Dict, List, Optional
- from core.document_chat.component.document_chat_logger import document_chat_logger as logger
- from core.document_chat.retrieval.config import RetrievalConfig, load_retrieval_config
- class DocumentChatRerankService:
- """执行重排并将分数合并回原始候选索引。"""
- def __init__(self, config: Optional[RetrievalConfig] = None):
- self.config = config or load_retrieval_config()
- def rerank(self, query: str, candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
- if not candidates:
- return {
- "reranked_references": [],
- "retrieval_status": "no_recall",
- "retrieval_metrics": {"rerank_count": 0},
- "warnings": [],
- }
- if not query.strip():
- return self._failed("查询为空,无法进行知识库重排。")
- try:
- from foundation.ai.models.rerank_model import rerank_model
- documents = [str(item.get("text") or "") for item in candidates]
- raw_results = rerank_model.shutian_rerank(
- query=query,
- candidates=documents,
- top_k=self.config.rerank_top_k,
- )
- except Exception as exc:
- logger.warning(f"[DocumentChat] rerank failed: {exc}", exc_info=True)
- return self._failed("知识库片段重排不可用,本次未引用向量库内容。")
- reranked = self._merge_rerank_results(raw_results, candidates)
- if not reranked:
- return self._failed("知识库片段重排不可用,本次未引用向量库内容。")
- metrics = {
- "rerank_count": len(reranked),
- "max_rerank_score": max((item.get("rerank_score", 0.0) for item in reranked), default=0.0),
- }
- return {
- "reranked_references": reranked,
- "retrieval_status": "reranked",
- "retrieval_metrics": metrics,
- "warnings": [],
- }
- def _merge_rerank_results(
- self,
- raw_results: List[Dict[str, Any]],
- candidates: List[Dict[str, Any]],
- ) -> List[Dict[str, Any]]:
- if not isinstance(raw_results, list):
- return []
- merged = []
- used_indexes = set()
- text_to_unique_index = self._unique_text_index(candidates)
- for item in raw_results:
- if not isinstance(item, dict):
- continue
- original_index = self._resolve_index(item, text_to_unique_index)
- if original_index is None or original_index in used_indexes:
- continue
- if original_index < 0 or original_index >= len(candidates):
- continue
- score = self._to_float(item.get("score", item.get("relevance_score")), 0.0)
- candidate = dict(candidates[original_index])
- candidate["rerank_score"] = score
- candidate["rerank_index"] = original_index
- merged.append(candidate)
- used_indexes.add(original_index)
- merged.sort(key=lambda row: row.get("rerank_score", 0.0), reverse=True)
- return merged[: self.config.rerank_top_k]
- @staticmethod
- def _unique_text_index(candidates: List[Dict[str, Any]]) -> Dict[str, int]:
- counts = {}
- for item in candidates:
- text = str(item.get("text") or "")
- counts[text] = counts.get(text, 0) + 1
- return {
- str(item.get("text") or ""): index
- for index, item in enumerate(candidates)
- if counts.get(str(item.get("text") or ""), 0) == 1
- }
- def _resolve_index(self, item: Dict[str, Any], text_to_unique_index: Dict[str, int]) -> Optional[int]:
- try:
- return int(item["index"])
- except (KeyError, TypeError, ValueError):
- pass
- doc = item.get("document")
- text = doc if isinstance(doc, str) else ""
- if isinstance(doc, dict):
- text = str(doc.get("text") or "")
- text = text or str(item.get("text") or "")
- if text in text_to_unique_index:
- return text_to_unique_index[text]
- return None
- @staticmethod
- def _to_float(value: Any, default: float) -> float:
- try:
- return float(value)
- except (TypeError, ValueError):
- return default
- @staticmethod
- def _failed(message: str) -> Dict[str, Any]:
- return {
- "reranked_references": [],
- "retrieval_status": "rerank_failed",
- "retrieval_metrics": {"rerank_count": 0, "max_rerank_score": 0.0},
- "warnings": [message],
- }
|