tti.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # coding=utf-8
  2. from typing import Dict
  3. from django.utils.translation import gettext_lazy as _, gettext
  4. from common import forms
  5. from common.exception.app_exception import AppApiException
  6. from common.forms import BaseForm, TooltipLabel
  7. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  8. from common.utils.logger import maxkb_logger
  9. class DockerAITTIModelParams(BaseForm):
  10. size = forms.SingleSelect(
  11. TooltipLabel(_('Image size'),
  12. _('The image generation endpoint allows you to create raw images based on text prompts. When using the DALL·E 3, the image size can be 1024x1024, 1024x1792 or 1792x1024 pixels.')),
  13. required=True,
  14. default_value='1024x1024',
  15. option_list=[
  16. {'value': '1024x1024', 'label': '1024x1024'},
  17. {'value': '1024x1792', 'label': '1024x1792'},
  18. {'value': '1792x1024', 'label': '1792x1024'},
  19. ],
  20. text_field='label',
  21. value_field='value'
  22. )
  23. quality = forms.SingleSelect(
  24. TooltipLabel(_('Picture quality'), _('''
  25. By default, images are produced in standard quality, but with DALL·E 3 you can set quality: "hd" to enhance detail. Square, standard quality images are generated fastest.
  26. ''')),
  27. required=True,
  28. default_value='standard',
  29. option_list=[
  30. {'value': 'standard', 'label': 'standard'},
  31. {'value': 'hd', 'label': 'hd'},
  32. ],
  33. text_field='label',
  34. value_field='value'
  35. )
  36. n = forms.SliderField(
  37. TooltipLabel(_('Number of pictures'),
  38. _('You can use DALL·E 3 to request 1 image at a time (requesting more images by issuing parallel requests), or use DALL·E 2 with the n parameter to request up to 10 images at a time.')),
  39. required=True, default_value=1,
  40. _min=1,
  41. _max=10,
  42. _step=1,
  43. precision=0)
  44. class DockerAITextToImageModelCredential(BaseForm, BaseModelCredential):
  45. api_base = forms.TextInputField('API URL', required=True)
  46. api_key = forms.PasswordInputField('API Key', required=True)
  47. def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
  48. raise_exception=False):
  49. model_type_list = provider.get_model_type_list()
  50. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  51. raise AppApiException(ValidCode.valid_error.value,
  52. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  53. for key in ['api_base', 'api_key']:
  54. if key not in model_credential:
  55. if raise_exception:
  56. raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
  57. else:
  58. return False
  59. try:
  60. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  61. res = model.check_auth()
  62. except Exception as e:
  63. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  64. if isinstance(e, AppApiException):
  65. raise e
  66. if raise_exception:
  67. raise AppApiException(ValidCode.valid_error.value,
  68. gettext(
  69. 'Verification failed, please check whether the parameters are correct: {error}').format(
  70. error=str(e)))
  71. else:
  72. return False
  73. return True
  74. def encryption_dict(self, model: Dict[str, object]):
  75. return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
  76. def get_model_params_setting_form(self, model_name):
  77. return DockerAITTIModelParams()