model_serializer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. import threading
  5. import time
  6. from typing import Dict
  7. import uuid_utils.compat as uuid
  8. from django.core.cache import cache
  9. from django.db import transaction
  10. from django.db.models import QuerySet
  11. from django.utils.translation import gettext_lazy as _
  12. from rest_framework import serializers
  13. from common.config.embedding_config import ModelManage
  14. from common.constants.cache_version import Cache_Version
  15. from common.constants.permission_constants import ResourcePermission, ResourceAuthType
  16. from common.database_model_manage.database_model_manage import DatabaseModelManage
  17. from common.db.search import native_search
  18. from common.exception.app_exception import AppApiException
  19. from common.utils.common import get_file_content
  20. from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
  21. from maxkb.conf import PROJECT_DIR
  22. from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
  23. from models_provider.constants.model_provider_constants import ModelProvideConstants
  24. from models_provider.models import Model, Status
  25. from models_provider.tools import get_model_credential
  26. from system_manage.models import WorkspaceUserResourcePermission, AuthTargetType
  27. from system_manage.models.resource_mapping import ResourceMapping
  28. from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer
  29. from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
  30. from users.serializers.user import is_workspace_manage
  31. def get_default_model_params_setting(provider, model_type, model_name):
  32. credential = get_model_credential(provider, model_type, model_name)
  33. setting_form = credential.get_model_params_setting_form(model_name)
  34. if setting_form is not None:
  35. return setting_form.to_form_list()
  36. return []
  37. class ModelModelSerializer(serializers.ModelSerializer):
  38. class Meta:
  39. model = Model
  40. fields = [
  41. 'id', 'name', 'status', 'model_type', 'model_name',
  42. 'user', 'provider', 'credential', 'meta',
  43. 'model_params_form', 'workspace_id', 'create_time', 'update_time'
  44. ]
  45. class ModelCreateRequest(serializers.Serializer):
  46. name = serializers.CharField(required=True, max_length=64, label=_("model name"))
  47. provider = serializers.CharField(required=True, label=_("provider"))
  48. model_type = serializers.CharField(required=True, label=_("model type"))
  49. model_name = serializers.CharField(required=True, label=_("base model"))
  50. model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
  51. credential = serializers.DictField(required=True, label=_("certification information"))
  52. class ModelPullManage:
  53. @staticmethod
  54. def pull(model: Model, credential: Dict):
  55. try:
  56. response = ModelProvideConstants[model.provider].value.down_model(
  57. model.model_type, model.model_name, credential
  58. )
  59. down_model_chunk = {}
  60. last_update_time = time.time()
  61. for chunk in response:
  62. down_model_chunk[chunk.digest] = chunk.to_dict()
  63. if time.time() - last_update_time > 5:
  64. current_model = QuerySet(Model).filter(id=model.id).first()
  65. if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
  66. return
  67. QuerySet(Model).filter(id=model.id).update(
  68. meta={"down_model_chunk": list(down_model_chunk.values())}
  69. )
  70. last_update_time = time.time()
  71. status = Status.ERROR
  72. message = ""
  73. for chunk in down_model_chunk.values():
  74. if chunk.get('status') == DownModelChunkStatus.success.value:
  75. status = Status.SUCCESS
  76. elif chunk.get('status') == DownModelChunkStatus.error.value:
  77. message = chunk.get("digest")
  78. QuerySet(Model).filter(id=model.id).update(
  79. meta={"down_model_chunk": [], "message": message},
  80. status=status
  81. )
  82. except Exception as e:
  83. QuerySet(Model).filter(id=model.id).update(
  84. meta={"down_model_chunk": [], "message": str(e)},
  85. status=Status.ERROR
  86. )
  87. class ModelSerializer(serializers.Serializer):
  88. @staticmethod
  89. def model_to_dict(model: Model):
  90. credential = json.loads(rsa_long_decrypt(model.credential))
  91. return {
  92. 'id': str(model.id),
  93. 'provider': model.provider,
  94. 'name': model.name,
  95. 'model_type': model.model_type,
  96. 'model_name': model.model_name,
  97. 'status': model.status,
  98. 'meta': model.meta,
  99. 'credential': ModelProvideConstants[model.provider].value.get_model_credential(
  100. model.model_type, model.model_name
  101. ).encryption_dict(credential),
  102. 'workspace_id': model.workspace_id,
  103. 'nick_name': model.user.nick_name if model.user else '',
  104. 'username': model.user.username if model.user else ''
  105. }
  106. class Operate(serializers.Serializer):
  107. id = serializers.UUIDField(required=True, label=_("model id"))
  108. user_id = serializers.UUIDField(required=False, label=_("user id"))
  109. workspace_id = serializers.CharField(required=False, label=_("workspace id"))
  110. def is_valid(self, *, raise_exception=False):
  111. super().is_valid(raise_exception=True)
  112. workspace_id = self.data.get("workspace_id")
  113. model_query = QuerySet(Model).filter(id=self.data.get("id"))
  114. if workspace_id is not None:
  115. model_query = model_query.filter(workspace_id=workspace_id)
  116. model = model_query.first()
  117. if model is None:
  118. raise AppApiException(500, _('Model does not exist'))
  119. if model.workspace_id == 'None':
  120. raise AppApiException(500, _('Shared models cannot be deleted or modified'))
  121. def one(self, with_valid=False):
  122. if with_valid:
  123. super().is_valid(raise_exception=True)
  124. model = QuerySet(Model).get(
  125. id=self.data.get('id'), workspace_id=self.data.get('workspace_id', 'None')
  126. )
  127. return ModelSerializer.model_to_dict(model)
  128. def one_meta(self, with_valid=False):
  129. model = None
  130. if with_valid:
  131. super().is_valid(raise_exception=True)
  132. model = QuerySet(Model).filter(id=self.data.get("id"),
  133. workspace_id=self.data.get('workspace_id', 'None')).first()
  134. if model is None:
  135. raise AppApiException(500, _('Model does not exist'))
  136. return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
  137. 'model_name': model.model_name,
  138. 'status': model.status,
  139. 'meta': model.meta,
  140. 'workspace_id': model.workspace_id,
  141. }
  142. def pause_download(self, with_valid=True):
  143. if with_valid:
  144. self.is_valid(raise_exception=True)
  145. QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
  146. return True
  147. @transaction.atomic
  148. def delete(self, with_valid=True):
  149. if with_valid:
  150. super().is_valid(raise_exception=True)
  151. model_id = self.data.get('id')
  152. model = Model.objects.filter(id=model_id).first()
  153. if model is None:
  154. return True
  155. QuerySet(WorkspaceUserResourcePermission).filter(target=model_id).delete()
  156. # TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
  157. # if model.model_type == 'LLM':
  158. # application_count = Application.objects.filter(model_id=model_id).count()
  159. # if application_count > 0:
  160. # raise AppApiException(500, f"该模型关联了{application_count} 个应用,无法删除该模型。")
  161. # elif model.model_type == 'EMBEDDING':
  162. # dataset_count = DataSet.objects.filter(embedding_model_id=model_id).count()
  163. # if dataset_count > 0:
  164. # raise AppApiException(500, f"该模型关联了{dataset_count} 个知识库,无法删除该模型。")
  165. # elif model.model_type == 'TTS':
  166. # dataset_count = Application.objects.filter(tts_model_id=model_id).count()
  167. # if dataset_count > 0:
  168. # raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
  169. # elif model.model_type == 'STT':
  170. # dataset_count = Application.objects.filter(stt_model_id=model_id).count()
  171. # if dataset_count > 0:
  172. # raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
  173. model.delete()
  174. ResourceMapping.objects.filter(target_id=model_id).delete()
  175. return True
  176. def edit(self, instance: Dict, user_id: str, with_valid=True):
  177. if with_valid:
  178. super().is_valid(raise_exception=True)
  179. model = QuerySet(Model).filter(id=self.data.get('id')).first()
  180. credential, model_credential, provider_handler = ModelSerializer.Edit(
  181. data={**instance}).is_valid(
  182. model=model)
  183. try:
  184. model.status = Status.SUCCESS
  185. default_params = {item['field']: item['default_value'] for item in model.model_params_form}
  186. # 校验模型认证数据
  187. provider_handler.is_valid_credential(model.model_type,
  188. instance.get("model_name"),
  189. credential,
  190. default_params,
  191. raise_exception=True)
  192. except AppApiException as e:
  193. if e.code == ValidCode.model_not_fount:
  194. model.status = Status.DOWNLOAD
  195. else:
  196. raise e
  197. update_keys = ['credential', 'name', 'model_type', 'model_name']
  198. for update_key in update_keys:
  199. if update_key in instance and instance.get(update_key) is not None:
  200. if update_key == 'credential':
  201. model_credential_str = json.dumps(credential)
  202. model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
  203. else:
  204. model.__setattr__(update_key, instance.get(update_key))
  205. ModelManage.delete_key(str(model.id))
  206. model.save()
  207. if model.status == Status.DOWNLOAD:
  208. thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
  209. thread.start()
  210. return self.one(with_valid=False)
  211. class Edit(serializers.Serializer):
  212. user_id = serializers.CharField(required=False, label=(_('user id')))
  213. name = serializers.CharField(required=False, max_length=64,
  214. label=(_("model name")))
  215. model_type = serializers.CharField(required=False, label=(_("model type")))
  216. model_name = serializers.CharField(required=False, label=(_("base model")))
  217. credential = serializers.DictField(required=False,
  218. label=(_("certification information")))
  219. workspace_id = serializers.CharField(required=False, label=(_("workspace id")))
  220. def is_valid(self, model=None, raise_exception=False):
  221. super().is_valid(raise_exception=True)
  222. filter_params = {'workspace_id': model.workspace_id}
  223. if 'name' in self.data and self.data.get('name') is not None:
  224. filter_params['name'] = self.data.get('name')
  225. if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
  226. raise AppApiException(500, _('base model【{model_name}】already exists').format(
  227. model_name=self.data.get("name")))
  228. ModelSerializer.model_to_dict(model)
  229. provider = model.provider
  230. model_type = self.data.get('model_type')
  231. model_name = self.data.get(
  232. 'model_name')
  233. credential = self.data.get('credential')
  234. provider_handler = ModelProvideConstants[provider].value
  235. model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
  236. model_name)
  237. source_model_credential = json.loads(rsa_long_decrypt(model.credential))
  238. source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
  239. if credential is not None:
  240. for k in source_encryption_model_credential.keys():
  241. if k in credential and credential[k] == source_encryption_model_credential[k]:
  242. credential[k] = source_model_credential[k]
  243. return credential, model_credential, provider_handler
  244. class Create(serializers.Serializer):
  245. user_id = serializers.UUIDField(required=True, label=_('user id'))
  246. name = serializers.CharField(required=True, max_length=64, label=_("model name"))
  247. provider = serializers.CharField(required=True, label=_("provider"))
  248. model_type = serializers.CharField(required=True, label=_("model type"))
  249. model_name = serializers.CharField(required=True, label=_("base model"))
  250. model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
  251. credential = serializers.DictField(required=True, label=_("certification information"))
  252. workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
  253. def is_valid(self, *, raise_exception=False):
  254. super().is_valid(raise_exception=True)
  255. if QuerySet(Model).filter(
  256. name=self.data.get('name'),
  257. workspace_id=self.data.get('workspace_id', 'None')
  258. ).exists():
  259. raise AppApiException(
  260. 500,
  261. _('base model【{model_name}】already exists').format(model_name=self.data.get("name"))
  262. )
  263. default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
  264. ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
  265. self.data.get('model_type'),
  266. self.data.get('model_name'),
  267. self.data.get('credential'),
  268. default_params,
  269. raise_exception=True
  270. )
  271. def insert(self, workspace_id, with_valid=True):
  272. status = Status.SUCCESS
  273. if with_valid:
  274. try:
  275. self.is_valid(raise_exception=True)
  276. except AppApiException as e:
  277. if e.code == ValidCode.model_not_fount:
  278. status = Status.DOWNLOAD
  279. else:
  280. raise e
  281. credential = self.data.get('credential')
  282. model_data = {
  283. 'id': uuid.uuid7(),
  284. 'status': status,
  285. 'user_id': self.data.get('user_id'),
  286. 'name': self.data.get('name'),
  287. 'credential': rsa_long_encrypt(json.dumps(credential)),
  288. 'provider': self.data.get('provider'),
  289. 'model_type': self.data.get('model_type'),
  290. 'model_name': self.data.get('model_name'),
  291. 'model_params_form': self.data.get('model_params_form'),
  292. 'workspace_id': workspace_id
  293. }
  294. model = Model(**model_data)
  295. try:
  296. model.save()
  297. if workspace_id != 'None':
  298. UserResourcePermissionSerializer(data={
  299. 'workspace_id': workspace_id,
  300. 'user_id': self.data.get('user_id'),
  301. 'auth_target_type': AuthTargetType.MODEL.value
  302. }).auth_resource(str(model.id))
  303. except Exception as save_error:
  304. # 可添加日志记录
  305. raise AppApiException(500, _("Model saving failed")) from save_error
  306. if status == Status.DOWNLOAD:
  307. thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
  308. thread.start()
  309. return ModelModelSerializer(model).data
  310. class Query(serializers.Serializer):
  311. user_id = serializers.CharField(required=True, label=_("User ID"))
  312. name = serializers.CharField(required=False, max_length=64, label=_('model name'))
  313. model_type = serializers.CharField(required=False, label=_('model type'))
  314. model_name = serializers.CharField(required=False, label=_('base model'))
  315. provider = serializers.CharField(required=False, label=_('provider'))
  316. create_user = serializers.CharField(required=False, label=_('create user'))
  317. workspace_id = serializers.CharField(required=False, label=_('workspace id'))
  318. @staticmethod
  319. def is_x_pack_ee():
  320. workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping")
  321. role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model")
  322. return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None
  323. def list(self, workspace_id, with_valid):
  324. if with_valid:
  325. self.is_valid(raise_exception=True)
  326. user_id = self.data.get("user_id")
  327. workspace_manage = is_workspace_manage(user_id, workspace_id)
  328. query_params = self._build_query_params(workspace_id, workspace_manage, user_id)
  329. is_x_pack_ee = self.is_x_pack_ee()
  330. result = native_search(query_params,
  331. select_string=get_file_content(
  332. os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
  333. 'list_model.sql' if workspace_manage else (
  334. 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
  335. )))
  336. return ResourceMappingSerializer().get_resource_count(result)
  337. def share_list(self, workspace_id, with_valid=True):
  338. if with_valid:
  339. self.is_valid(raise_exception=True)
  340. user_id = self.data.get("user_id")
  341. query_params = self._build_query_params(workspace_id, False, user_id)
  342. result = [
  343. self._build_model_data(
  344. model
  345. ) for model in query_params.get('model_query_set')
  346. ]
  347. return ResourceMappingSerializer().get_resource_count(result)
  348. def model_list(self, workspace_id, with_valid=True):
  349. if with_valid:
  350. self.is_valid(raise_exception=True)
  351. user_id = self.data.get("user_id")
  352. workspace_manage = is_workspace_manage(user_id, workspace_id)
  353. queryset = self._build_query_params(workspace_id, workspace_manage, user_id)
  354. get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
  355. shared_queryset = QuerySet(Model).none()
  356. if get_authorized_model is not None:
  357. shared_queryset = self._build_query_params('None', False, user_id)['model_query_set']
  358. shared_queryset = get_authorized_model(shared_queryset, workspace_id)
  359. # 构建共享模型和普通模型列表
  360. shared_model = [self._build_model_data(model) for model in shared_queryset]
  361. is_x_pack_ee = self.is_x_pack_ee()
  362. normal_model = native_search(
  363. queryset,
  364. select_string=get_file_content(
  365. os.path.join(
  366. PROJECT_DIR, "apps", "models_provider", 'sql',
  367. 'list_model.sql' if workspace_manage else (
  368. 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
  369. )
  370. )
  371. )
  372. return {
  373. "shared_model": shared_model,
  374. "model": normal_model
  375. }
  376. def _build_query_params(self, workspace_id, workspace_manage: bool, user_id):
  377. queryset = QuerySet(Model)
  378. if workspace_id:
  379. queryset = queryset.filter(workspace_id=workspace_id)
  380. for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
  381. value = self.data.get(field)
  382. if value is not None:
  383. if field == 'name':
  384. queryset = queryset.filter(**{f'{field}__icontains': value})
  385. elif field == 'create_user':
  386. queryset = queryset.filter(user_id=value)
  387. else:
  388. queryset = queryset.filter(**{field: value})
  389. queryset = queryset.order_by("-create_time")
  390. return {
  391. 'model_query_set': queryset,
  392. 'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter(
  393. auth_target_type="MODEL",
  394. workspace_id=workspace_id,
  395. user_id=user_id)} if (
  396. not workspace_manage) else {
  397. 'model_query_set': queryset,
  398. }
  399. def _build_model_data(self, model):
  400. return {
  401. 'id': str(model.id),
  402. 'provider': model.provider,
  403. 'name': model.name,
  404. 'model_type': model.model_type,
  405. 'model_name': model.model_name,
  406. 'status': model.status,
  407. 'meta': model.meta,
  408. 'user_id': model.user_id,
  409. 'username': model.user.username,
  410. 'nick_name': model.user.nick_name,
  411. }
  412. def page(self, current_page, page_size):
  413. pass
  414. class ModelParams(serializers.Serializer):
  415. id = serializers.UUIDField(required=True, label=_('model id'))
  416. def is_valid(self, *, raise_exception=False):
  417. super().is_valid(raise_exception=True)
  418. model = QuerySet(Model).filter(id=self.data.get("id")).first()
  419. if model is None:
  420. raise AppApiException(500, _("Model does not exist"))
  421. def get_model_params(self, with_valid=True):
  422. if with_valid:
  423. self.is_valid(raise_exception=True)
  424. model_id = self.data.get('id')
  425. model = QuerySet(Model).filter(id=model_id).first()
  426. return model.model_params_form
  427. def save_model_params_form(self, model_params_form, with_valid=True):
  428. if with_valid:
  429. self.is_valid(raise_exception=True)
  430. if model_params_form is None:
  431. model_params_form = []
  432. model_id = self.data.get('id')
  433. model = QuerySet(Model).filter(id=model_id).first()
  434. model.model_params_form = model_params_form
  435. model.save()
  436. return True
  437. class WorkspaceSharedModelSerializer(serializers.Serializer):
  438. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  439. name = serializers.CharField(required=False, max_length=64, label=_('model name'))
  440. model_type = serializers.CharField(required=False, label=_('model type'))
  441. model_name = serializers.CharField(required=False, label=_('base model'))
  442. provider = serializers.CharField(required=False, label=_('provider'))
  443. create_user = serializers.CharField(required=False, label=_('create user'))
  444. def get_share_model_list(self):
  445. self.is_valid(raise_exception=True)
  446. workspace_id = self.data.get('workspace_id')
  447. queryset = self._build_queryset(workspace_id)
  448. return [
  449. {
  450. 'id': str(model.id),
  451. 'provider': model.provider,
  452. 'name': model.name,
  453. 'model_type': model.model_type,
  454. 'model_name': model.model_name,
  455. 'status': model.status,
  456. 'meta': model.meta,
  457. 'user_id': model.user_id,
  458. 'nick_name': model.user.nick_name,
  459. 'username': model.user.username
  460. }
  461. for model in queryset.order_by("-create_time")
  462. ]
  463. def _build_queryset(self, workspace_id):
  464. queryset = QuerySet(Model)
  465. if workspace_id:
  466. get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
  467. if get_authorized_model is not None:
  468. queryset = get_authorized_model(queryset, workspace_id)
  469. for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
  470. value = self.data.get(field)
  471. if value is not None:
  472. if field == 'name':
  473. queryset = queryset.filter(**{f'{field}__icontains': value})
  474. elif field == 'create_user':
  475. queryset = queryset.filter(user_id=value)
  476. else:
  477. queryset = queryset.filter(**{field: value})
  478. return queryset