reranker.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import Sequence, Optional, Dict, Any
  2. import cohere
  3. from langchain_core.callbacks import Callbacks
  4. from langchain_core.documents import BaseDocumentCompressor, Document
  5. from models_provider.base_model_provider import MaxKBBaseModel
  6. class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor):
  7. api_key: str
  8. api_url: str
  9. model: str
  10. params: dict
  11. client: Any = None
  12. def __init__(self, **kwargs):
  13. super().__init__(**kwargs)
  14. self.api_key = kwargs.get('api_key')
  15. self.model = kwargs.get('model')
  16. self.params = kwargs.get('params')
  17. self.api_url = kwargs.get('api_url')
  18. self.client = cohere.ClientV2(kwargs.get('api_key'), base_url=kwargs.get('api_url'))
  19. @staticmethod
  20. def is_cache_model():
  21. return False
  22. @staticmethod
  23. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  24. r_url = model_credential.get('api_url')[:-3] if model_credential.get('api_url').endswith('/v1') else model_credential.get('api_url')
  25. return VllmBgeReranker(
  26. model=model_name,
  27. api_key=model_credential.get('api_key'),
  28. api_url=r_url,
  29. params=model_kwargs,
  30. **model_kwargs
  31. )
  32. def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
  33. Sequence[Document]:
  34. if documents is None or len(documents) == 0:
  35. return []
  36. ds = [d.page_content for d in documents]
  37. result = self.client.rerank(model=self.model, query=query, documents=ds)
  38. return [Document(page_content=d.document.get('text'), metadata={'relevance_score': d.relevance_score}) for d in
  39. result.results]