provide.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # coding=utf-8
  2. from django.utils.translation import gettext_lazy as _
  3. from drf_spectacular.utils import extend_schema
  4. from rest_framework.request import Request
  5. from rest_framework.views import APIView
  6. from common import result
  7. from common.auth import TokenAuth
  8. from common.auth.authentication import has_permissions
  9. from common.constants.permission_constants import PermissionConstants
  10. from models_provider.api.provide import ProvideApi
  11. from models_provider.constants.model_provider_constants import ModelProvideConstants
  12. from models_provider.serializers.model_serializer import get_default_model_params_setting
  13. class Provide(APIView):
  14. authentication_classes = [TokenAuth]
  15. @extend_schema(methods=['GET'],
  16. summary=_('Get a list of model suppliers'),
  17. description=_('Get a list of model suppliers'),
  18. operation_id=_('Get a list of model suppliers'), # type: ignore
  19. responses=ProvideApi.get_response(),
  20. tags=[_('Model')]) # type: ignore
  21. def get(self, request: Request):
  22. model_type = request.query_params.get('model_type')
  23. if model_type:
  24. providers = []
  25. for key in ModelProvideConstants.__members__:
  26. if len([item for item in ModelProvideConstants[key].value.get_model_type_list() if
  27. item['value'] == model_type]) > 0:
  28. providers.append(ModelProvideConstants[key].value.get_model_provide_info().to_dict())
  29. return result.success(providers)
  30. return result.success(
  31. [ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in
  32. ModelProvideConstants.__members__])
  33. class ModelTypeList(APIView):
  34. authentication_classes = [TokenAuth]
  35. @extend_schema(methods=['GET'],
  36. summary=_('Get a list of model types'),
  37. description=_('Get a list of model types'),
  38. operation_id=_('Get a list of model types'), # type: ignore
  39. parameters=ProvideApi.ModelTypeList.get_query_params_api(),
  40. responses=ProvideApi.ModelTypeList.get_response(),
  41. tags=[_('Model')]) # type: ignore
  42. def get(self, request: Request):
  43. provider = request.query_params.get('provider')
  44. return result.success(ModelProvideConstants[provider].value.get_model_type_list())
  45. class ModelList(APIView):
  46. authentication_classes = [TokenAuth]
  47. @extend_schema(methods=['GET'],
  48. summary=_('Example of obtaining model list'),
  49. description=_('Example of obtaining model list'),
  50. operation_id=_('Example of obtaining model list'), # type: ignore
  51. parameters=ProvideApi.ModelList.get_query_params_api(),
  52. responses=ProvideApi.ModelList.get_response(),
  53. tags=[_('Model')]) # type: ignore
  54. def get(self, request: Request):
  55. provider = request.query_params.get('provider')
  56. model_type = request.query_params.get('model_type')
  57. return result.success(
  58. ModelProvideConstants[provider].value.get_model_list(
  59. model_type))
  60. class ModelParamsForm(APIView):
  61. authentication_classes = [TokenAuth]
  62. @extend_schema(methods=['GET'],
  63. summary=_('Get model default parameters'),
  64. description=_('Get model default parameters'),
  65. operation_id=_('Get model default parameters'), # type: ignore
  66. parameters=ProvideApi.ModelParamsForm.get_query_params_api(),
  67. responses=ProvideApi.ModelParamsForm.get_response(),
  68. tags=[_('Model')]) # type: ignore
  69. def get(self, request: Request):
  70. provider = request.query_params.get('provider')
  71. model_type = request.query_params.get('model_type')
  72. model_name = request.query_params.get('model_name')
  73. return result.success(get_default_model_params_setting(provider, model_type, model_name))
  74. class ModelForm(APIView):
  75. authentication_classes = [TokenAuth]
  76. @extend_schema(methods=['GET'],
  77. summary=_('Get the model creation form'),
  78. description=_('Get the model creation form'),
  79. operation_id=_('Get the model creation form'), # type: ignore
  80. parameters=ProvideApi.ModelParamsForm.get_query_params_api(),
  81. responses=ProvideApi.ModelParamsForm.get_response(),
  82. tags=[_('Model')]) # type: ignore
  83. def get(self, request: Request):
  84. provider = request.query_params.get('provider')
  85. model_type = request.query_params.get('model_type')
  86. model_name = request.query_params.get('model_name')
  87. return result.success(
  88. ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())