embedding.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from typing import Dict, List
  2. from langchain_core.embeddings import Embeddings
  3. from tencentcloud.common import credential
  4. from tencentcloud.hunyuan.v20230901.hunyuan_client import HunyuanClient
  5. from tencentcloud.hunyuan.v20230901.models import GetEmbeddingRequest
  6. from models_provider.base_model_provider import MaxKBBaseModel
  7. class TencentEmbeddingModel(MaxKBBaseModel, Embeddings):
  8. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  9. return [self.embed_query(text) for text in texts]
  10. def embed_query(self, text: str) -> List[float]:
  11. request = GetEmbeddingRequest()
  12. request.Input = text
  13. res = self.client.GetEmbedding(request)
  14. return res.Data[0].Embedding
  15. def __init__(self, secret_id: str, secret_key: str, model_name: str):
  16. self.secret_id = secret_id
  17. self.secret_key = secret_key
  18. self.model_name = model_name
  19. cred = credential.Credential(
  20. secret_id, secret_key
  21. )
  22. self.client = HunyuanClient(cred, "")
  23. @staticmethod
  24. def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
  25. return TencentEmbeddingModel(
  26. secret_id=model_credential.get('SecretId'),
  27. secret_key=model_credential.get('SecretKey'),
  28. model_name=model_name,
  29. )
  30. def _generate_auth_token(self):
  31. # Example method to generate an authentication token for the model API
  32. return f"{self.secret_id}:{self.secret_key}"