embedding.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from langchain_aws import BedrockEmbeddings
  2. from models_provider.base_model_provider import MaxKBBaseModel
  3. from typing import Dict, List
  4. from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
  5. class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
  6. def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
  7. **kwargs):
  8. super().__init__(model_id=model_id, region_name=region_name,
  9. credentials_profile_name=credentials_profile_name, **kwargs)
  10. @classmethod
  11. def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
  12. **model_kwargs) -> 'BedrockModel':
  13. _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
  14. model_credential['secret_access_key'])
  15. return cls(
  16. model_id=model_name,
  17. region_name=model_credential['region_name'],
  18. credentials_profile_name=model_credential['access_key_id'],
  19. )
  20. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  21. """Compute doc embeddings using a Bedrock model.
  22. Args:
  23. texts: The list of texts to embed
  24. Returns:
  25. List of embeddings, one for each text.
  26. """
  27. results = []
  28. for text in texts:
  29. response = self._embedding_func(text)
  30. if self.normalize:
  31. response = self._normalize_vector(response)
  32. results.append(response)
  33. return results
  34. def embed_query(self, text: str) -> List[float]:
  35. """Compute query embeddings using a Bedrock model.
  36. Args:
  37. text: The text to embed.
  38. Returns:
  39. Embeddings for the text.
  40. """
  41. embedding = self._embedding_func(text)
  42. if self.normalize:
  43. return self._normalize_vector(embedding)
  44. return embedding