reranker.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: reranker.py
  6. @date:2024/9/10 9:45
  7. @desc:
  8. """
  9. from typing import Sequence, Optional, Any, Dict
  10. from langchain_core.callbacks import Callbacks
  11. from langchain_core.documents import BaseDocumentCompressor, Document
  12. from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle
  13. from models_provider.base_model_provider import MaxKBBaseModel
  14. class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
  15. server_url: Optional[str]
  16. """URL of the xinference server"""
  17. model_uid: Optional[str]
  18. """UID of the launched model"""
  19. api_key: Optional[str]
  20. @staticmethod
  21. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  22. return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name,
  23. api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
  24. top_n: Optional[int] = 3
  25. def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
  26. Sequence[Document]:
  27. if documents is None or len(documents) == 0:
  28. return []
  29. client: Any
  30. if documents is None or len(documents) == 0:
  31. return []
  32. try:
  33. from xinference.client import RESTfulClient
  34. except ImportError:
  35. try:
  36. from xinference_client import RESTfulClient
  37. except ImportError as e:
  38. raise ImportError(
  39. "Could not import RESTfulClient from xinference. Please install it"
  40. " with `pip install xinference` or `pip install xinference_client`."
  41. ) from e
  42. client = RESTfulClient(self.server_url, self.api_key)
  43. model: RESTfulRerankModelHandle = client.get_model(self.model_uid)
  44. res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
  45. return [Document(page_content=d.get('document', {}).get('text'),
  46. metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]