embedding.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: embedding.py
  6. @date:2024/10/16 16:34
  7. @desc:
  8. """
  9. from typing import Dict, List
  10. from common.utils.logger import maxkb_logger
  11. import requests
  12. from models_provider.base_model_provider import MaxKBBaseModel
  13. class SiliconCloudEmbeddingModel(MaxKBBaseModel):
  14. model_name: str
  15. openai_api_key: str
  16. base_url: str
  17. optional_params: dict
  18. def __init__(self, api_key, model_name: str, base_url, optional_params: dict):
  19. self.openai_api_key = api_key
  20. self.base_url = base_url
  21. self.model_name = model_name
  22. self.optional_params = optional_params
  23. def is_cache_model(self):
  24. return False
  25. @staticmethod
  26. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  27. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  28. return SiliconCloudEmbeddingModel(
  29. api_key=model_credential.get('api_key'),
  30. model_name=model_name,
  31. optional_params=optional_params,
  32. base_url=model_credential.get('api_base'),
  33. )
  34. def embed_query(self, text: str) -> list:
  35. payload = {
  36. "model": self.model_name,
  37. "input": text,
  38. **self.optional_params
  39. }
  40. headers = {
  41. "Authorization": f"Bearer {self.openai_api_key}",
  42. "Content-Type": "application/json"
  43. }
  44. response = requests.post(self.base_url + '/embeddings', json=payload, headers=headers)
  45. data = response.json()
  46. if isinstance(data, dict):
  47. if data['data'] is None or 'code' in data:
  48. raise ValueError(f"Embedding API returned no data: {data}")
  49. # 假设返回结构中有 'data[0].embedding'
  50. return data["data"][0]["embedding"]
  51. else:
  52. maxkb_logger.error(f"Unexpected response from Embedding API: {data}")
  53. def embed_documents(self, texts: list) -> list:
  54. return [self.embed_query(text) for text in texts]