tts.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # coding=utf-8
  2. from typing import Dict
  3. from django.utils.translation import gettext_lazy as _, gettext
  4. from common import forms
  5. from common.exception.app_exception import AppApiException
  6. from common.forms import BaseForm, TooltipLabel
  7. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  8. from common.utils.logger import maxkb_logger
  9. class XunFeiTTSModelGeneralParams(BaseForm):
  10. vcn = forms.SingleSelect(
  11. TooltipLabel(_('Speaker'),
  12. _('Speaker, optional value: Please go to the console to add a trial or purchase speaker. After adding, the speaker parameter value will be displayed.')),
  13. required=True, default_value='xiaoyan',
  14. text_field='value',
  15. value_field='value',
  16. option_list=[
  17. {'text': _('iFlytek Xiaoyan'), 'value': 'xiaoyan'},
  18. {'text': _('iFlytek Xujiu'), 'value': 'aisjiuxu'},
  19. {'text': _('iFlytek Xiaoping'), 'value': 'aisxping'},
  20. {'text': _('iFlytek Xiaojing'), 'value': 'aisjinger'},
  21. {'text': _('iFlytek Xuxiaobao'), 'value': 'aisbabyxu'},
  22. ])
  23. speed = forms.SliderField(
  24. TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')),
  25. required=True, default_value=50,
  26. _min=1,
  27. _max=100,
  28. _step=5,
  29. precision=1)
  30. class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
  31. spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts')
  32. spark_app_id = forms.TextInputField('APP ID', required=True)
  33. spark_api_key = forms.PasswordInputField("API Key", required=True)
  34. spark_api_secret = forms.PasswordInputField('API Secret', required=True)
  35. def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
  36. raise_exception=False):
  37. model_type_list = provider.get_model_type_list()
  38. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  39. raise AppApiException(ValidCode.valid_error.value,
  40. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  41. for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
  42. if key not in model_credential:
  43. if raise_exception:
  44. raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
  45. else:
  46. return False
  47. try:
  48. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  49. model.check_auth()
  50. except Exception as e:
  51. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  52. if isinstance(e, AppApiException):
  53. raise e
  54. if raise_exception:
  55. raise AppApiException(ValidCode.valid_error.value,
  56. gettext(
  57. 'Verification failed, please check whether the parameters are correct: {error}').format(
  58. error=str(e)))
  59. else:
  60. return False
  61. return True
  62. def encryption_dict(self, model: Dict[str, object]):
  63. return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
  64. def get_model_params_setting_form(self, model_name):
  65. return XunFeiTTSModelGeneralParams()