tts.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # -*- coding:utf-8 -*-
  2. #
  3. # author: iflytek
  4. #
  5. # 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
  6. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
  7. import asyncio
  8. import base64
  9. import datetime
  10. import hashlib
  11. import hmac
  12. import json
  13. import logging
  14. import ssl
  15. from datetime import datetime, UTC
  16. from typing import Dict
  17. from urllib.parse import urlencode, urlparse
  18. import websockets
  19. from django.utils.translation import gettext as _
  20. from common.utils.common import _remove_empty_lines
  21. from models_provider.base_model_provider import MaxKBBaseModel
  22. from models_provider.impl.base_tts import BaseTextToSpeech
  23. STATUS_FIRST_FRAME = 0 # 第一帧的标识
  24. STATUS_CONTINUE_FRAME = 1 # 中间帧标识
  25. STATUS_LAST_FRAME = 2 # 最后一帧的标识
  26. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  27. ssl_context.check_hostname = False
  28. ssl_context.verify_mode = ssl.CERT_NONE
  29. class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
  30. spark_app_id: str
  31. spark_api_key: str
  32. spark_api_secret: str
  33. spark_api_url: str
  34. params: dict
  35. def __init__(self, **kwargs):
  36. super().__init__(**kwargs)
  37. self.spark_api_url = kwargs.get('spark_api_url')
  38. self.spark_app_id = kwargs.get('spark_app_id')
  39. self.spark_api_key = kwargs.get('spark_api_key')
  40. self.spark_api_secret = kwargs.get('spark_api_secret')
  41. self.params = kwargs.get('params')
  42. @staticmethod
  43. def is_cache_model():
  44. return False
  45. @staticmethod
  46. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  47. optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}}
  48. for key, value in model_kwargs.items():
  49. if key not in ['model_id', 'use_local', 'streaming']:
  50. optional_params['params'][key] = value
  51. return XFSparkTextToSpeech(
  52. spark_app_id=model_credential.get('spark_app_id'),
  53. spark_api_key=model_credential.get('spark_api_key'),
  54. spark_api_secret=model_credential.get('spark_api_secret'),
  55. spark_api_url=model_credential.get('spark_api_url'),
  56. **optional_params
  57. )
  58. # 生成url
  59. def create_url(self):
  60. url = self.spark_api_url
  61. host = urlparse(url).hostname
  62. # 生成RFC1123格式的时间戳
  63. gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
  64. date = datetime.now(UTC).strftime(gmt_format)
  65. # 拼接字符串
  66. signature_origin = "host: " + host + "\n"
  67. signature_origin += "date: " + date + "\n"
  68. signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
  69. # 进行hmac-sha256进行加密
  70. signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  71. digestmod=hashlib.sha256).digest()
  72. signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
  73. authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
  74. self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
  75. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  76. # 将请求的鉴权参数组合为字典
  77. v = {
  78. "authorization": authorization,
  79. "date": date,
  80. "host": host
  81. }
  82. # 拼接鉴权参数,生成url
  83. url = url + '?' + urlencode(v)
  84. # print("date: ",date)
  85. # print("v: ",v)
  86. # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
  87. # print('websocket url :', url)
  88. return url
  89. def check_auth(self):
  90. self.text_to_speech(_('Hello'))
  91. def text_to_speech(self, text):
  92. # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
  93. # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
  94. text = _remove_empty_lines(text)
  95. async def handle():
  96. async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
  97. # 发送 full client request
  98. await self.send(ws, text)
  99. return await self.handle_message(ws)
  100. return asyncio.run(handle())
  101. def is_cache_model(self):
  102. return False
  103. @staticmethod
  104. async def handle_message(ws):
  105. audio_bytes: bytes = b''
  106. while True:
  107. res = await ws.recv()
  108. message = json.loads(res)
  109. # print(message)
  110. code = message["code"]
  111. sid = message["sid"]
  112. if code != 0:
  113. errMsg = message["message"]
  114. raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
  115. else:
  116. audio = message["data"]["audio"]
  117. audio = base64.b64decode(audio)
  118. audio_bytes += audio
  119. # 退出
  120. if message["data"]["status"] == 2:
  121. break
  122. return audio_bytes
  123. async def send(self, ws, text):
  124. business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"}
  125. d = {
  126. "common": {"app_id": self.spark_app_id},
  127. "business": business | self.params,
  128. "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
  129. }
  130. d = json.dumps(d)
  131. await ws.send(d)