web.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: web.py
  6. @date:2025/11/5 15:30
  7. @desc:
  8. """
  9. from typing import Sequence, Optional, Dict
  10. import requests
  11. from anthropic import BaseModel
  12. from langchain_core.callbacks import Callbacks
  13. from langchain_core.documents import Document, BaseDocumentCompressor
  14. from maxkb.const import CONFIG
  15. from models_provider.base_model_provider import MaxKBBaseModel
  16. class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor):
  17. @staticmethod
  18. def is_cache_model():
  19. return False
  20. @staticmethod
  21. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  22. return LocalReranker(model_type=model_type, model_name=model_name, model_credential=model_credential,
  23. **model_kwargs)
  24. model_id: str = None
  25. def __init__(self, **kwargs):
  26. super().__init__(**kwargs)
  27. self.model_id = kwargs.get('model_id', None)
  28. def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
  29. Sequence[Document]:
  30. if documents is None or len(documents) == 0:
  31. return []
  32. prefix = CONFIG.get_admin_path()
  33. bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
  34. res = requests.post(
  35. f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
  36. json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
  37. documents], 'query': query}, headers={'Content-Type': 'application/json'})
  38. result = res.json()
  39. if result.get('code', 500) == 200:
  40. return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
  41. in result.get('data')]
  42. raise Exception(result.get('message'))