embedding_config.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: embedding_config.py
  6. @date:2023/10/23 16:03
  7. @desc:
  8. """
  9. import threading
  10. import time
  11. from common.cache.mem_cache import MemCache
  12. _lock = threading.Lock()
  13. locks = {}
  14. class ModelManage:
  15. cache = MemCache('model', {})
  16. up_clear_time = time.time()
  17. @staticmethod
  18. def _get_lock(_id):
  19. lock = locks.get(_id)
  20. if lock is None:
  21. with _lock:
  22. lock = locks.get(_id)
  23. if lock is None:
  24. lock = threading.Lock()
  25. locks[_id] = lock
  26. return lock
  27. @staticmethod
  28. def get_model(_id, get_model):
  29. model_instance = ModelManage.cache.get(_id)
  30. if model_instance is None:
  31. lock = ModelManage._get_lock(_id)
  32. with lock:
  33. model_instance = ModelManage.cache.get(_id)
  34. if model_instance is None:
  35. model_instance = get_model(_id)
  36. ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
  37. else:
  38. if model_instance.is_cache_model():
  39. ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
  40. else:
  41. model_instance = get_model(_id)
  42. ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
  43. ModelManage.clear_timeout_cache()
  44. return model_instance
  45. @staticmethod
  46. def clear_timeout_cache():
  47. if time.time() - ModelManage.up_clear_time > 60 * 60:
  48. threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start()
  49. ModelManage.up_clear_time = time.time()
  50. @staticmethod
  51. def delete_key(_id):
  52. if ModelManage.cache.has_key(_id):
  53. ModelManage.cache.delete(_id)
  54. class VectorStore:
  55. from knowledge.vector.pg_vector import PGVector
  56. from knowledge.vector.base_vector import BaseVectorStore
  57. instance_map = {
  58. 'pg_vector': PGVector,
  59. }
  60. instance = None
  61. @staticmethod
  62. def get_embedding_vector() -> BaseVectorStore:
  63. from knowledge.vector.pg_vector import PGVector
  64. if VectorStore.instance is None:
  65. from maxkb.const import CONFIG
  66. vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
  67. PGVector)
  68. VectorStore.instance = vector_store_class()
  69. return VectorStore.instance