| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355 |
- # coding=utf-8
- """
- requires Python 3.6 or later
- pip install asyncio
- pip install websockets
- """
- import asyncio
- import base64
- import gzip
- import hmac
- import json
- import logging
- import os
- import ssl
- import uuid_utils.compat as uuid
- import wave
- from hashlib import sha256
- from io import BytesIO
- from typing import Dict
- from urllib.parse import urlparse
- import websockets
- from common.utils.logger import maxkb_logger
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_stt import BaseSpeechToText
- audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置
- PROTOCOL_VERSION = 0b0001
- DEFAULT_HEADER_SIZE = 0b0001
- PROTOCOL_VERSION_BITS = 4
- HEADER_BITS = 4
- MESSAGE_TYPE_BITS = 4
- MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
- MESSAGE_SERIALIZATION_BITS = 4
- MESSAGE_COMPRESSION_BITS = 4
- RESERVED_BITS = 8
- # Message Type:
- CLIENT_FULL_REQUEST = 0b0001
- CLIENT_AUDIO_ONLY_REQUEST = 0b0010
- SERVER_FULL_RESPONSE = 0b1001
- SERVER_ACK = 0b1011
- SERVER_ERROR_RESPONSE = 0b1111
- # Message Type Specific Flags
- NO_SEQUENCE = 0b0000 # no check sequence
- POS_SEQUENCE = 0b0001
- NEG_SEQUENCE = 0b0010
- NEG_SEQUENCE_1 = 0b0011
- # Message Serialization
- NO_SERIALIZATION = 0b0000
- JSON = 0b0001
- THRIFT = 0b0011
- CUSTOM_TYPE = 0b1111
- # Message Compression
- NO_COMPRESSION = 0b0000
- GZIP = 0b0001
- CUSTOM_COMPRESSION = 0b1111
- ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- ssl_context.check_hostname = False
- ssl_context.verify_mode = ssl.CERT_NONE
- def generate_header(
- version=PROTOCOL_VERSION,
- message_type=CLIENT_FULL_REQUEST,
- message_type_specific_flags=NO_SEQUENCE,
- serial_method=JSON,
- compression_type=GZIP,
- reserved_data=0x00,
- extension_header=bytes()
- ):
- """
- protocol_version(4 bits), header_size(4 bits),
- message_type(4 bits), message_type_specific_flags(4 bits)
- serialization_method(4 bits) message_compression(4 bits)
- reserved (8bits) 保留字段
- header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
- """
- header = bytearray()
- header_size = int(len(extension_header) / 4) + 1
- header.append((version << 4) | header_size)
- header.append((message_type << 4) | message_type_specific_flags)
- header.append((serial_method << 4) | compression_type)
- header.append(reserved_data)
- header.extend(extension_header)
- return header
- def generate_full_default_header():
- return generate_header()
- def generate_audio_default_header():
- return generate_header(
- message_type=CLIENT_AUDIO_ONLY_REQUEST
- )
- def generate_last_audio_default_header():
- return generate_header(
- message_type=CLIENT_AUDIO_ONLY_REQUEST,
- message_type_specific_flags=NEG_SEQUENCE
- )
- def parse_response(res):
- """
- protocol_version(4 bits), header_size(4 bits),
- message_type(4 bits), message_type_specific_flags(4 bits)
- serialization_method(4 bits) message_compression(4 bits)
- reserved (8bits) 保留字段
- header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
- payload 类似与http 请求体
- """
- protocol_version = res[0] >> 4
- header_size = res[0] & 0x0f
- message_type = res[1] >> 4
- message_type_specific_flags = res[1] & 0x0f
- serialization_method = res[2] >> 4
- message_compression = res[2] & 0x0f
- reserved = res[3]
- header_extensions = res[4:header_size * 4]
- payload = res[header_size * 4:]
- result = {}
- payload_msg = None
- payload_size = 0
- if message_type == SERVER_FULL_RESPONSE:
- payload_size = int.from_bytes(payload[:4], "big", signed=True)
- payload_msg = payload[4:]
- elif message_type == SERVER_ACK:
- seq = int.from_bytes(payload[:4], "big", signed=True)
- result['seq'] = seq
- if len(payload) >= 8:
- payload_size = int.from_bytes(payload[4:8], "big", signed=False)
- payload_msg = payload[8:]
- elif message_type == SERVER_ERROR_RESPONSE:
- code = int.from_bytes(payload[:4], "big", signed=False)
- result['code'] = code
- payload_size = int.from_bytes(payload[4:8], "big", signed=False)
- payload_msg = payload[8:]
- maxkb_logger.error(f"Error code: {code}, message: {payload_msg}")
- if payload_msg is None:
- return result
- if message_compression == GZIP:
- payload_msg = gzip.decompress(payload_msg)
- if serialization_method == JSON:
- payload_msg = json.loads(str(payload_msg, "utf-8"))
- elif serialization_method != NO_SERIALIZATION:
- payload_msg = str(payload_msg, "utf-8")
- result['payload_msg'] = payload_msg
- result['payload_size'] = payload_size
- return result
- def read_wav_info(data: bytes = None) -> (int, int, int, int, int):
- with BytesIO(data) as _f:
- wave_fp = wave.open(_f, 'rb')
- nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
- wave_bytes = wave_fp.readframes(nframes)
- return nchannels, sampwidth, framerate, nframes, len(wave_bytes)
- class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
- workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate"
- show_language: bool = False
- show_utterances: bool = False
- result_type: str = "full"
- format: str = "mp3"
- rate: int = 16000
- language: str = "zh-CN"
- bits: int = 16
- channel: int = 1
- codec: str = "raw"
- audio_type: int = 1
- secret: str = "access_secret"
- auth_method: str = "token"
- mp3_seg_size: int = 10000
- success_code: int = 1000 # success code, default is 1000
- seg_duration: int = 15000
- nbest: int = 1
- volcanic_app_id: str
- volcanic_cluster: str
- volcanic_api_url: str
- volcanic_token: str
- params: dict
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.volcanic_api_url = kwargs.get('volcanic_api_url')
- self.volcanic_token = kwargs.get('volcanic_token')
- self.volcanic_app_id = kwargs.get('volcanic_app_id')
- self.volcanic_cluster = kwargs.get('volcanic_cluster')
- self.params = kwargs.get('params')
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = {}
- if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
- optional_params['max_tokens'] = model_kwargs['max_tokens']
- if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
- optional_params['temperature'] = model_kwargs['temperature']
- return VolcanicEngineSpeechToText(
- volcanic_api_url=model_credential.get('volcanic_api_url'),
- volcanic_token=model_credential.get('volcanic_token'),
- volcanic_app_id=model_credential.get('volcanic_app_id'),
- volcanic_cluster=model_credential.get('volcanic_cluster'),
- params=model_kwargs,
- **model_kwargs,
- **optional_params
- )
- def construct_request(self, reqid):
- params = self.params or {}
- req = {
- 'app': {
- 'appid': self.volcanic_app_id,
- 'cluster': self.volcanic_cluster,
- 'token': self.volcanic_token,
- },
- 'user': {
- 'uid': params.get("uid", "streaming_asr_demo")
- },
- 'request': {
- 'reqid': reqid,
- 'nbest': params.get('nbest', self.nbest),
- 'workflow': params.get('workflow', self.workflow),
- 'show_language': params.get('show_language', self.show_language),
- 'show_utterances': params.get('show_utterances', self.show_utterances),
- 'result_type': params.get('result_type', self.result_type),
- 'sequence': params.get('sequence', 1)
- },
- 'audio': {
- 'format': params.get('format', self.format),
- 'rate': params.get('rate', self.rate),
- 'language': params.get('language', self.language),
- 'bits': params.get('bits', self.bits),
- 'channel': params.get('channel', self.channel),
- 'codec': params.get('codec', self.codec)
- }
- }
- return req
- @staticmethod
- def slice_data(data: bytes, chunk_size: int) -> (list, bool):
- """
- slice data
- :param data: wav data
- :param chunk_size: the segment size in one request
- :return: segment data, last flag
- """
- data_len = len(data)
- offset = 0
- while offset + chunk_size < data_len:
- yield data[offset: offset + chunk_size], False
- offset += chunk_size
- else:
- yield data[offset: data_len], True
- def _real_processor(self, request_params: dict) -> dict:
- pass
- def token_auth(self):
- return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
- def signature_auth(self, data):
- header_dicts = {
- 'Custom': 'auth_custom',
- }
- url_parse = urlparse(self.volcanic_api_url)
- input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path)
- auth_headers = 'Custom'
- for header in auth_headers.split(','):
- input_str += '{}\n'.format(header_dicts[header])
- input_data = bytearray(input_str, 'utf-8')
- input_data += data
- mac = base64.urlsafe_b64encode(
- hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest())
- header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token,
- str(mac, 'utf-8'),
- auth_headers)
- return header_dicts
- async def segment_data_processor(self, wav_data: bytes, segment_size: int):
- reqid = str(uuid.uuid7())
- # 构建 full client request,并序列化压缩
- request_params = self.construct_request(reqid)
- payload_bytes = str.encode(json.dumps(request_params))
- payload_bytes = gzip.compress(payload_bytes)
- full_client_request = bytearray(generate_full_default_header())
- full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
- full_client_request.extend(payload_bytes) # payload
- header = None
- if self.auth_method == "token":
- header = self.token_auth()
- elif self.auth_method == "signature":
- header = self.signature_auth(full_client_request)
- async with websockets.connect(self.volcanic_api_url, additional_headers=header, max_size=1000000000,
- ssl=ssl_context) as ws:
- # 发送 full client request
- await ws.send(full_client_request)
- res = await ws.recv()
- result = parse_response(res)
- if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
- raise Exception(
- f"Error code: {result['payload_msg']['code']}, message: {result['payload_msg']['message']}")
- for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1):
- # if no compression, comment this line
- payload_bytes = gzip.compress(chunk)
- audio_only_request = bytearray(generate_audio_default_header())
- if last:
- audio_only_request = bytearray(generate_last_audio_default_header())
- audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
- audio_only_request.extend(payload_bytes) # payload
- # 发送 audio-only client request
- await ws.send(audio_only_request)
- res = await ws.recv()
- result = parse_response(res)
- if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
- return result
- return result['payload_msg']['result'][0]['text']
- def check_auth(self):
- cwd = os.path.dirname(os.path.abspath(__file__))
- with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
- self.speech_to_text(f)
- def speech_to_text(self, file):
- data = file.read()
- audio_data = bytes(data)
- if self.format == "mp3":
- segment_size = self.mp3_seg_size
- return asyncio.run(self.segment_data_processor(audio_data, segment_size))
- if self.format != "wav":
- raise Exception("format should in wav or mp3")
- nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info(
- audio_data)
- size_per_sec = nchannels * sampwidth * framerate
- segment_size = int(size_per_sec * self.seg_duration / 1000)
- return asyncio.run(self.segment_data_processor(audio_data, segment_size))
|