tti.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # coding=utf-8
  2. from typing import Dict
  3. from django.utils.translation import gettext_lazy as _, gettext
  4. from common import forms
  5. from common.exception.app_exception import AppApiException
  6. from common.forms import BaseForm, TooltipLabel
  7. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  8. from common.utils.logger import maxkb_logger
  9. class AzureOpenAITTIModelParams(BaseForm):
  10. size = forms.SingleSelect(
  11. TooltipLabel(_('Image size'), _('Specify the size of the generated image, such as: 1024x1024')),
  12. required=True,
  13. default_value='1024x1024',
  14. option_list=[
  15. {'value': '1024x1024', 'label': '1024x1024'},
  16. {'value': '1024x1792', 'label': '1024x1792'},
  17. {'value': '1792x1024', 'label': '1792x1024'},
  18. ],
  19. text_field='label',
  20. value_field='value'
  21. )
  22. quality = forms.SingleSelect(
  23. TooltipLabel(_('Picture quality'), ''),
  24. required=True,
  25. default_value='standard',
  26. option_list=[
  27. {'value': 'standard', 'label': 'standard'},
  28. {'value': 'hd', 'label': 'hd'},
  29. ],
  30. text_field='label',
  31. value_field='value'
  32. )
  33. n = forms.SliderField(
  34. TooltipLabel(_('Number of pictures'), _('Specify the number of generated images')),
  35. required=True, default_value=1,
  36. _min=1,
  37. _max=10,
  38. _step=1,
  39. precision=0)
  40. class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
  41. api_version = forms.TextInputField("API Version", required=True)
  42. api_base = forms.TextInputField('Azure Endpoint', required=True)
  43. api_key = forms.PasswordInputField("API Key", required=True)
  44. def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
  45. raise_exception=False):
  46. model_type_list = provider.get_model_type_list()
  47. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  48. raise AppApiException(ValidCode.valid_error.value,
  49. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  50. for key in ['api_base', 'api_key', 'api_version']:
  51. if key not in model_credential:
  52. if raise_exception:
  53. raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
  54. else:
  55. return False
  56. try:
  57. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  58. res = model.check_auth()
  59. except Exception as e:
  60. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  61. if isinstance(e, AppApiException):
  62. raise e
  63. if raise_exception:
  64. raise AppApiException(ValidCode.valid_error.value,
  65. gettext(
  66. 'Verification failed, please check whether the parameters are correct: {error}').format(
  67. error=str(e)))
  68. else:
  69. return False
  70. return True
  71. def encryption_dict(self, model: Dict[str, object]):
  72. return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
  73. def get_model_params_setting_form(self, model_name):
  74. return AzureOpenAITTIModelParams()