stt.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # coding=utf-8
  2. """
  3. requires Python 3.6 or later
  4. pip install asyncio
  5. pip install websockets
  6. """
  7. import asyncio
  8. import base64
  9. import gzip
  10. import hmac
  11. import json
  12. import logging
  13. import os
  14. import ssl
  15. import uuid_utils.compat as uuid
  16. import wave
  17. from hashlib import sha256
  18. from io import BytesIO
  19. from typing import Dict
  20. from urllib.parse import urlparse
  21. import websockets
  22. from common.utils.logger import maxkb_logger
  23. from models_provider.base_model_provider import MaxKBBaseModel
  24. from models_provider.impl.base_stt import BaseSpeechToText
  25. audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置
  26. PROTOCOL_VERSION = 0b0001
  27. DEFAULT_HEADER_SIZE = 0b0001
  28. PROTOCOL_VERSION_BITS = 4
  29. HEADER_BITS = 4
  30. MESSAGE_TYPE_BITS = 4
  31. MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
  32. MESSAGE_SERIALIZATION_BITS = 4
  33. MESSAGE_COMPRESSION_BITS = 4
  34. RESERVED_BITS = 8
  35. # Message Type:
  36. CLIENT_FULL_REQUEST = 0b0001
  37. CLIENT_AUDIO_ONLY_REQUEST = 0b0010
  38. SERVER_FULL_RESPONSE = 0b1001
  39. SERVER_ACK = 0b1011
  40. SERVER_ERROR_RESPONSE = 0b1111
  41. # Message Type Specific Flags
  42. NO_SEQUENCE = 0b0000 # no check sequence
  43. POS_SEQUENCE = 0b0001
  44. NEG_SEQUENCE = 0b0010
  45. NEG_SEQUENCE_1 = 0b0011
  46. # Message Serialization
  47. NO_SERIALIZATION = 0b0000
  48. JSON = 0b0001
  49. THRIFT = 0b0011
  50. CUSTOM_TYPE = 0b1111
  51. # Message Compression
  52. NO_COMPRESSION = 0b0000
  53. GZIP = 0b0001
  54. CUSTOM_COMPRESSION = 0b1111
  55. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  56. ssl_context.check_hostname = False
  57. ssl_context.verify_mode = ssl.CERT_NONE
  58. def generate_header(
  59. version=PROTOCOL_VERSION,
  60. message_type=CLIENT_FULL_REQUEST,
  61. message_type_specific_flags=NO_SEQUENCE,
  62. serial_method=JSON,
  63. compression_type=GZIP,
  64. reserved_data=0x00,
  65. extension_header=bytes()
  66. ):
  67. """
  68. protocol_version(4 bits), header_size(4 bits),
  69. message_type(4 bits), message_type_specific_flags(4 bits)
  70. serialization_method(4 bits) message_compression(4 bits)
  71. reserved (8bits) 保留字段
  72. header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
  73. """
  74. header = bytearray()
  75. header_size = int(len(extension_header) / 4) + 1
  76. header.append((version << 4) | header_size)
  77. header.append((message_type << 4) | message_type_specific_flags)
  78. header.append((serial_method << 4) | compression_type)
  79. header.append(reserved_data)
  80. header.extend(extension_header)
  81. return header
  82. def generate_full_default_header():
  83. return generate_header()
  84. def generate_audio_default_header():
  85. return generate_header(
  86. message_type=CLIENT_AUDIO_ONLY_REQUEST
  87. )
  88. def generate_last_audio_default_header():
  89. return generate_header(
  90. message_type=CLIENT_AUDIO_ONLY_REQUEST,
  91. message_type_specific_flags=NEG_SEQUENCE
  92. )
  93. def parse_response(res):
  94. """
  95. protocol_version(4 bits), header_size(4 bits),
  96. message_type(4 bits), message_type_specific_flags(4 bits)
  97. serialization_method(4 bits) message_compression(4 bits)
  98. reserved (8bits) 保留字段
  99. header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
  100. payload 类似与http 请求体
  101. """
  102. protocol_version = res[0] >> 4
  103. header_size = res[0] & 0x0f
  104. message_type = res[1] >> 4
  105. message_type_specific_flags = res[1] & 0x0f
  106. serialization_method = res[2] >> 4
  107. message_compression = res[2] & 0x0f
  108. reserved = res[3]
  109. header_extensions = res[4:header_size * 4]
  110. payload = res[header_size * 4:]
  111. result = {}
  112. payload_msg = None
  113. payload_size = 0
  114. if message_type == SERVER_FULL_RESPONSE:
  115. payload_size = int.from_bytes(payload[:4], "big", signed=True)
  116. payload_msg = payload[4:]
  117. elif message_type == SERVER_ACK:
  118. seq = int.from_bytes(payload[:4], "big", signed=True)
  119. result['seq'] = seq
  120. if len(payload) >= 8:
  121. payload_size = int.from_bytes(payload[4:8], "big", signed=False)
  122. payload_msg = payload[8:]
  123. elif message_type == SERVER_ERROR_RESPONSE:
  124. code = int.from_bytes(payload[:4], "big", signed=False)
  125. result['code'] = code
  126. payload_size = int.from_bytes(payload[4:8], "big", signed=False)
  127. payload_msg = payload[8:]
  128. maxkb_logger.error(f"Error code: {code}, message: {payload_msg}")
  129. if payload_msg is None:
  130. return result
  131. if message_compression == GZIP:
  132. payload_msg = gzip.decompress(payload_msg)
  133. if serialization_method == JSON:
  134. payload_msg = json.loads(str(payload_msg, "utf-8"))
  135. elif serialization_method != NO_SERIALIZATION:
  136. payload_msg = str(payload_msg, "utf-8")
  137. result['payload_msg'] = payload_msg
  138. result['payload_size'] = payload_size
  139. return result
  140. def read_wav_info(data: bytes = None) -> (int, int, int, int, int):
  141. with BytesIO(data) as _f:
  142. wave_fp = wave.open(_f, 'rb')
  143. nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
  144. wave_bytes = wave_fp.readframes(nframes)
  145. return nchannels, sampwidth, framerate, nframes, len(wave_bytes)
  146. class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
  147. workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate"
  148. show_language: bool = False
  149. show_utterances: bool = False
  150. result_type: str = "full"
  151. format: str = "mp3"
  152. rate: int = 16000
  153. language: str = "zh-CN"
  154. bits: int = 16
  155. channel: int = 1
  156. codec: str = "raw"
  157. audio_type: int = 1
  158. secret: str = "access_secret"
  159. auth_method: str = "token"
  160. mp3_seg_size: int = 10000
  161. success_code: int = 1000 # success code, default is 1000
  162. seg_duration: int = 15000
  163. nbest: int = 1
  164. volcanic_app_id: str
  165. volcanic_cluster: str
  166. volcanic_api_url: str
  167. volcanic_token: str
  168. params: dict
  169. def __init__(self, **kwargs):
  170. super().__init__(**kwargs)
  171. self.volcanic_api_url = kwargs.get('volcanic_api_url')
  172. self.volcanic_token = kwargs.get('volcanic_token')
  173. self.volcanic_app_id = kwargs.get('volcanic_app_id')
  174. self.volcanic_cluster = kwargs.get('volcanic_cluster')
  175. self.params = kwargs.get('params')
  176. @staticmethod
  177. def is_cache_model():
  178. return False
  179. @staticmethod
  180. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  181. optional_params = {}
  182. if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
  183. optional_params['max_tokens'] = model_kwargs['max_tokens']
  184. if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
  185. optional_params['temperature'] = model_kwargs['temperature']
  186. return VolcanicEngineSpeechToText(
  187. volcanic_api_url=model_credential.get('volcanic_api_url'),
  188. volcanic_token=model_credential.get('volcanic_token'),
  189. volcanic_app_id=model_credential.get('volcanic_app_id'),
  190. volcanic_cluster=model_credential.get('volcanic_cluster'),
  191. params=model_kwargs,
  192. **model_kwargs,
  193. **optional_params
  194. )
  195. def construct_request(self, reqid):
  196. params = self.params or {}
  197. req = {
  198. 'app': {
  199. 'appid': self.volcanic_app_id,
  200. 'cluster': self.volcanic_cluster,
  201. 'token': self.volcanic_token,
  202. },
  203. 'user': {
  204. 'uid': params.get("uid", "streaming_asr_demo")
  205. },
  206. 'request': {
  207. 'reqid': reqid,
  208. 'nbest': params.get('nbest', self.nbest),
  209. 'workflow': params.get('workflow', self.workflow),
  210. 'show_language': params.get('show_language', self.show_language),
  211. 'show_utterances': params.get('show_utterances', self.show_utterances),
  212. 'result_type': params.get('result_type', self.result_type),
  213. 'sequence': params.get('sequence', 1)
  214. },
  215. 'audio': {
  216. 'format': params.get('format', self.format),
  217. 'rate': params.get('rate', self.rate),
  218. 'language': params.get('language', self.language),
  219. 'bits': params.get('bits', self.bits),
  220. 'channel': params.get('channel', self.channel),
  221. 'codec': params.get('codec', self.codec)
  222. }
  223. }
  224. return req
  225. @staticmethod
  226. def slice_data(data: bytes, chunk_size: int) -> (list, bool):
  227. """
  228. slice data
  229. :param data: wav data
  230. :param chunk_size: the segment size in one request
  231. :return: segment data, last flag
  232. """
  233. data_len = len(data)
  234. offset = 0
  235. while offset + chunk_size < data_len:
  236. yield data[offset: offset + chunk_size], False
  237. offset += chunk_size
  238. else:
  239. yield data[offset: data_len], True
  240. def _real_processor(self, request_params: dict) -> dict:
  241. pass
  242. def token_auth(self):
  243. return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
  244. def signature_auth(self, data):
  245. header_dicts = {
  246. 'Custom': 'auth_custom',
  247. }
  248. url_parse = urlparse(self.volcanic_api_url)
  249. input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path)
  250. auth_headers = 'Custom'
  251. for header in auth_headers.split(','):
  252. input_str += '{}\n'.format(header_dicts[header])
  253. input_data = bytearray(input_str, 'utf-8')
  254. input_data += data
  255. mac = base64.urlsafe_b64encode(
  256. hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest())
  257. header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token,
  258. str(mac, 'utf-8'),
  259. auth_headers)
  260. return header_dicts
  261. async def segment_data_processor(self, wav_data: bytes, segment_size: int):
  262. reqid = str(uuid.uuid7())
  263. # 构建 full client request,并序列化压缩
  264. request_params = self.construct_request(reqid)
  265. payload_bytes = str.encode(json.dumps(request_params))
  266. payload_bytes = gzip.compress(payload_bytes)
  267. full_client_request = bytearray(generate_full_default_header())
  268. full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
  269. full_client_request.extend(payload_bytes) # payload
  270. header = None
  271. if self.auth_method == "token":
  272. header = self.token_auth()
  273. elif self.auth_method == "signature":
  274. header = self.signature_auth(full_client_request)
  275. async with websockets.connect(self.volcanic_api_url, additional_headers=header, max_size=1000000000,
  276. ssl=ssl_context) as ws:
  277. # 发送 full client request
  278. await ws.send(full_client_request)
  279. res = await ws.recv()
  280. result = parse_response(res)
  281. if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
  282. raise Exception(
  283. f"Error code: {result['payload_msg']['code']}, message: {result['payload_msg']['message']}")
  284. for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1):
  285. # if no compression, comment this line
  286. payload_bytes = gzip.compress(chunk)
  287. audio_only_request = bytearray(generate_audio_default_header())
  288. if last:
  289. audio_only_request = bytearray(generate_last_audio_default_header())
  290. audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
  291. audio_only_request.extend(payload_bytes) # payload
  292. # 发送 audio-only client request
  293. await ws.send(audio_only_request)
  294. res = await ws.recv()
  295. result = parse_response(res)
  296. if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
  297. return result
  298. return result['payload_msg']['result'][0]['text']
  299. def check_auth(self):
  300. cwd = os.path.dirname(os.path.abspath(__file__))
  301. with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
  302. self.speech_to_text(f)
  303. def speech_to_text(self, file):
  304. data = file.read()
  305. audio_data = bytes(data)
  306. if self.format == "mp3":
  307. segment_size = self.mp3_seg_size
  308. return asyncio.run(self.segment_data_processor(audio_data, segment_size))
  309. if self.format != "wav":
  310. raise Exception("format should in wav or mp3")
  311. nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info(
  312. audio_data)
  313. size_per_sec = nchannels * sampwidth * framerate
  314. segment_size = int(size_per_sec * self.seg_duration / 1000)
  315. return asyncio.run(self.segment_data_processor(audio_data, segment_size))