rerank_service.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # -*- coding: utf-8 -*-
  2. """AI对话引用文档重排序服务。"""
  3. from __future__ import annotations
  4. from typing import Any, Dict, List, Optional
  5. from core.document_chat.component.document_chat_logger import document_chat_logger as logger
  6. from core.document_chat.retrieval.config import RetrievalConfig, load_retrieval_config
  7. class DocumentChatRerankService:
  8. """执行重排并将分数合并回原始候选索引。"""
  9. def __init__(self, config: Optional[RetrievalConfig] = None):
  10. self.config = config or load_retrieval_config()
  11. def rerank(self, query: str, candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
  12. if not candidates:
  13. return {
  14. "reranked_references": [],
  15. "retrieval_status": "no_recall",
  16. "retrieval_metrics": {"rerank_count": 0},
  17. "warnings": [],
  18. }
  19. if not query.strip():
  20. return self._failed("查询为空,无法进行知识库重排。")
  21. try:
  22. from foundation.ai.models.rerank_model import rerank_model
  23. documents = [str(item.get("text") or "") for item in candidates]
  24. raw_results = rerank_model.shutian_rerank(
  25. query=query,
  26. candidates=documents,
  27. top_k=self.config.rerank_top_k,
  28. )
  29. except Exception as exc:
  30. logger.warning(f"[DocumentChat] rerank failed: {exc}", exc_info=True)
  31. return self._failed("知识库片段重排不可用,本次未引用向量库内容。")
  32. reranked = self._merge_rerank_results(raw_results, candidates)
  33. if not reranked:
  34. return self._failed("知识库片段重排不可用,本次未引用向量库内容。")
  35. metrics = {
  36. "rerank_count": len(reranked),
  37. "max_rerank_score": max((item.get("rerank_score", 0.0) for item in reranked), default=0.0),
  38. }
  39. return {
  40. "reranked_references": reranked,
  41. "retrieval_status": "reranked",
  42. "retrieval_metrics": metrics,
  43. "warnings": [],
  44. }
  45. def _merge_rerank_results(
  46. self,
  47. raw_results: List[Dict[str, Any]],
  48. candidates: List[Dict[str, Any]],
  49. ) -> List[Dict[str, Any]]:
  50. if not isinstance(raw_results, list):
  51. return []
  52. merged = []
  53. used_indexes = set()
  54. text_to_unique_index = self._unique_text_index(candidates)
  55. for item in raw_results:
  56. if not isinstance(item, dict):
  57. continue
  58. original_index = self._resolve_index(item, text_to_unique_index)
  59. if original_index is None or original_index in used_indexes:
  60. continue
  61. if original_index < 0 or original_index >= len(candidates):
  62. continue
  63. score = self._to_float(item.get("score", item.get("relevance_score")), 0.0)
  64. candidate = dict(candidates[original_index])
  65. candidate["rerank_score"] = score
  66. candidate["rerank_index"] = original_index
  67. merged.append(candidate)
  68. used_indexes.add(original_index)
  69. merged.sort(key=lambda row: row.get("rerank_score", 0.0), reverse=True)
  70. return merged[: self.config.rerank_top_k]
  71. @staticmethod
  72. def _unique_text_index(candidates: List[Dict[str, Any]]) -> Dict[str, int]:
  73. counts = {}
  74. for item in candidates:
  75. text = str(item.get("text") or "")
  76. counts[text] = counts.get(text, 0) + 1
  77. return {
  78. str(item.get("text") or ""): index
  79. for index, item in enumerate(candidates)
  80. if counts.get(str(item.get("text") or ""), 0) == 1
  81. }
  82. def _resolve_index(self, item: Dict[str, Any], text_to_unique_index: Dict[str, int]) -> Optional[int]:
  83. try:
  84. return int(item["index"])
  85. except (KeyError, TypeError, ValueError):
  86. pass
  87. doc = item.get("document")
  88. text = doc if isinstance(doc, str) else ""
  89. if isinstance(doc, dict):
  90. text = str(doc.get("text") or "")
  91. text = text or str(item.get("text") or "")
  92. if text in text_to_unique_index:
  93. return text_to_unique_index[text]
  94. return None
  95. @staticmethod
  96. def _to_float(value: Any, default: float) -> float:
  97. try:
  98. return float(value)
  99. except (TypeError, ValueError):
  100. return default
  101. @staticmethod
  102. def _failed(message: str) -> Dict[str, Any]:
  103. return {
  104. "reranked_references": [],
  105. "retrieval_status": "rerank_failed",
  106. "retrieval_metrics": {"rerank_count": 0, "max_rerank_score": 0.0},
  107. "warnings": [message],
  108. }