| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- # -*- coding:utf-8 -*-
- #
- # author: iflytek
- #
- # 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
- # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
- import asyncio
- import base64
- import datetime
- import hashlib
- import hmac
- import json
- import logging
- import ssl
- from datetime import datetime, UTC
- from typing import Dict
- from urllib.parse import urlencode, urlparse
- import websockets
- from django.utils.translation import gettext as _
- from common.utils.common import _remove_empty_lines
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_tts import BaseTextToSpeech
- STATUS_FIRST_FRAME = 0 # 第一帧的标识
- STATUS_CONTINUE_FRAME = 1 # 中间帧标识
- STATUS_LAST_FRAME = 2 # 最后一帧的标识
- ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- ssl_context.check_hostname = False
- ssl_context.verify_mode = ssl.CERT_NONE
- class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
- spark_app_id: str
- spark_api_key: str
- spark_api_secret: str
- spark_api_url: str
- params: dict
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.spark_api_url = kwargs.get('spark_api_url')
- self.spark_app_id = kwargs.get('spark_app_id')
- self.spark_api_key = kwargs.get('spark_api_key')
- self.spark_api_secret = kwargs.get('spark_api_secret')
- self.params = kwargs.get('params')
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}}
- for key, value in model_kwargs.items():
- if key not in ['model_id', 'use_local', 'streaming']:
- optional_params['params'][key] = value
- return XFSparkTextToSpeech(
- spark_app_id=model_credential.get('spark_app_id'),
- spark_api_key=model_credential.get('spark_api_key'),
- spark_api_secret=model_credential.get('spark_api_secret'),
- spark_api_url=model_credential.get('spark_api_url'),
- **optional_params
- )
- # 生成url
- def create_url(self):
- url = self.spark_api_url
- host = urlparse(url).hostname
- # 生成RFC1123格式的时间戳
- gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
- date = datetime.now(UTC).strftime(gmt_format)
- # 拼接字符串
- signature_origin = "host: " + host + "\n"
- signature_origin += "date: " + date + "\n"
- signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
- # 进行hmac-sha256进行加密
- signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
- digestmod=hashlib.sha256).digest()
- signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
- authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
- self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
- authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
- # 将请求的鉴权参数组合为字典
- v = {
- "authorization": authorization,
- "date": date,
- "host": host
- }
- # 拼接鉴权参数,生成url
- url = url + '?' + urlencode(v)
- # print("date: ",date)
- # print("v: ",v)
- # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
- # print('websocket url :', url)
- return url
- def check_auth(self):
- self.text_to_speech(_('Hello'))
- def text_to_speech(self, text):
- # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
- # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
- text = _remove_empty_lines(text)
- async def handle():
- async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
- # 发送 full client request
- await self.send(ws, text)
- return await self.handle_message(ws)
- return asyncio.run(handle())
- def is_cache_model(self):
- return False
- @staticmethod
- async def handle_message(ws):
- audio_bytes: bytes = b''
- while True:
- res = await ws.recv()
- message = json.loads(res)
- # print(message)
- code = message["code"]
- sid = message["sid"]
- if code != 0:
- errMsg = message["message"]
- raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
- else:
- audio = message["data"]["audio"]
- audio = base64.b64decode(audio)
- audio_bytes += audio
- # 退出
- if message["data"]["status"] == 2:
- break
- return audio_bytes
- async def send(self, ws, text):
- business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"}
- d = {
- "common": {"app_id": self.spark_app_id},
- "business": business | self.params,
- "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
- }
- d = json.dumps(d)
- await ws.send(d)
|