# coding=utf-8 """ requires Python 3.6 or later pip install asyncio pip install websockets """ import asyncio import base64 import gzip import hmac import json import logging import os import ssl import uuid_utils.compat as uuid import wave from hashlib import sha256 from io import BytesIO from typing import Dict from urllib.parse import urlparse import websockets from common.utils.logger import maxkb_logger from models_provider.base_model_provider import MaxKBBaseModel from models_provider.impl.base_stt import BaseSpeechToText audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置 PROTOCOL_VERSION = 0b0001 DEFAULT_HEADER_SIZE = 0b0001 PROTOCOL_VERSION_BITS = 4 HEADER_BITS = 4 MESSAGE_TYPE_BITS = 4 MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 MESSAGE_SERIALIZATION_BITS = 4 MESSAGE_COMPRESSION_BITS = 4 RESERVED_BITS = 8 # Message Type: CLIENT_FULL_REQUEST = 0b0001 CLIENT_AUDIO_ONLY_REQUEST = 0b0010 SERVER_FULL_RESPONSE = 0b1001 SERVER_ACK = 0b1011 SERVER_ERROR_RESPONSE = 0b1111 # Message Type Specific Flags NO_SEQUENCE = 0b0000 # no check sequence POS_SEQUENCE = 0b0001 NEG_SEQUENCE = 0b0010 NEG_SEQUENCE_1 = 0b0011 # Message Serialization NO_SERIALIZATION = 0b0000 JSON = 0b0001 THRIFT = 0b0011 CUSTOM_TYPE = 0b1111 # Message Compression NO_COMPRESSION = 0b0000 GZIP = 0b0001 CUSTOM_COMPRESSION = 0b1111 ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE def generate_header( version=PROTOCOL_VERSION, message_type=CLIENT_FULL_REQUEST, message_type_specific_flags=NO_SEQUENCE, serial_method=JSON, compression_type=GZIP, reserved_data=0x00, extension_header=bytes() ): """ protocol_version(4 bits), header_size(4 bits), message_type(4 bits), message_type_specific_flags(4 bits) serialization_method(4 bits) message_compression(4 bits) reserved (8bits) 保留字段 header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) """ header = bytearray() header_size = int(len(extension_header) / 4) + 1 header.append((version << 4) | header_size) header.append((message_type << 4) | message_type_specific_flags) header.append((serial_method << 4) | compression_type) header.append(reserved_data) header.extend(extension_header) return header def generate_full_default_header(): return generate_header() def generate_audio_default_header(): return generate_header( message_type=CLIENT_AUDIO_ONLY_REQUEST ) def generate_last_audio_default_header(): return generate_header( message_type=CLIENT_AUDIO_ONLY_REQUEST, message_type_specific_flags=NEG_SEQUENCE ) def parse_response(res): """ protocol_version(4 bits), header_size(4 bits), message_type(4 bits), message_type_specific_flags(4 bits) serialization_method(4 bits) message_compression(4 bits) reserved (8bits) 保留字段 header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) payload 类似与http 请求体 """ protocol_version = res[0] >> 4 header_size = res[0] & 0x0f message_type = res[1] >> 4 message_type_specific_flags = res[1] & 0x0f serialization_method = res[2] >> 4 message_compression = res[2] & 0x0f reserved = res[3] header_extensions = res[4:header_size * 4] payload = res[header_size * 4:] result = {} payload_msg = None payload_size = 0 if message_type == SERVER_FULL_RESPONSE: payload_size = int.from_bytes(payload[:4], "big", signed=True) payload_msg = payload[4:] elif message_type == SERVER_ACK: seq = int.from_bytes(payload[:4], "big", signed=True) result['seq'] = seq if len(payload) >= 8: payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] elif message_type == SERVER_ERROR_RESPONSE: code = int.from_bytes(payload[:4], "big", signed=False) result['code'] = code payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] maxkb_logger.error(f"Error code: {code}, message: {payload_msg}") if payload_msg is None: return result if message_compression == GZIP: payload_msg = gzip.decompress(payload_msg) if serialization_method == JSON: payload_msg = json.loads(str(payload_msg, "utf-8")) elif serialization_method != NO_SERIALIZATION: payload_msg = str(payload_msg, "utf-8") result['payload_msg'] = payload_msg result['payload_size'] = payload_size return result def read_wav_info(data: bytes = None) -> (int, int, int, int, int): with BytesIO(data) as _f: wave_fp = wave.open(_f, 'rb') nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] wave_bytes = wave_fp.readframes(nframes) return nchannels, sampwidth, framerate, nframes, len(wave_bytes) class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate" show_language: bool = False show_utterances: bool = False result_type: str = "full" format: str = "mp3" rate: int = 16000 language: str = "zh-CN" bits: int = 16 channel: int = 1 codec: str = "raw" audio_type: int = 1 secret: str = "access_secret" auth_method: str = "token" mp3_seg_size: int = 10000 success_code: int = 1000 # success code, default is 1000 seg_duration: int = 15000 nbest: int = 1 volcanic_app_id: str volcanic_cluster: str volcanic_api_url: str volcanic_token: str params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.volcanic_api_url = kwargs.get('volcanic_api_url') self.volcanic_token = kwargs.get('volcanic_token') self.volcanic_app_id = kwargs.get('volcanic_app_id') self.volcanic_cluster = kwargs.get('volcanic_cluster') self.params = kwargs.get('params') @staticmethod def is_cache_model(): return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {} if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: optional_params['max_tokens'] = model_kwargs['max_tokens'] if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: optional_params['temperature'] = model_kwargs['temperature'] return VolcanicEngineSpeechToText( volcanic_api_url=model_credential.get('volcanic_api_url'), volcanic_token=model_credential.get('volcanic_token'), volcanic_app_id=model_credential.get('volcanic_app_id'), volcanic_cluster=model_credential.get('volcanic_cluster'), params=model_kwargs, **model_kwargs, **optional_params ) def construct_request(self, reqid): params = self.params or {} req = { 'app': { 'appid': self.volcanic_app_id, 'cluster': self.volcanic_cluster, 'token': self.volcanic_token, }, 'user': { 'uid': params.get("uid", "streaming_asr_demo") }, 'request': { 'reqid': reqid, 'nbest': params.get('nbest', self.nbest), 'workflow': params.get('workflow', self.workflow), 'show_language': params.get('show_language', self.show_language), 'show_utterances': params.get('show_utterances', self.show_utterances), 'result_type': params.get('result_type', self.result_type), 'sequence': params.get('sequence', 1) }, 'audio': { 'format': params.get('format', self.format), 'rate': params.get('rate', self.rate), 'language': params.get('language', self.language), 'bits': params.get('bits', self.bits), 'channel': params.get('channel', self.channel), 'codec': params.get('codec', self.codec) } } return req @staticmethod def slice_data(data: bytes, chunk_size: int) -> (list, bool): """ slice data :param data: wav data :param chunk_size: the segment size in one request :return: segment data, last flag """ data_len = len(data) offset = 0 while offset + chunk_size < data_len: yield data[offset: offset + chunk_size], False offset += chunk_size else: yield data[offset: data_len], True def _real_processor(self, request_params: dict) -> dict: pass def token_auth(self): return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} def signature_auth(self, data): header_dicts = { 'Custom': 'auth_custom', } url_parse = urlparse(self.volcanic_api_url) input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) auth_headers = 'Custom' for header in auth_headers.split(','): input_str += '{}\n'.format(header_dicts[header]) input_data = bytearray(input_str, 'utf-8') input_data += data mac = base64.urlsafe_b64encode( hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token, str(mac, 'utf-8'), auth_headers) return header_dicts async def segment_data_processor(self, wav_data: bytes, segment_size: int): reqid = str(uuid.uuid7()) # 构建 full client request,并序列化压缩 request_params = self.construct_request(reqid) payload_bytes = str.encode(json.dumps(request_params)) payload_bytes = gzip.compress(payload_bytes) full_client_request = bytearray(generate_full_default_header()) full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) full_client_request.extend(payload_bytes) # payload header = None if self.auth_method == "token": header = self.token_auth() elif self.auth_method == "signature": header = self.signature_auth(full_client_request) async with websockets.connect(self.volcanic_api_url, additional_headers=header, max_size=1000000000, ssl=ssl_context) as ws: # 发送 full client request await ws.send(full_client_request) res = await ws.recv() result = parse_response(res) if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: raise Exception( f"Error code: {result['payload_msg']['code']}, message: {result['payload_msg']['message']}") for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1): # if no compression, comment this line payload_bytes = gzip.compress(chunk) audio_only_request = bytearray(generate_audio_default_header()) if last: audio_only_request = bytearray(generate_last_audio_default_header()) audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) audio_only_request.extend(payload_bytes) # payload # 发送 audio-only client request await ws.send(audio_only_request) res = await ws.recv() result = parse_response(res) if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: return result return result['payload_msg']['result'][0]['text'] def check_auth(self): cwd = os.path.dirname(os.path.abspath(__file__)) with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: self.speech_to_text(f) def speech_to_text(self, file): data = file.read() audio_data = bytes(data) if self.format == "mp3": segment_size = self.mp3_seg_size return asyncio.run(self.segment_data_processor(audio_data, segment_size)) if self.format != "wav": raise Exception("format should in wav or mp3") nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( audio_data) size_per_sec = nchannels * sampwidth * framerate segment_size = int(size_per_sec * self.seg_duration / 1000) return asyncio.run(self.segment_data_processor(audio_data, segment_size))