tts.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # coding=utf-8
  2. '''
  3. requires Python 3.6 or later
  4. pip install asyncio
  5. pip install websockets
  6. '''
  7. import asyncio
  8. import copy
  9. import gzip
  10. import json
  11. import re
  12. import ssl
  13. import requests
  14. import uuid_utils.compat as uuid
  15. from typing import Dict
  16. import websockets
  17. from django.utils.translation import gettext as _
  18. from common.utils.common import _remove_empty_lines
  19. from models_provider.base_model_provider import MaxKBBaseModel
  20. from models_provider.impl.base_tts import BaseTextToSpeech
  21. MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
  22. MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
  23. 2: "last message from server (seq < 0)", 3: "sequence number < 0"}
  24. MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
  25. MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
  26. # version: b0001 (4 bits)
  27. # header size: b0001 (4 bits)
  28. # message type: b0001 (Full client request) (4bits)
  29. # message type specific flags: b0000 (none) (4bits)
  30. # message serialization method: b0001 (JSON) (4 bits)
  31. # message compression: b0001 (gzip) (4bits)
  32. # reserved data: 0x00 (1 byte)
  33. default_header = bytearray(b'\x11\x10\x11\x00')
  34. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  35. ssl_context.check_hostname = False
  36. ssl_context.verify_mode = ssl.CERT_NONE
  37. class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
  38. volcanic_app_id: str
  39. volcanic_cluster: str
  40. volcanic_api_url: str
  41. volcanic_token: str
  42. params: dict
  43. def __init__(self, **kwargs):
  44. super().__init__(**kwargs)
  45. self.volcanic_api_url = kwargs.get('volcanic_api_url')
  46. self.volcanic_token = kwargs.get('volcanic_token')
  47. self.volcanic_app_id = kwargs.get('volcanic_app_id')
  48. self.volcanic_cluster = kwargs.get('volcanic_cluster')
  49. self.params = kwargs.get('params')
  50. @staticmethod
  51. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  52. optional_params = {'params': {'voice_type': 'zh_female_cancan_mars_bigtts', 'speed_ratio': 1.0}}
  53. for key, value in model_kwargs.items():
  54. if key not in ['model_id', 'use_local', 'streaming']:
  55. optional_params['params'][key] = value
  56. return VolcanicEngineTextToSpeech(
  57. volcanic_api_url=model_credential.get('volcanic_api_url'),
  58. volcanic_token=model_credential.get('volcanic_token'),
  59. volcanic_app_id=model_credential.get('volcanic_app_id'),
  60. volcanic_cluster=model_credential.get('volcanic_cluster'),
  61. **optional_params
  62. )
  63. def check_auth(self):
  64. self.text_to_speech(_('Hello'))
  65. def text_to_speech(self, text):
  66. request_json = {
  67. "app": {
  68. "appid": self.volcanic_app_id,
  69. "token": "access_token",
  70. "cluster": self.volcanic_cluster
  71. },
  72. "user": {
  73. "uid": "uid"
  74. },
  75. "audio": {
  76. "encoding": "mp3",
  77. "volume_ratio": 1.0,
  78. "pitch_ratio": 1.0,
  79. } | self.params,
  80. "request": {
  81. "reqid": str(uuid.uuid7()),
  82. "text": '',
  83. "text_type": "plain",
  84. "operation": "xxx"
  85. }
  86. }
  87. text = _remove_empty_lines(text)
  88. return asyncio.run(self.submit(request_json, text))
  89. def is_cache_model(self):
  90. return False
  91. def token_auth(self):
  92. return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
  93. async def submit(self, request_json, text):
  94. submit_request_json = copy.deepcopy(request_json)
  95. submit_request_json["request"]["operation"] = "submit"
  96. header = {"Authorization": f"Bearer; {self.volcanic_token}"}
  97. result = b''
  98. async with websockets.connect(self.volcanic_api_url, additional_headers=header, ping_interval=None,
  99. ssl=ssl_context) as ws:
  100. lines = [text[i:i + 200] for i in range(0, len(text), 200)]
  101. for line in lines:
  102. if self.is_table_format_chars_only(line):
  103. continue
  104. submit_request_json["request"]["reqid"] = str(uuid.uuid7())
  105. submit_request_json["request"]["text"] = line
  106. payload_bytes = str.encode(json.dumps(submit_request_json))
  107. payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line
  108. full_client_request = bytearray(default_header)
  109. full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
  110. full_client_request.extend(payload_bytes) # payload
  111. await ws.send(full_client_request)
  112. result += await self.parse_response(ws)
  113. return result
  114. @staticmethod
  115. def is_table_format_chars_only(s):
  116. # 检查是否仅包含 "|", "-", 和空格字符
  117. return bool(s) and re.fullmatch(r'[|\-\s]+', s)
  118. @staticmethod
  119. async def parse_response(ws):
  120. result = b''
  121. while True:
  122. res = await ws.recv()
  123. protocol_version = res[0] >> 4
  124. header_size = res[0] & 0x0f
  125. message_type = res[1] >> 4
  126. message_type_specific_flags = res[1] & 0x0f
  127. serialization_method = res[2] >> 4
  128. message_compression = res[2] & 0x0f
  129. reserved = res[3]
  130. header_extensions = res[4:header_size * 4]
  131. payload = res[header_size * 4:]
  132. if header_size != 1:
  133. # print(f" Header extensions: {header_extensions}")
  134. pass
  135. if message_type == 0xb: # audio-only server response
  136. if message_type_specific_flags == 0: # no sequence number as ACK
  137. continue
  138. else:
  139. sequence_number = int.from_bytes(payload[:4], "big", signed=True)
  140. payload_size = int.from_bytes(payload[4:8], "big", signed=False)
  141. payload = payload[8:]
  142. result += payload
  143. if sequence_number < 0:
  144. break
  145. else:
  146. continue
  147. elif message_type == 0xf:
  148. code = int.from_bytes(payload[:4], "big", signed=False)
  149. msg_size = int.from_bytes(payload[4:8], "big", signed=False)
  150. error_msg = payload[8:]
  151. if message_compression == 1:
  152. error_msg = gzip.decompress(error_msg)
  153. error_msg = str(error_msg, "utf-8")
  154. raise Exception(f"Error code: {code}, message: {error_msg}")
  155. elif message_type == 0xc:
  156. msg_size = int.from_bytes(payload[:4], "big", signed=False)
  157. payload = payload[4:]
  158. if message_compression == 1:
  159. payload = gzip.decompress(payload)
  160. else:
  161. break
  162. return result