| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import base64
- import os
- import traceback
- from typing import Dict
- from openai import OpenAI
- from common.utils.logger import maxkb_logger
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_stt import BaseSpeechToText
- class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText):
- api_key: str
- api_url: str
- model: str
- params: dict
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.api_key = kwargs.get('api_key')
- self.model = kwargs.get('model')
- self.params = kwargs.get('params')
- self.api_url = kwargs.get('api_url')
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- return VllmWhisperSpeechToText(
- model=model_name,
- api_key=model_credential.get('api_key'),
- api_url=model_credential.get('api_url'),
- params=model_kwargs,
- **model_kwargs
- )
- def check_auth(self):
- cwd = os.path.dirname(os.path.abspath(__file__))
- with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file:
- self.speech_to_text(audio_file)
- def speech_to_text(self, audio_file):
- base_url = self.api_url if self.api_url.endswith('v1') else f"{self.api_url}/v1"
- try:
- client = OpenAI(
- api_key=self.api_key,
- base_url=base_url
- )
- buf = audio_file.read()
- filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}}
- transcription_params = {
- 'model': self.model,
- 'file': buf,
- 'language': 'zh',
- }
- result = client.audio.transcriptions.create(
- **transcription_params, extra_body=filter_params
- )
- return result.text
- except Exception as err:
- maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}")
- raise err
|