| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- # coding=utf-8
- '''
- requires Python 3.6 or later
- pip install asyncio
- pip install websockets
- '''
- import asyncio
- import copy
- import gzip
- import json
- import re
- import ssl
- import requests
- import uuid_utils.compat as uuid
- from typing import Dict
- import websockets
- from django.utils.translation import gettext as _
- from common.utils.common import _remove_empty_lines
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_tts import BaseTextToSpeech
- MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
- MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
- 2: "last message from server (seq < 0)", 3: "sequence number < 0"}
- MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
- MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
- # version: b0001 (4 bits)
- # header size: b0001 (4 bits)
- # message type: b0001 (Full client request) (4bits)
- # message type specific flags: b0000 (none) (4bits)
- # message serialization method: b0001 (JSON) (4 bits)
- # message compression: b0001 (gzip) (4bits)
- # reserved data: 0x00 (1 byte)
- default_header = bytearray(b'\x11\x10\x11\x00')
- ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- ssl_context.check_hostname = False
- ssl_context.verify_mode = ssl.CERT_NONE
- class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
- volcanic_app_id: str
- volcanic_cluster: str
- volcanic_api_url: str
- volcanic_token: str
- params: dict
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.volcanic_api_url = kwargs.get('volcanic_api_url')
- self.volcanic_token = kwargs.get('volcanic_token')
- self.volcanic_app_id = kwargs.get('volcanic_app_id')
- self.volcanic_cluster = kwargs.get('volcanic_cluster')
- self.params = kwargs.get('params')
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = {'params': {'voice_type': 'zh_female_cancan_mars_bigtts', 'speed_ratio': 1.0}}
- for key, value in model_kwargs.items():
- if key not in ['model_id', 'use_local', 'streaming']:
- optional_params['params'][key] = value
- return VolcanicEngineTextToSpeech(
- volcanic_api_url=model_credential.get('volcanic_api_url'),
- volcanic_token=model_credential.get('volcanic_token'),
- volcanic_app_id=model_credential.get('volcanic_app_id'),
- volcanic_cluster=model_credential.get('volcanic_cluster'),
- **optional_params
- )
- def check_auth(self):
- self.text_to_speech(_('Hello'))
- def text_to_speech(self, text):
- request_json = {
- "app": {
- "appid": self.volcanic_app_id,
- "token": "access_token",
- "cluster": self.volcanic_cluster
- },
- "user": {
- "uid": "uid"
- },
- "audio": {
- "encoding": "mp3",
- "volume_ratio": 1.0,
- "pitch_ratio": 1.0,
- } | self.params,
- "request": {
- "reqid": str(uuid.uuid7()),
- "text": '',
- "text_type": "plain",
- "operation": "xxx"
- }
- }
- text = _remove_empty_lines(text)
- return asyncio.run(self.submit(request_json, text))
- def is_cache_model(self):
- return False
- def token_auth(self):
- return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
- async def submit(self, request_json, text):
- submit_request_json = copy.deepcopy(request_json)
- submit_request_json["request"]["operation"] = "submit"
- header = {"Authorization": f"Bearer; {self.volcanic_token}"}
- result = b''
- async with websockets.connect(self.volcanic_api_url, additional_headers=header, ping_interval=None,
- ssl=ssl_context) as ws:
- lines = [text[i:i + 200] for i in range(0, len(text), 200)]
- for line in lines:
- if self.is_table_format_chars_only(line):
- continue
- submit_request_json["request"]["reqid"] = str(uuid.uuid7())
- submit_request_json["request"]["text"] = line
- payload_bytes = str.encode(json.dumps(submit_request_json))
- payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line
- full_client_request = bytearray(default_header)
- full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
- full_client_request.extend(payload_bytes) # payload
- await ws.send(full_client_request)
- result += await self.parse_response(ws)
- return result
- @staticmethod
- def is_table_format_chars_only(s):
- # 检查是否仅包含 "|", "-", 和空格字符
- return bool(s) and re.fullmatch(r'[|\-\s]+', s)
- @staticmethod
- async def parse_response(ws):
- result = b''
- while True:
- res = await ws.recv()
- protocol_version = res[0] >> 4
- header_size = res[0] & 0x0f
- message_type = res[1] >> 4
- message_type_specific_flags = res[1] & 0x0f
- serialization_method = res[2] >> 4
- message_compression = res[2] & 0x0f
- reserved = res[3]
- header_extensions = res[4:header_size * 4]
- payload = res[header_size * 4:]
- if header_size != 1:
- # print(f" Header extensions: {header_extensions}")
- pass
- if message_type == 0xb: # audio-only server response
- if message_type_specific_flags == 0: # no sequence number as ACK
- continue
- else:
- sequence_number = int.from_bytes(payload[:4], "big", signed=True)
- payload_size = int.from_bytes(payload[4:8], "big", signed=False)
- payload = payload[8:]
- result += payload
- if sequence_number < 0:
- break
- else:
- continue
- elif message_type == 0xf:
- code = int.from_bytes(payload[:4], "big", signed=False)
- msg_size = int.from_bytes(payload[4:8], "big", signed=False)
- error_msg = payload[8:]
- if message_compression == 1:
- error_msg = gzip.decompress(error_msg)
- error_msg = str(error_msg, "utf-8")
- raise Exception(f"Error code: {code}, message: {error_msg}")
- elif message_type == 0xc:
- msg_size = int.from_bytes(payload[:4], "big", signed=False)
- payload = payload[4:]
- if message_compression == 1:
- payload = gzip.decompress(payload)
- else:
- break
- return result
|