model.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: model.py
  6. @date:2025/11/5 15:26
  7. @desc:
  8. """
  9. from typing import Dict
  10. from langchain_huggingface import HuggingFaceEmbeddings
  11. from common.utils.logger import maxkb_logger
  12. from models_provider.base_model_provider import MaxKBBaseModel
  13. max_retries = 3
  14. class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
  15. @staticmethod
  16. def is_cache_model():
  17. return True
  18. @staticmethod
  19. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  20. for attempt in range(max_retries):
  21. try:
  22. embedding = LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
  23. model_kwargs={'device': model_credential.get('device')},
  24. encode_kwargs={'normalize_embeddings': True}
  25. )
  26. # 测试一下是否真的能用
  27. embedding.embed_query("test")
  28. return embedding
  29. except Exception as e:
  30. if 'meta tensor' in str(e).lower() and attempt < max_retries - 1:
  31. maxkb_logger.warning(
  32. f"Test failed with meta tensor error, retrying... (attempt {attempt + 1}/{max_retries})")
  33. import time
  34. time.sleep(1)
  35. continue
  36. raise e