reranker.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import json
  2. from typing import Sequence, Optional, Dict, Any
  3. import requests
  4. from langchain_core.callbacks import Callbacks
  5. from langchain_core.documents import BaseDocumentCompressor, Document
  6. from models_provider.base_model_provider import MaxKBBaseModel
  7. class QfBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
  8. api_key: str
  9. api_url: str
  10. model: str
  11. params: dict
  12. top_n: int = 3
  13. def __init__(self, **kwargs):
  14. super().__init__(**kwargs)
  15. self.api_key = kwargs.get('api_key')
  16. self.model = kwargs.get('model')
  17. self.params = kwargs.get('params', {})
  18. self.api_url = kwargs.get('api_url')
  19. self.top_n = self.params.get('top_n', 3)
  20. @staticmethod
  21. def is_cache_model():
  22. return False
  23. @staticmethod
  24. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  25. return QfBgeReranker(
  26. model=model_name,
  27. api_key=model_credential.get('api_key'),
  28. api_url=model_credential.get('api_url'),
  29. params=model_kwargs,
  30. )
  31. def compress_documents(
  32. self,
  33. documents: Sequence[Document],
  34. query: str,
  35. callbacks: Optional[Callbacks] = None
  36. ) -> Sequence[Document]:
  37. if not documents:
  38. return []
  39. texts = [doc.page_content for doc in documents]
  40. headers = {
  41. "Authorization": f"Bearer {self.api_key}",
  42. "Content-Type": "application/json"
  43. }
  44. top_n = min(self.top_n, len(texts))
  45. payload = {
  46. "model": self.model,
  47. "query": query,
  48. "documents": texts,
  49. "top_n": top_n
  50. }
  51. response = requests.post(f"{self.api_url}/rerank", json=payload, headers=headers)
  52. if response.status_code != 200:
  53. raise RuntimeError(f"千帆 API 请求失败:{response.text}")
  54. res = response.json()
  55. return [
  56. Document(
  57. page_content=item.get('document', ''),
  58. metadata={'relevance_score': item.get('relevance_score')}
  59. )
  60. for item in res.get('results', [])
  61. ]