web.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: web.py
  6. @date:2025/11/5 15:24
  7. @desc:
  8. """
  9. from typing import Dict, List
  10. import requests
  11. from anthropic import BaseModel
  12. from langchain_core.embeddings import Embeddings
  13. from maxkb.const import CONFIG
  14. from models_provider.base_model_provider import MaxKBBaseModel
  15. class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
  16. def __init__(self, **kwargs):
  17. super().__init__(**kwargs)
  18. self.model_id = kwargs.get('model_id', None)
  19. @staticmethod
  20. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  21. return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
  22. model_kwargs={'device': model_credential.get('device')},
  23. encode_kwargs={'normalize_embeddings': True},
  24. **model_kwargs)
  25. model_id: str = None
  26. def embed_query(self, text: str) -> List[float]:
  27. bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
  28. prefix = CONFIG.get_admin_path()
  29. res = requests.post(
  30. f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
  31. {'text': text})
  32. result = res.json()
  33. if result.get('code', 500) == 200:
  34. return result.get('data')
  35. raise Exception(result.get('message'))
  36. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  37. bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
  38. prefix = CONFIG.get_admin_path()
  39. res = requests.post(
  40. f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
  41. {'texts': texts})
  42. result = res.json()
  43. if result.get('code', 500) == 200:
  44. return result.get('data')
  45. raise Exception(result.get('message'))