tti.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # coding=utf-8
  2. from typing import Dict
  3. from django.utils.translation import gettext_lazy as _, gettext
  4. from common import forms
  5. from common.exception.app_exception import AppApiException
  6. from common.forms import BaseForm, TooltipLabel
  7. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  8. class XinferenceTTIModelParams(BaseForm):
  9. size = forms.SingleSelect(
  10. TooltipLabel(_('Image size'),
  11. _('The image generation endpoint allows you to create raw images based on text prompts. The dimensions of the image can be 1024x1024, 1024x1792, or 1792x1024 pixels.')),
  12. required=True,
  13. default_value='1024x1024',
  14. option_list=[
  15. {'value': '1024x1024', 'label': '1024x1024'},
  16. {'value': '1024x1792', 'label': '1024x1792'},
  17. {'value': '1792x1024', 'label': '1792x1024'},
  18. ],
  19. text_field='label',
  20. value_field='value'
  21. )
  22. quality = forms.SingleSelect(
  23. TooltipLabel(_('Picture quality'),
  24. _('By default, images are generated in standard quality, you can set quality: "hd" to enhance detail. Square, standard quality images are generated fastest.')),
  25. required=True,
  26. default_value='standard',
  27. option_list=[
  28. {'value': 'standard', 'label': 'standard'},
  29. {'value': 'hd', 'label': 'hd'},
  30. ],
  31. text_field='label',
  32. value_field='value'
  33. )
  34. n = forms.SliderField(
  35. TooltipLabel(_('Number of pictures'),
  36. _('You can request 1 image at a time (requesting more images by making parallel requests), or up to 10 images at a time using the n parameter.')),
  37. required=True, default_value=1,
  38. _min=1,
  39. _max=10,
  40. _step=1,
  41. precision=0)
  42. class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential):
  43. api_base = forms.TextInputField('API URL', required=True)
  44. api_key = forms.PasswordInputField('API Key', required=True)
  45. def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
  46. raise_exception=False):
  47. model_type_list = provider.get_model_type_list()
  48. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  49. raise AppApiException(ValidCode.valid_error.value,
  50. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  51. for key in ['api_base', 'api_key']:
  52. if key not in model_credential:
  53. if raise_exception:
  54. raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
  55. else:
  56. return False
  57. try:
  58. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  59. res = model.check_auth()
  60. except Exception as e:
  61. if isinstance(e, AppApiException):
  62. raise e
  63. if raise_exception:
  64. raise AppApiException(ValidCode.valid_error.value,
  65. gettext(
  66. 'Verification failed, please check whether the parameters are correct: {error}').format(
  67. error=str(e)))
  68. else:
  69. return False
  70. return True
  71. def encryption_dict(self, model: Dict[str, object]):
  72. return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
  73. def get_model_params_setting_form(self, model_name):
  74. return XinferenceTTIModelParams()