whisper_stt.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # coding=utf-8
  2. import traceback
  3. from typing import Dict
  4. from django.utils.translation import gettext_lazy as _, gettext
  5. from langchain_core.messages import HumanMessage
  6. from common import forms
  7. from common.exception.app_exception import AppApiException
  8. from common.forms import BaseForm, TooltipLabel
  9. from models_provider.base_model_provider import BaseModelCredential, ValidCode
  10. class VLLMWhisperModelParams(BaseForm):
  11. Language = forms.TextInputField(
  12. TooltipLabel(_('language'),
  13. _("If not passed, the default value is 'zh'")),
  14. required=True,
  15. default_value='zh',
  16. )
  17. class VLLMWhisperModelCredential(BaseForm, BaseModelCredential):
  18. api_url = forms.TextInputField('API URL', required=True)
  19. api_key = forms.PasswordInputField('API Key', required=True)
  20. def is_valid(self,
  21. model_type: str,
  22. model_name,
  23. model_credential: Dict[str, object],
  24. model_params,
  25. provider,
  26. raise_exception=False):
  27. model_type_list = provider.get_model_type_list()
  28. if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
  29. raise AppApiException(ValidCode.valid_error.value,
  30. gettext('{model_type} Model type is not supported').format(model_type=model_type))
  31. try:
  32. model_list = provider.get_base_model_list(model_credential.get('api_url'), model_credential.get('api_key'))
  33. except Exception as e:
  34. raise AppApiException(ValidCode.valid_error.value, gettext('API domain name is invalid'))
  35. exist = provider.get_model_info_by_name(model_list, model_name)
  36. if len(exist) == 0:
  37. raise AppApiException(ValidCode.valid_error.value,
  38. gettext('The model does not exist, please download the model first'))
  39. model = provider.get_model(model_type, model_name, model_credential, **model_params)
  40. return True
  41. def encryption_dict(self, model_info: Dict[str, object]):
  42. return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
  43. def build_model(self, model_info: Dict[str, object]):
  44. for key in ['api_key', 'model']:
  45. if key not in model_info:
  46. raise AppApiException(500, gettext('{key} is required').format(key=key))
  47. self.api_key = model_info.get('api_key')
  48. return self
  49. def get_model_params_setting_form(self, model_name):
  50. return VLLMWhisperModelParams()