tti.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # coding=utf-8
  2. from typing import Dict
  3. from django.utils.translation import gettext_lazy as _, gettext
  4. from common import forms
  5. from common.exception.app_exception import AppApiException
  6. from common.forms import BaseForm, TooltipLabel
  7. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  8. from common.utils.logger import maxkb_logger
  9. class RegoloTTIModelParams(BaseForm):
  10. size = forms.SingleSelect(
  11. TooltipLabel(_('Image size'),
  12. _('The image generation endpoint allows you to create raw images based on text prompts. ')),
  13. required=True,
  14. default_value='1024x1024',
  15. option_list=[
  16. {'value': '1024x1024', 'label': '1024x1024'},
  17. {'value': '1024x1792', 'label': '1024x1792'},
  18. {'value': '1792x1024', 'label': '1792x1024'},
  19. ],
  20. text_field='label',
  21. value_field='value'
  22. )
  23. quality = forms.SingleSelect(
  24. TooltipLabel(_('Picture quality'), _('''
  25. By default, images are produced in standard quality.
  26. ''')),
  27. required=True,
  28. default_value='standard',
  29. option_list=[
  30. {'value': 'standard', 'label': 'standard'},
  31. {'value': 'hd', 'label': 'hd'},
  32. ],
  33. text_field='label',
  34. value_field='value'
  35. )
  36. n = forms.SliderField(
  37. TooltipLabel(_('Number of pictures'),
  38. _('1 as default')),
  39. required=True, default_value=1,
  40. _min=1,
  41. _max=10,
  42. _step=1,
  43. precision=0)
  44. class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential):
  45. api_base = forms.TextInputField('API URL', required=True, default_value='https://api.regolo.ai/v1')
  46. api_key = forms.PasswordInputField('API Key', required=True)
  47. def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
  48. raise_exception=False):
  49. model_type_list = provider.get_model_type_list()
  50. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  51. raise AppApiException(ValidCode.valid_error.value,
  52. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  53. for key in ['api_key', 'api_base']:
  54. if key not in model_credential:
  55. if raise_exception:
  56. raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
  57. else:
  58. return False
  59. try:
  60. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  61. res = model.check_auth()
  62. except Exception as e:
  63. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  64. if isinstance(e, AppApiException):
  65. raise e
  66. if raise_exception:
  67. raise AppApiException(ValidCode.valid_error.value,
  68. gettext(
  69. 'Verification failed, please check whether the parameters are correct: {error}').format(
  70. error=str(e)))
  71. else:
  72. return False
  73. return True
  74. def encryption_dict(self, model: Dict[str, object]):
  75. return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
  76. def get_model_params_setting_form(self, model_name):
  77. return RegoloTTIModelParams()