reranker.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: siliconcloud_reranker.py
  6. @date:2024/9/10 9:45
  7. @desc: SiliconCloud 文档重排封装
  8. """
  9. import json
  10. from typing import Sequence, Optional, Any, Dict
  11. import requests
  12. from langchain_core.callbacks import Callbacks
  13. from langchain_core.documents import BaseDocumentCompressor, Document
  14. from models_provider.base_model_provider import MaxKBBaseModel
  15. class DockerAIReranker(MaxKBBaseModel, BaseDocumentCompressor):
  16. api_base: Optional[str]
  17. model: Optional[str]
  18. top_n: Optional[int] = 3 # 取前 N 个最相关的结果
  19. @staticmethod
  20. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  21. return DockerAIReranker(
  22. api_base=model_credential.get('api_base'),
  23. model=model_name,
  24. top_n=model_kwargs.get('top_n', 3)
  25. )
  26. def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
  27. Sequence[Document]:
  28. if not documents:
  29. return []
  30. # 预处理文本
  31. texts = [doc.page_content for doc in documents]
  32. headers = {
  33. "Content-Type": "application/json"
  34. }
  35. payload = {
  36. "model": self.model,
  37. "query": query,
  38. "documents": texts,
  39. "top_n": self.top_n,
  40. }
  41. response = requests.post(f"{self.api_base}/rerank", data=json.dumps(payload), headers=headers)
  42. if response.status_code != 200:
  43. raise RuntimeError(f"Docker AI API 请求失败: {response.text}")
  44. res = response.json()
  45. # 解析返回结果
  46. return [
  47. Document(
  48. page_content=payload['documents'][item.get('index')],
  49. metadata={'relevance_score': item.get('relevance_score')}
  50. )
  51. for item in res.get('results', [])
  52. ]