whisper_sst.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import base64
  2. import os
  3. import traceback
  4. from typing import Dict
  5. from openai import OpenAI
  6. from common.utils.logger import maxkb_logger
  7. from models_provider.base_model_provider import MaxKBBaseModel
  8. from models_provider.impl.base_stt import BaseSpeechToText
  9. class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
  10. api_key: str
  11. api_url: str
  12. model: str
  13. params: dict
  14. def __init__(self, **kwargs):
  15. super().__init__(**kwargs)
  16. self.api_key = kwargs.get('api_key')
  17. self.model = kwargs.get('model')
  18. self.params = kwargs.get('params')
  19. self.api_url = kwargs.get('api_url')
  20. @staticmethod
  21. def is_cache_model():
  22. return False
  23. @staticmethod
  24. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  25. return VllmWhisperSpeechToText(
  26. model=model_name,
  27. api_key=model_credential.get('api_key'),
  28. api_url=model_credential.get('api_url'),
  29. params=model_kwargs,
  30. **model_kwargs
  31. )
  32. def check_auth(self):
  33. cwd = os.path.dirname(os.path.abspath(__file__))
  34. with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file:
  35. self.speech_to_text(audio_file)
  36. def speech_to_text(self, audio_file):
  37. base_url = self.api_url if self.api_url.endswith('v1') else f"{self.api_url}/v1"
  38. try:
  39. client = OpenAI(
  40. api_key=self.api_key,
  41. base_url=base_url
  42. )
  43. buf = audio_file.read()
  44. filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
  45. transcription_params = {
  46. 'model': self.model,
  47. 'file': buf,
  48. 'language': 'zh',
  49. }
  50. result = client.audio.transcriptions.create(
  51. **transcription_params, extra_body=filter_params
  52. )
  53. return result.text
  54. except Exception as err:
  55. maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}")
  56. raise err