model.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: user.py
  6. @date:2025/4/14 19:25
  7. @desc:
  8. """
  9. from django.db.models import QuerySet
  10. from drf_spectacular.utils import extend_schema
  11. from rest_framework.views import APIView
  12. from django.utils.translation import gettext_lazy as _
  13. from rest_framework.request import Request
  14. from common.auth import TokenAuth
  15. from common.auth.authentication import has_permissions
  16. from common.constants.permission_constants import PermissionConstants, RoleConstants, ViewPermission, CompareConstants
  17. from common.log.log import log
  18. from common.result import result
  19. from common.utils.common import query_params_to_single_dict
  20. from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse
  21. from models_provider.api.provide import ProvideApi
  22. from models_provider.models import Model, Status
  23. from models_provider.serializers.model_serializer import ModelSerializer, \
  24. WorkspaceSharedModelSerializer
  25. from system_manage.views import encryption_str
  26. def encryption_credential(credential):
  27. if isinstance(credential, dict):
  28. return {key: encryption_str(credential.get(key)) for key in credential}
  29. return credential
  30. def get_edit_model_details(request):
  31. path = request.path
  32. body = request.data
  33. query = request.query_params
  34. credential = body.get('credential', {})
  35. credential_encryption_ed = encryption_credential(credential)
  36. return {
  37. 'path': path,
  38. 'body': {**body, 'credential': credential_encryption_ed},
  39. 'query': query
  40. }
  41. def get_model_operation_object(model_id):
  42. model_model = QuerySet(model=Model).filter(id=model_id).first()
  43. if model_model is not None:
  44. return {
  45. "name": model_model.name
  46. }
  47. return {}
  48. class ModelSetting(APIView):
  49. authentication_classes = [TokenAuth]
  50. @extend_schema(methods=['POST'],
  51. summary=_("Create model"),
  52. description=_("Create model"),
  53. operation_id=_("Create model"), # type: ignore
  54. tags=[_("Model")], # type: ignore
  55. parameters=ModelCreateAPI.get_parameters(),
  56. request=ModelCreateAPI.get_request(),
  57. responses=ModelCreateAPI.get_response())
  58. @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(),
  59. PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(),
  60. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
  61. @log(menu='model', operate='Create model',
  62. get_operation_object=lambda r, k: {'name': r.date.get('name')},
  63. get_details=get_edit_model_details,
  64. )
  65. def post(self, request: Request, workspace_id: str):
  66. return result.success(
  67. ModelSerializer.Create(
  68. data={**request.data, 'user_id': request.user.id, 'workspace_id': workspace_id}).insert(workspace_id,
  69. with_valid=True))
  70. # @extend_schema(methods=['PUT'],
  71. # summary=_('Update model'),
  72. # operation_id=_('Update model'), # type: ignore
  73. # request=ModelEditApi.get_request(),
  74. # responses=ModelCreateApi.get_response(),
  75. # tags=[_('Model')]) # type: ignore
  76. # @has_permissions(PermissionConstants.MODEL_CREATE)
  77. # def put(self, request: Request):
  78. # return result.success(
  79. # ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
  80. # with_valid=True))
  81. @extend_schema(methods=['GET'],
  82. summary=_('Query model list'),
  83. description=_('Query model list'),
  84. operation_id=_('Query model list'), # type: ignore
  85. parameters=ModelListResponse.get_parameters(),
  86. responses=ModelListResponse.get_response(),
  87. tags=[_('Model')]) # type: ignore
  88. @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
  89. PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(),
  90. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
  91. def get(self, request: Request, workspace_id: str):
  92. return result.success(
  93. ModelSerializer.Query(
  94. data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).list(
  95. workspace_id=workspace_id,
  96. with_valid=True))
  97. class Operate(APIView):
  98. authentication_classes = [TokenAuth]
  99. @extend_schema(methods=['PUT'],
  100. summary=_('Update model'),
  101. description=_('Update model'),
  102. operation_id=_('Update model'), # type: ignore
  103. request=ModelEditApi.get_request(),
  104. parameters=GetModelApi.get_parameters(),
  105. responses=ModelEditApi.get_response(),
  106. tags=[_('Model')]) # type: ignore
  107. @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_model_permission(),
  108. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  109. PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(),
  110. ViewPermission([RoleConstants.USER.get_workspace_role()],
  111. [PermissionConstants.MODEL.get_workspace_model_permission()],
  112. CompareConstants.AND), )
  113. @log(menu='model', operate='Update model',
  114. get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
  115. get_details=get_edit_model_details,
  116. )
  117. def put(self, request: Request, workspace_id, model_id: str):
  118. return result.success(
  119. ModelSerializer.Operate(
  120. data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).edit(request.data,
  121. str(request.user.id)))
  122. @extend_schema(methods=['DELETE'],
  123. summary=_('Delete model'),
  124. description=_('Delete model'),
  125. operation_id=_('Delete model'), # type: ignore
  126. parameters=GetModelApi.get_parameters(),
  127. responses=DefaultModelResponse.get_response(),
  128. tags=[_('Model')]) # type: ignore
  129. @has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_model_permission(),
  130. PermissionConstants.MODEL_DELETE.get_workspace_permission_workspace_manage_role(),
  131. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  132. ViewPermission([RoleConstants.USER.get_workspace_role()],
  133. [PermissionConstants.MODEL.get_workspace_model_permission()],
  134. CompareConstants.AND), )
  135. @log(menu='model', operate='Delete model',
  136. get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
  137. )
  138. def delete(self, request: Request, workspace_id: str, model_id: str):
  139. return result.success(
  140. ModelSerializer.Operate(
  141. data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).delete())
  142. @extend_schema(methods=['GET'],
  143. summary=_('Query model details'),
  144. description=_('Query model details'),
  145. operation_id=_('Query model details'), # type: ignore
  146. parameters=GetModelApi.get_parameters(),
  147. responses=GetModelApi.get_response(),
  148. tags=[_('Model')]) # type: ignore
  149. @has_permissions(PermissionConstants.MODEL_READ.get_workspace_model_permission(),
  150. PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(),
  151. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  152. ViewPermission([RoleConstants.USER.get_workspace_role()],
  153. [PermissionConstants.MODEL.get_workspace_model_permission()],
  154. CompareConstants.AND), )
  155. def get(self, request: Request, workspace_id: str, model_id: str):
  156. return result.success(
  157. ModelSerializer.Operate(
  158. data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).one(
  159. with_valid=True))
  160. class ModelParamsForm(APIView):
  161. authentication_classes = [TokenAuth]
  162. @extend_schema(methods=['GET'],
  163. summary=_('Get model parameter form'),
  164. description=_('Get model parameter form'),
  165. operation_id=_('Get model parameter form'), # type: ignore
  166. parameters=GetModelApi.get_parameters(),
  167. responses=ProvideApi.ModelParamsForm.get_response(),
  168. tags=[_('Model')]) # type: ignore
  169. @has_permissions(PermissionConstants.MODEL_READ.get_workspace_model_permission(),
  170. PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(),
  171. PermissionConstants.APPLICATION_READ.get_workspace_permission(),
  172. PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(),
  173. PermissionConstants.KNOWLEDGE_READ.get_workspace_permission_workspace_manage_role(),
  174. PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
  175. PermissionConstants.MODEL_READ.get_workspace_permission(),
  176. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  177. RoleConstants.USER.get_workspace_role(),)
  178. def get(self, request: Request, workspace_id: str, model_id: str):
  179. return result.success(
  180. ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
  181. @extend_schema(methods=['PUT'],
  182. summary=_('Save model parameter form'),
  183. description=_('Save model parameter form'),
  184. operation_id=_('Save model parameter form'), # type: ignore
  185. parameters=GetModelApi.get_parameters(),
  186. request=GetModelApi.get_request(),
  187. responses=ProvideApi.ModelParamsForm.get_response(),
  188. tags=[_('Model')]) # type: ignore
  189. @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_model_permission(),
  190. PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(),
  191. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  192. PermissionConstants.MODEL_READ.get_workspace_permission(),
  193. ViewPermission([RoleConstants.USER.get_workspace_role()],
  194. [PermissionConstants.MODEL.get_workspace_model_permission()],
  195. CompareConstants.AND), )
  196. @log(menu='model', operate='Save model parameter form',
  197. get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')),
  198. )
  199. def put(self, request: Request, workspace_id: str, model_id: str):
  200. return result.success(
  201. ModelSerializer.ModelParams(data={'id': model_id}).save_model_params_form(request.data))
  202. class ModelMeta(APIView):
  203. authentication_classes = [TokenAuth]
  204. @extend_schema(methods=['GET'],
  205. summary=_(
  206. 'Query model meta information, this interface does not carry authentication information'),
  207. description=_(
  208. 'Query model meta information, this interface does not carry authentication information'),
  209. operation_id=_(
  210. 'Query model meta information, this interface does not carry authentication information'),
  211. parameters=GetModelApi.get_parameters(),
  212. responses=GetModelApi.get_response(),
  213. tags=[_('Model')]) # type: ignore
  214. @has_permissions(PermissionConstants.MODEL_READ.get_workspace_model_permission(),
  215. PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(),
  216. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  217. PermissionConstants.MODEL_READ.get_workspace_permission(),
  218. ViewPermission([RoleConstants.USER.get_workspace_role()],
  219. [PermissionConstants.MODEL.get_workspace_model_permission()],
  220. CompareConstants.AND), )
  221. def get(self, request: Request, workspace_id: str, model_id: str):
  222. return result.success(
  223. ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True))
  224. class PauseDownload(APIView):
  225. authentication_classes = [TokenAuth]
  226. @extend_schema(methods=['PUT'],
  227. summary=_('Pause model download'),
  228. description=_('Pause model download'),
  229. operation_id=_('Pause model download'), # type: ignore
  230. parameters=GetModelApi.get_parameters(),
  231. request=GetModelApi.get_request(),
  232. responses=DefaultModelResponse.get_response(),
  233. tags=[_('Model')]) # type: ignore
  234. @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_model_permission(),
  235. PermissionConstants.MODEL_CREATE.get_workspace_permission_workspace_manage_role(),
  236. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  237. ViewPermission([RoleConstants.USER.get_workspace_role()],
  238. [PermissionConstants.MODEL.get_workspace_model_permission()],
  239. CompareConstants.AND), )
  240. def put(self, request: Request, workspace_id: str, model_id: str):
  241. return result.success(
  242. ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download())
  243. class BatchDelete(APIView):
  244. """批量删除模型"""
  245. authentication_classes = [TokenAuth]
  246. @extend_schema(methods=['POST'],
  247. summary=_('Batch delete models'),
  248. description=_('Batch delete models'),
  249. operation_id=_('Batch delete models'),
  250. tags=[_('Model')])
  251. @has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission(),
  252. PermissionConstants.MODEL_DELETE.get_workspace_permission_workspace_manage_role(),
  253. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  254. RoleConstants.USER.get_workspace_role())
  255. def post(self, request: Request, workspace_id: str):
  256. model_ids = request.data.get('model_ids', [])
  257. if not model_ids:
  258. return result.success({'deleted_count': 0})
  259. deleted_count = 0
  260. for model_id in model_ids:
  261. try:
  262. ModelSerializer.Operate(
  263. data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}
  264. ).delete()
  265. deleted_count += 1
  266. except Exception:
  267. pass
  268. return result.success({'deleted_count': deleted_count})
  269. class BatchOperate(APIView):
  270. """批量设置模型状态"""
  271. authentication_classes = [TokenAuth]
  272. @extend_schema(methods=['POST'],
  273. summary=_('Batch update model status'),
  274. description=_('Batch update model status'),
  275. operation_id=_('Batch update model status'),
  276. tags=[_('Model')])
  277. @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission(),
  278. PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(),
  279. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  280. RoleConstants.USER.get_workspace_role())
  281. def post(self, request: Request, workspace_id: str):
  282. model_ids = request.data.get('model_ids', [])
  283. status = request.data.get('status', '')
  284. if not model_ids or status not in [s[0] for s in Status.choices]:
  285. return result.success({'updated_count': 0})
  286. updated_count = QuerySet(Model).filter(
  287. id__in=model_ids, workspace_id=workspace_id
  288. ).update(status=status)
  289. return result.success({'updated_count': updated_count})
  290. class WorkspaceSharedModelSetting(APIView):
  291. authentication_classes = [TokenAuth]
  292. @extend_schema(
  293. methods=['Get'],
  294. summary=_('Get Share model by workspace id'),
  295. description=_('Get Share model by workspace id'),
  296. operation_id=_('Get Share model by workspace id'), # type: ignore
  297. parameters=ModelListResponse.get_parameters(),
  298. responses=DefaultModelResponse.get_response(),
  299. tags=[_('Shared Model')]
  300. ) # type: ignore
  301. @has_permissions(
  302. PermissionConstants.MODEL_READ.get_workspace_permission(),
  303. PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(),
  304. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
  305. RoleConstants.USER.get_workspace_role(),
  306. )
  307. def get(self, request: Request, workspace_id: str):
  308. return result.success(
  309. WorkspaceSharedModelSerializer(data={**query_params_to_single_dict(request.query_params),
  310. 'workspace_id': workspace_id}).get_share_model_list())
  311. class ModelList(APIView):
  312. authentication_classes = [TokenAuth]
  313. @extend_schema(methods=['GET'],
  314. summary=_('Query all model list'),
  315. description=_('Query all model list'),
  316. operation_id=_('Query all model list'), # type: ignore
  317. parameters=ModelListResponse.get_parameters(),
  318. responses=ModelListResponse.get_response(),
  319. tags=[_('Model')]) # type: ignore
  320. @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(),
  321. PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(),
  322. PermissionConstants.APPLICATION_READ.get_workspace_permission(),
  323. PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(),
  324. PermissionConstants.KNOWLEDGE_READ.get_workspace_permission_workspace_manage_role(),
  325. PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
  326. RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role())
  327. def get(self, request: Request, workspace_id: str):
  328. return result.success(
  329. ModelSerializer.Query(
  330. data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).model_list(
  331. workspace_id=workspace_id,
  332. with_valid=True))