tts.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # coding=utf-8
  2. from typing import Dict
  3. from django.utils.translation import gettext_lazy as _, gettext
  4. from common import forms
  5. from common.exception.app_exception import AppApiException
  6. from common.forms import BaseForm, TooltipLabel
  7. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  8. from common.utils.logger import maxkb_logger
  9. class AzureOpenAITTSModelGeneralParams(BaseForm):
  10. # alloy, echo, fable, onyx, nova, shimmer
  11. voice = forms.SingleSelect(
  12. TooltipLabel('Voice',
  13. _('Try out the different sounds (Alloy, Echo, Fable, Onyx, Nova, and Sparkle) to find one that suits your desired tone and audience. The current voiceover is optimized for English.')),
  14. required=True, default_value='alloy',
  15. text_field='value',
  16. value_field='value',
  17. option_list=[
  18. {'text': 'alloy', 'value': 'alloy'},
  19. {'text': 'echo', 'value': 'echo'},
  20. {'text': 'fable', 'value': 'fable'},
  21. {'text': 'onyx', 'value': 'onyx'},
  22. {'text': 'nova', 'value': 'nova'},
  23. {'text': 'shimmer', 'value': 'shimmer'},
  24. ])
  25. class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
  26. api_version = forms.TextInputField("API Version", required=True)
  27. api_base = forms.TextInputField('Azure Endpoint', required=True)
  28. api_key = forms.PasswordInputField("API Key", required=True)
  29. def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
  30. raise_exception=False):
  31. model_type_list = provider.get_model_type_list()
  32. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  33. raise AppApiException(ValidCode.valid_error.value,
  34. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  35. for key in ['api_base', 'api_key', 'api_version']:
  36. if key not in model_credential:
  37. if raise_exception:
  38. raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
  39. else:
  40. return False
  41. try:
  42. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  43. model.check_auth()
  44. except Exception as e:
  45. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  46. if isinstance(e, AppApiException):
  47. raise e
  48. if raise_exception:
  49. raise AppApiException(ValidCode.valid_error.value, gettext(
  50. 'Verification failed, please check whether the parameters are correct: {error}').format(
  51. error=str(e)))
  52. else:
  53. return False
  54. return True
  55. def encryption_dict(self, model: Dict[str, object]):
  56. return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
  57. def get_model_params_setting_form(self, model_name):
  58. return AzureOpenAITTSModelGeneralParams()