stt.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # -*- coding:utf-8 -*-
  2. #
  3. # 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
  4. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
  5. import asyncio
  6. import base64
  7. import datetime
  8. import hashlib
  9. import hmac
  10. import json
  11. import logging
  12. import os
  13. import ssl
  14. from datetime import datetime, UTC
  15. from typing import Dict
  16. from urllib.parse import urlencode, urlparse
  17. import websockets
  18. from models_provider.base_model_provider import MaxKBBaseModel
  19. from models_provider.impl.base_stt import BaseSpeechToText
  20. STATUS_FIRST_FRAME = 0 # 第一帧的标识
  21. STATUS_CONTINUE_FRAME = 1 # 中间帧标识
  22. STATUS_LAST_FRAME = 2 # 最后一帧的标识
  23. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  24. ssl_context.check_hostname = False
  25. ssl_context.verify_mode = ssl.CERT_NONE
  26. class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
  27. spark_app_id: str
  28. spark_api_key: str
  29. spark_api_secret: str
  30. spark_api_url: str
  31. params: dict
  32. model_name: str
  33. def __init__(self, **kwargs):
  34. super().__init__(**kwargs)
  35. self.spark_api_url = kwargs.get('spark_api_url')
  36. self.spark_app_id = kwargs.get('spark_app_id')
  37. self.spark_api_key = kwargs.get('spark_api_key')
  38. self.spark_api_secret = kwargs.get('spark_api_secret')
  39. self.params = kwargs.get('params')
  40. self.model_name = kwargs.get('model_name')
  41. @staticmethod
  42. def is_cache_model():
  43. return False
  44. @staticmethod
  45. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  46. optional_params = {}
  47. if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
  48. optional_params['max_tokens'] = model_kwargs['max_tokens']
  49. if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
  50. optional_params['temperature'] = model_kwargs['temperature']
  51. return XFSparkSpeechToText(
  52. spark_app_id=model_credential.get('spark_app_id'),
  53. spark_api_key=model_credential.get('spark_api_key'),
  54. spark_api_secret=model_credential.get('spark_api_secret'),
  55. spark_api_url=model_credential.get('spark_api_url'),
  56. params=model_kwargs,
  57. model_name=model_name,
  58. **optional_params
  59. )
  60. # 生成url
  61. def create_url(self):
  62. url = self.spark_api_url
  63. host = urlparse(url).hostname
  64. # 生成RFC1123格式的时间戳
  65. gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
  66. date = datetime.now(UTC).strftime(gmt_format)
  67. # 拼接字符串
  68. signature_origin = "host: " + host + "\n"
  69. signature_origin += "date: " + date + "\n"
  70. signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
  71. # 进行hmac-sha256进行加密
  72. signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  73. digestmod=hashlib.sha256).digest()
  74. signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
  75. authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
  76. self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
  77. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  78. # 将请求的鉴权参数组合为字典
  79. v = {
  80. "authorization": authorization,
  81. "date": date,
  82. "host": host
  83. }
  84. # 拼接鉴权参数,生成url
  85. url = url + '?' + urlencode(v)
  86. # print("date: ",date)
  87. # print("v: ",v)
  88. # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
  89. # print('websocket url :', url)
  90. return url
  91. def check_auth(self):
  92. cwd = os.path.dirname(os.path.abspath(__file__))
  93. with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
  94. self.speech_to_text(f)
  95. def speech_to_text(self, file):
  96. async def handle():
  97. async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
  98. # 发送 full client request
  99. await self.send(ws, file)
  100. return await self.handle_message(ws)
  101. return asyncio.run(handle())
  102. @staticmethod
  103. async def handle_message(ws):
  104. res = await ws.recv()
  105. message = json.loads(res)
  106. code = message["code"]
  107. sid = message["sid"]
  108. if code != 0:
  109. errMsg = message["message"]
  110. raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
  111. else:
  112. data = message["data"]["result"]["ws"]
  113. result = ""
  114. for i in data:
  115. for w in i["cw"]:
  116. result += w["w"]
  117. # print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
  118. return result
  119. # 收到websocket连接建立的处理
  120. async def send(self, ws, file):
  121. frameSize = 8000 # 每一帧的音频大小
  122. status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
  123. allowed_params = {'language', 'domain', 'accent', 'vad_eos', 'dwa', 'pd', 'ptt',
  124. 'pcm', 'ltc', 'rlang', 'vinfo', 'nunum', 'speex_size', 'nbest', 'wbest'}
  125. business_params = {k: v for k, v in self.params.items() if k in allowed_params}
  126. if not business_params:
  127. business_params = {
  128. "domain": f'{self.model_name}',
  129. "language": "zh_cn",
  130. "accent": "mandarin",
  131. "vinfo": 1,
  132. "vad_eos": 10000
  133. }
  134. while True:
  135. buf = file.read(frameSize)
  136. # 文件结束
  137. if not buf:
  138. status = STATUS_LAST_FRAME
  139. # 第一帧处理
  140. # 发送第一帧音频,带business 参数
  141. # appid 必须带上,只需第一帧发送
  142. if status == STATUS_FIRST_FRAME:
  143. d = {
  144. "common": {"app_id": self.spark_app_id},
  145. "business": {
  146. **business_params
  147. },
  148. "data": {
  149. "status": 0, "format": "audio/L16;rate=16000",
  150. "audio": str(base64.b64encode(buf), 'utf-8'),
  151. "encoding": "lame"}
  152. }
  153. d = json.dumps(d)
  154. await ws.send(d)
  155. status = STATUS_CONTINUE_FRAME
  156. # 中间帧处理
  157. elif status == STATUS_CONTINUE_FRAME:
  158. d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
  159. "audio": str(base64.b64encode(buf), 'utf-8'),
  160. "encoding": "lame"}}
  161. await ws.send(json.dumps(d))
  162. # 最后一帧处理
  163. elif status == STATUS_LAST_FRAME:
  164. d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
  165. "audio": str(base64.b64encode(buf), 'utf-8'),
  166. "encoding": "lame"}}
  167. await ws.send(json.dumps(d))
  168. break