embedding.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # coding=utf-8
  2. import threading
  3. from typing import Dict, Optional, List, Any
  4. from langchain_core.embeddings import Embeddings
  5. from models_provider.base_model_provider import MaxKBBaseModel
  6. class XinferenceEmbedding(MaxKBBaseModel, Embeddings):
  7. client: Any
  8. server_url: Optional[str]
  9. """URL of the xinference server"""
  10. model_uid: Optional[str]
  11. """UID of the launched model"""
  12. @staticmethod
  13. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  14. return XinferenceEmbedding(
  15. model_uid=model_name,
  16. server_url=model_credential.get('api_base'),
  17. api_key=model_credential.get('api_key'),
  18. )
  19. def down_model(self):
  20. self.client.launch_model(model_name=self.model_uid, model_type="embedding")
  21. def start_down_model_thread(self):
  22. thread = threading.Thread(target=self.down_model)
  23. thread.daemon = True
  24. thread.start()
  25. def __init__(
  26. self, server_url: Optional[str] = None, model_uid: Optional[str] = None,
  27. api_key: Optional[str] = None
  28. ):
  29. try:
  30. from xinference.client import RESTfulClient
  31. except ImportError:
  32. try:
  33. from xinference_client import RESTfulClient
  34. except ImportError as e:
  35. raise ImportError(
  36. "Could not import RESTfulClient from xinference. Please install it"
  37. " with `pip install xinference` or `pip install xinference_client`."
  38. ) from e
  39. if server_url is None:
  40. raise ValueError("Please provide server URL")
  41. if model_uid is None:
  42. raise ValueError("Please provide the model UID")
  43. self.server_url = server_url
  44. self.model_uid = model_uid
  45. self.api_key = api_key
  46. self.client = RESTfulClient(server_url, api_key)
  47. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  48. """Embed a list of documents using Xinference.
  49. Args:
  50. texts: The list of texts to embed.
  51. Returns:
  52. List of embeddings, one for each text.
  53. """
  54. model = self.client.get_model(self.model_uid)
  55. embeddings = [
  56. model.create_embedding(text)["data"][0]["embedding"] for text in texts
  57. ]
  58. return [list(map(float, e)) for e in embeddings]
  59. def embed_query(self, text: str) -> List[float]:
  60. """Embed a query of documents using Xinference.
  61. Args:
  62. text: The text to embed.
  63. Returns:
  64. Embeddings for the text.
  65. """
  66. model = self.client.get_model(self.model_uid)
  67. embedding_res = model.create_embedding(text)
  68. embedding = embedding_res["data"][0]["embedding"]
  69. return list(map(float, embedding))