| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- # -*- coding:utf-8 -*-
- #
- # 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
- # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
- import asyncio
- import base64
- import datetime
- import hashlib
- import hmac
- import json
- import logging
- import os
- import ssl
- from datetime import datetime, UTC
- from typing import Dict
- from urllib.parse import urlencode, urlparse
- import websockets
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_stt import BaseSpeechToText
- STATUS_FIRST_FRAME = 0 # 第一帧的标识
- STATUS_CONTINUE_FRAME = 1 # 中间帧标识
- STATUS_LAST_FRAME = 2 # 最后一帧的标识
- ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- ssl_context.check_hostname = False
- ssl_context.verify_mode = ssl.CERT_NONE
- class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
- spark_app_id: str
- spark_api_key: str
- spark_api_secret: str
- spark_api_url: str
- params: dict
- model_name: str
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.spark_api_url = kwargs.get('spark_api_url')
- self.spark_app_id = kwargs.get('spark_app_id')
- self.spark_api_key = kwargs.get('spark_api_key')
- self.spark_api_secret = kwargs.get('spark_api_secret')
- self.params = kwargs.get('params')
- self.model_name = kwargs.get('model_name')
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = {}
- if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
- optional_params['max_tokens'] = model_kwargs['max_tokens']
- if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
- optional_params['temperature'] = model_kwargs['temperature']
- return XFSparkSpeechToText(
- spark_app_id=model_credential.get('spark_app_id'),
- spark_api_key=model_credential.get('spark_api_key'),
- spark_api_secret=model_credential.get('spark_api_secret'),
- spark_api_url=model_credential.get('spark_api_url'),
- params=model_kwargs,
- model_name=model_name,
- **optional_params
- )
- # 生成url
- def create_url(self):
- url = self.spark_api_url
- host = urlparse(url).hostname
- # 生成RFC1123格式的时间戳
- gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
- date = datetime.now(UTC).strftime(gmt_format)
- # 拼接字符串
- signature_origin = "host: " + host + "\n"
- signature_origin += "date: " + date + "\n"
- signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
- # 进行hmac-sha256进行加密
- signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
- digestmod=hashlib.sha256).digest()
- signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
- authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
- self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
- authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
- # 将请求的鉴权参数组合为字典
- v = {
- "authorization": authorization,
- "date": date,
- "host": host
- }
- # 拼接鉴权参数,生成url
- url = url + '?' + urlencode(v)
- # print("date: ",date)
- # print("v: ",v)
- # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
- # print('websocket url :', url)
- return url
- def check_auth(self):
- cwd = os.path.dirname(os.path.abspath(__file__))
- with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
- self.speech_to_text(f)
- def speech_to_text(self, file):
- async def handle():
- async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
- # 发送 full client request
- await self.send(ws, file)
- return await self.handle_message(ws)
- return asyncio.run(handle())
- @staticmethod
- async def handle_message(ws):
- res = await ws.recv()
- message = json.loads(res)
- code = message["code"]
- sid = message["sid"]
- if code != 0:
- errMsg = message["message"]
- raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
- else:
- data = message["data"]["result"]["ws"]
- result = ""
- for i in data:
- for w in i["cw"]:
- result += w["w"]
- # print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
- return result
- # 收到websocket连接建立的处理
- async def send(self, ws, file):
- frameSize = 8000 # 每一帧的音频大小
- status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
- allowed_params = {'language', 'domain', 'accent', 'vad_eos', 'dwa', 'pd', 'ptt',
- 'pcm', 'ltc', 'rlang', 'vinfo', 'nunum', 'speex_size', 'nbest', 'wbest'}
- business_params = {k: v for k, v in self.params.items() if k in allowed_params}
- if not business_params:
- business_params = {
- "domain": f'{self.model_name}',
- "language": "zh_cn",
- "accent": "mandarin",
- "vinfo": 1,
- "vad_eos": 10000
- }
- while True:
- buf = file.read(frameSize)
- # 文件结束
- if not buf:
- status = STATUS_LAST_FRAME
- # 第一帧处理
- # 发送第一帧音频,带business 参数
- # appid 必须带上,只需第一帧发送
- if status == STATUS_FIRST_FRAME:
- d = {
- "common": {"app_id": self.spark_app_id},
- "business": {
- **business_params
- },
- "data": {
- "status": 0, "format": "audio/L16;rate=16000",
- "audio": str(base64.b64encode(buf), 'utf-8'),
- "encoding": "lame"}
- }
- d = json.dumps(d)
- await ws.send(d)
- status = STATUS_CONTINUE_FRAME
- # 中间帧处理
- elif status == STATUS_CONTINUE_FRAME:
- d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
- "audio": str(base64.b64encode(buf), 'utf-8'),
- "encoding": "lame"}}
- await ws.send(json.dumps(d))
- # 最后一帧处理
- elif status == STATUS_LAST_FRAME:
- d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
- "audio": str(base64.b64encode(buf), 'utf-8'),
- "encoding": "lame"}}
- await ws.send(json.dumps(d))
- break
|