embedding.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from typing import Dict, List
  2. from models_provider.base_model_provider import MaxKBBaseModel
  3. from volcenginesdkarkruntime import Ark
  4. class VolcanicEngineEmbeddingModel(MaxKBBaseModel):
  5. api_key: str
  6. model_name: str
  7. api_base: str
  8. params: Dict[str, object]
  9. def __init__(self, api_key: str, model: str, api_base: str, params: Dict[str, object] = None):
  10. self.client = Ark(
  11. api_key=api_key,
  12. base_url=api_base
  13. )
  14. self.model_name = model
  15. self.params = params
  16. @staticmethod
  17. def is_cache_model():
  18. return False
  19. @staticmethod
  20. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  21. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  22. return VolcanicEngineEmbeddingModel(
  23. api_key=model_credential.get("api_key"),
  24. model=model_name,
  25. api_base=model_credential.get("api_base"),
  26. **optional_params
  27. )
  28. def embed_query(self, text: str):
  29. res = self.embed_documents([text])
  30. return res[0]
  31. def embed_documents(
  32. self, texts: List[str], chunk_size: int | None = None
  33. ) -> List[List[float]]:
  34. if self.model_name.startswith("doubao-embedding-vision-"):
  35. multimodal_inputs = []
  36. for text in texts:
  37. multimodal_inputs.append({
  38. "type": "text",
  39. "text": text
  40. })
  41. resp = self.client.multimodal_embeddings.create(
  42. model=self.model_name,
  43. input=multimodal_inputs,
  44. **(self.params or {})
  45. )
  46. return [resp.data.get('embedding')]
  47. else:
  48. resp = self.client.embeddings.create(
  49. model=self.model_name,
  50. input=texts,
  51. **(self.params or {})
  52. )
  53. return [e.embedding for e in resp.data]