embedding.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: embedding.py
  6. @date:2024/7/12 15:02
  7. @desc:
  8. """
  9. from typing import Dict, List
  10. from langchain_ollama import OllamaEmbeddings
  11. from models_provider.base_model_provider import MaxKBBaseModel
  12. class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
  13. @staticmethod
  14. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  15. return OllamaEmbedding(
  16. model=model_name,
  17. base_url=model_credential.get('api_base'),
  18. )
  19. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  20. """Embed documents using an Ollama deployed embedding model.
  21. Args:
  22. texts: The list of texts to embed.
  23. Returns:
  24. List of embeddings, one for each text.
  25. """
  26. return self._client.embed(
  27. self.model, texts, options=self._default_params, keep_alive=self.keep_alive
  28. )["embeddings"]
  29. def embed_query(self, text: str) -> List[float]:
  30. """Embed a query using a Ollama deployed embedding model.
  31. Args:
  32. text: The text to embed.
  33. Returns:
  34. Embeddings for the text.
  35. """
  36. return self.embed_documents([text])[0]