model_apply_serializers.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: model_apply_serializers.py
  6. @date:2024/8/20 20:39
  7. @desc:
  8. """
  9. import json
  10. import threading
  11. import time
  12. from django.db import connection
  13. from django.db.models import QuerySet
  14. from django.utils.translation import gettext_lazy as _
  15. from langchain_core.documents import Document
  16. from rest_framework import serializers
  17. from local_model.models import Model
  18. from local_model.serializers.rsa_util import rsa_long_decrypt
  19. from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
  20. from common.cache.mem_cache import MemCache
  21. _lock = threading.Lock()
  22. locks = {}
  23. class ModelManage:
  24. cache = MemCache('model', {})
  25. up_clear_time = time.time()
  26. @staticmethod
  27. def _get_lock(_id):
  28. lock = locks.get(_id)
  29. if lock is None:
  30. with _lock:
  31. lock = locks.get(_id)
  32. if lock is None:
  33. lock = threading.Lock()
  34. locks[_id] = lock
  35. return lock
  36. @staticmethod
  37. def get_model(_id, get_model):
  38. model_instance = ModelManage.cache.get(_id)
  39. if model_instance is None:
  40. lock = ModelManage._get_lock(_id)
  41. with lock:
  42. model_instance = ModelManage.cache.get(_id)
  43. if model_instance is None:
  44. model_instance = get_model(_id)
  45. ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
  46. else:
  47. if model_instance.is_cache_model():
  48. ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
  49. else:
  50. model_instance = get_model(_id)
  51. ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
  52. ModelManage.clear_timeout_cache()
  53. return model_instance
  54. @staticmethod
  55. def clear_timeout_cache():
  56. if time.time() - ModelManage.up_clear_time > 60 * 60:
  57. threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start()
  58. ModelManage.up_clear_time = time.time()
  59. @staticmethod
  60. def delete_key(_id):
  61. if ModelManage.cache.has_key(_id):
  62. ModelManage.cache.delete(_id)
  63. def get_local_model(model, **kwargs):
  64. return LocalModelProvider().get_model(model.model_type, model.model_name,
  65. json.loads(
  66. rsa_long_decrypt(model.credential)),
  67. model_id=model.id,
  68. streaming=True, **kwargs)
  69. def get_embedding_model(model_id):
  70. model = QuerySet(Model).filter(id=model_id).first()
  71. # 手动关闭数据库连接
  72. connection.close()
  73. embedding_model = ModelManage.get_model(model_id,
  74. lambda _id: get_local_model(model, use_local=True))
  75. return embedding_model
  76. class EmbedDocuments(serializers.Serializer):
  77. texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
  78. label=_('vector text')),
  79. label=_('vector text list')),
  80. class EmbedQuery(serializers.Serializer):
  81. text = serializers.CharField(required=True, label=_('vector text'))
  82. class CompressDocument(serializers.Serializer):
  83. page_content = serializers.CharField(required=True, label=_('text'))
  84. metadata = serializers.DictField(required=False, label=_('metadata'))
  85. class CompressDocuments(serializers.Serializer):
  86. documents = CompressDocument(required=True, many=True)
  87. query = serializers.CharField(required=True, label=_('query'))
  88. class ValidateModelSerializers(serializers.Serializer):
  89. model_name = serializers.CharField(required=True, label=_('model_name'))
  90. model_type = serializers.CharField(required=True, label=_('model_type'))
  91. model_credential = serializers.DictField(required=True, label="credential")
  92. def validate_model(self, with_valid=True):
  93. if with_valid:
  94. self.is_valid(raise_exception=True)
  95. LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
  96. self.data.get('model_credential'), model_params={},
  97. raise_exception=True)
  98. class ModelApplySerializers(serializers.Serializer):
  99. model_id = serializers.UUIDField(required=True, label=_('model id'))
  100. def embed_documents(self, instance, with_valid=True):
  101. if with_valid:
  102. self.is_valid(raise_exception=True)
  103. EmbedDocuments(data=instance).is_valid(raise_exception=True)
  104. model = get_embedding_model(self.data.get('model_id'))
  105. return model.embed_documents(instance.getlist('texts'))
  106. def embed_query(self, instance, with_valid=True):
  107. if with_valid:
  108. self.is_valid(raise_exception=True)
  109. EmbedQuery(data=instance).is_valid(raise_exception=True)
  110. model = get_embedding_model(self.data.get('model_id'))
  111. return model.embed_query(instance.get('text'))
  112. def compress_documents(self, instance, with_valid=True):
  113. if with_valid:
  114. self.is_valid(raise_exception=True)
  115. CompressDocuments(data=instance).is_valid(raise_exception=True)
  116. model = get_embedding_model(self.data.get('model_id'))
  117. return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
  118. [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
  119. instance.get('documents')], instance.get('query'))]
  120. def unload(self, with_valid=True):
  121. if with_valid:
  122. self.is_valid(raise_exception=True)
  123. ModelManage.delete_key(self.data.get('model_id'))
  124. return True