reranker.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: siliconcloud_reranker.py
  6. @date:2024/9/10 9:45
  7. @desc: SiliconCloud 文档重排封装
  8. """
  9. from typing import Sequence, Optional, Any, Dict
  10. import requests
  11. from langchain_core.callbacks import Callbacks
  12. from langchain_core.documents import BaseDocumentCompressor, Document
  13. from models_provider.base_model_provider import MaxKBBaseModel
  14. from django.utils.translation import gettext as _
  15. class SiliconCloudReranker(MaxKBBaseModel, BaseDocumentCompressor):
  16. api_base: Optional[str]
  17. """SiliconCloud API URL"""
  18. model: Optional[str]
  19. """SiliconCloud 重排模型 ID"""
  20. api_key: Optional[str]
  21. """API Key"""
  22. top_n: Optional[int] = 3 # 取前 N 个最相关的结果
  23. @staticmethod
  24. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  25. return SiliconCloudReranker(
  26. api_base=model_credential.get('api_base'),
  27. model=model_name,
  28. api_key=model_credential.get('api_key'),
  29. top_n=model_kwargs.get('top_n', 3)
  30. )
  31. def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
  32. Sequence[Document]:
  33. if not documents:
  34. return []
  35. # 预处理文本
  36. texts = [doc.page_content for doc in documents]
  37. # 发送请求到 SiliconCloud API
  38. headers = {
  39. "Authorization": f"Bearer {self.api_key}",
  40. "Content-Type": "application/json"
  41. }
  42. payload = {
  43. "model": self.model,
  44. "query": query,
  45. "documents": texts,
  46. "top_n": self.top_n,
  47. "return_documents": True,
  48. }
  49. response = requests.post(f"{self.api_base}/rerank", json=payload, headers=headers)
  50. if response.status_code != 200:
  51. raise RuntimeError(f"SiliconCloud API 请求失败: {response.text}")
  52. res = response.json()
  53. # 解析返回结果
  54. return [
  55. Document(
  56. page_content=item.get('document', {}).get('text', ''),
  57. metadata={'relevance_score': item.get('relevance_score')}
  58. )
  59. for item in res.get('results', [])
  60. ]