reranker.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import Sequence, Optional, Dict
  2. from langchain_ollama import OllamaEmbeddings
  3. from langchain_core.callbacks import Callbacks
  4. from langchain_core.documents import Document
  5. from pydantic import BaseModel, Field
  6. from models_provider.base_model_provider import MaxKBBaseModel
  7. class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
  8. top_n: Optional[int] = Field(3, description="Number of top documents to return")
  9. @staticmethod
  10. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  11. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  12. return OllamaReranker(
  13. model=model_name,
  14. base_url=model_credential.get('api_base'),
  15. **optional_params
  16. )
  17. def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
  18. Sequence[Document]:
  19. from sklearn.metrics.pairwise import cosine_similarity
  20. """Rank documents based on their similarity to the query.
  21. Args:
  22. query: The query text.
  23. documents: The list of document texts to rank.
  24. Returns:
  25. List of documents sorted by relevance to the query.
  26. """
  27. # 获取查询和文档的嵌入
  28. query_embedding = self.embed_query(query)
  29. documents = [doc.page_content for doc in documents]
  30. document_embeddings = self.embed_documents(documents)
  31. # 计算相似度
  32. similarities = cosine_similarity([query_embedding], document_embeddings)[0]
  33. ranked_docs = [(doc, _) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
  34. return [
  35. Document(
  36. page_content=doc, # 第一个值是文档内容
  37. metadata={'relevance_score': score} # 第二个值是相似度分数
  38. )
  39. for doc, score in ranked_docs
  40. ]