| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853 |
- """
- TTS语音合成服务
- 提供语音合成的业务逻辑处理,集成阿里云百炼平台DashScope
- 需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8
- 支持: 非流式合成、流式合成、长文本合成、文本切割
- """
- import io
- import logging
- import os
- import re
- import uuid
- from datetime import date, datetime
- from typing import AsyncGenerator, List, Optional
- from sqlalchemy.orm import Session
- from fastapi import HTTPException
- from app.models.audio import AudioSynthesis
- from app.schemas.audio_schema import (
- TTSRequest, TTSResponse, LongTTSResponse, TTSModelResponse
- )
- from app.services.oss_service import get_oss_service
- from decimal import Decimal
- logger = logging.getLogger(__name__)
- class TTSService:
- """TTS语音合成服务类"""
-
- # TTS模型配置
- TTS_MODELS = [
- {
- "id": 1,
- "title": "cosyvoice-v3-flash",
- "name": "CosyVoice V3 Flash",
- "description": "平衡效果与成本,性价比高",
- "price": "0.14335/万字符",
- "features": ["快速合成", "支持SSML", "支持Instruct"]
- },
- {
- "id": 2,
- "title": "cosyvoice-v3-plus",
- "name": "CosyVoice V3 Plus",
- "description": "最高质量,最佳表现力",
- "price": "0.286706/万字符",
- "features": ["高质量", "支持SSML", "支持Instruct"]
- },
- {
- "id": 3,
- "title": "cosyvoice-v2",
- "name": "CosyVoice V2",
- "description": "兼容旧版,稳定可靠",
- "price": "0.286706/万字符",
- "features": ["稳定", "支持SSML"]
- }
- ]
-
- # 有效的TTS模型名称
- VALID_MODELS = ["cosyvoice-v3-flash", "cosyvoice-v3-plus", "cosyvoice-v2"]
-
- def __init__(self, db: Session, user_id: str, api_key: str = None):
- """
- 初始化TTS服务
-
- Args:
- db: 数据库会话
- user_id: 用户ID
- api_key: 用户的API密钥(从用户数据动态加载)
- """
- self.db = db
- self.user_id = user_id
- self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
- self.oss_service = get_oss_service()
-
- def _calculate_char_count(self, text: str) -> int:
- """
- 计算字符数(汉字算两个字符)
-
- Args:
- text: 待计算的文本
-
- Returns:
- 字符数(汉字算两个字符)
- """
- count = 0
- for char in text:
- # 判断是否为汉字(CJK统一汉字范围)
- if '\u4e00' <= char <= '\u9fff':
- count += 2
- else:
- count += 1
- return count
-
- def _validate_request(self, request: TTSRequest) -> None:
- """
- 验证TTS请求参数
-
- Args:
- request: TTS请求对象
-
- Raises:
- HTTPException: 参数验证失败
- """
- # 验证文本不为空
- if not request.text or not request.text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空或仅包含空白字符")
-
- # 验证模型(动态查库)
- from app.models.model import ModelNew, ModelCategory
- from sqlalchemy import cast
- from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
- valid_model = self.db.query(ModelNew).filter(
- ModelNew.model_code == request.model,
- ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
- ModelNew.is_api_enabled == True,
- ).first()
- if not valid_model:
- raise HTTPException(
- status_code=400,
- detail=f"无效的模型名称: {request.model}"
- )
-
- # 验证文本长度(非长文本合成时,汉字算两个字符)
- char_count = self._calculate_char_count(request.text)
- if char_count > 2000:
- raise HTTPException(
- status_code=400,
- detail="文本长度超过限制(最大2000字符,汉字算两个字符),请使用长文本转语音功能"
- )
-
- def _get_audio_format(self, format_str: str, sample_rate: int):
- """
- 获取DashScope音频格式枚举
-
- Args:
- format_str: 格式字符串 (mp3, wav, pcm, opus)
- sample_rate: 采样率
-
- Returns:
- AudioFormat枚举值
- """
- from dashscope.audio.tts_v2 import AudioFormat
-
- format_map = {
- ("mp3", 16000): AudioFormat.MP3_16000HZ_MONO_128KBPS,
- ("mp3", 22050): AudioFormat.MP3_22050HZ_MONO_256KBPS,
- ("mp3", 24000): AudioFormat.MP3_24000HZ_MONO_256KBPS,
- ("mp3", 44100): AudioFormat.MP3_44100HZ_MONO_256KBPS,
- ("mp3", 48000): AudioFormat.MP3_48000HZ_MONO_256KBPS,
- ("wav", 8000): AudioFormat.WAV_8000HZ_MONO_16BIT,
- ("wav", 16000): AudioFormat.WAV_16000HZ_MONO_16BIT,
- ("wav", 22050): AudioFormat.WAV_22050HZ_MONO_16BIT,
- ("wav", 24000): AudioFormat.WAV_24000HZ_MONO_16BIT,
- ("wav", 44100): AudioFormat.WAV_44100HZ_MONO_16BIT,
- ("wav", 48000): AudioFormat.WAV_48000HZ_MONO_16BIT,
- ("pcm", 8000): AudioFormat.PCM_8000HZ_MONO_16BIT,
- ("pcm", 16000): AudioFormat.PCM_16000HZ_MONO_16BIT,
- ("pcm", 22050): AudioFormat.PCM_22050HZ_MONO_16BIT,
- ("pcm", 24000): AudioFormat.PCM_24000HZ_MONO_16BIT,
- ("pcm", 44100): AudioFormat.PCM_44100HZ_MONO_16BIT,
- ("pcm", 48000): AudioFormat.PCM_48000HZ_MONO_16BIT,
- }
-
- key = (format_str.lower(), sample_rate)
- if key in format_map:
- return format_map[key]
-
- # 默认返回MP3 22050Hz
- return AudioFormat.MP3_22050HZ_MONO_256KBPS
-
- def _generate_oss_path(self, format_str: str) -> str:
- """
- 生成OSS存储路径
-
- Args:
- format_str: 音频格式
-
- Returns:
- OSS路径
- """
- date_str = date.today().strftime('%Y%m%d')
- unique_id = uuid.uuid4().hex
- return f"audio/tts/{date_str}/{unique_id}.{format_str}"
-
- def _estimate_duration(self, audio_data: bytes, format_str: str, sample_rate: int) -> float:
- """
- 估算音频时长
-
- Args:
- audio_data: 音频数据
- format_str: 音频格式
- sample_rate: 采样率
-
- Returns:
- 估算的时长(秒)
- """
- if format_str == "pcm":
- # PCM: 16bit mono = 2 bytes per sample
- return len(audio_data) / (sample_rate * 2)
- elif format_str == "wav":
- # WAV: 跳过44字节头,16bit mono
- return (len(audio_data) - 44) / (sample_rate * 2)
- elif format_str == "mp3":
- # MP3: 粗略估算,128kbps = 16KB/s
- return len(audio_data) / 16000
- else:
- # 默认估算
- return len(audio_data) / 16000
-
- async def synthesize(self, request: TTSRequest) -> TTSResponse:
- """
- 非流式语音合成
-
- Args:
- request: TTS请求对象
-
- Returns:
- TTS响应对象
-
- Raises:
- HTTPException: 合成失败
- """
- import dashscope
- from dashscope.audio.tts_v2 import SpeechSynthesizer
-
- # 验证请求
- self._validate_request(request)
-
- # 设置API Key
- dashscope.api_key = self.api_key
- try:
- # 获取音频格式
- audio_format = self._get_audio_format(request.format, request.sample_rate)
-
- # 创建合成器
- synthesizer = SpeechSynthesizer(
- model=request.model,
- voice=request.voice,
- format=audio_format,
- volume=request.volume,
- speech_rate=request.speech_rate,
- pitch_rate=request.pitch_rate
- )
-
- # 合成音频(同步阻塞调用放到线程池,避免阻塞 event loop)
- import asyncio
- loop = asyncio.get_event_loop()
- audio_data = await loop.run_in_executor(None, synthesizer.call, request.text)
-
- if not audio_data:
- raise HTTPException(status_code=502, detail="语音合成失败,未返回音频数据")
-
- # 上传到OSS
- oss_path = self._generate_oss_path(request.format)
- audio_url = self.oss_service.upload_file(
- audio_data,
- prefix="audio/tts",
- original_filename=f"audio.{request.format}"
- )
-
- # 估算时长
- duration = self._estimate_duration(audio_data, request.format, request.sample_rate)
- # 计算费用(API调用免费)
- bill = Decimal("0")
- # 保存合成记录
- synthesis_record = AudioSynthesis(
- user_id=self.user_id,
- model=request.model,
- voice=request.voice,
- text=request.text,
- audio_url=audio_url,
- duration=duration,
- format=request.format,
- characters=len(request.text),
- bill=bill,
- completed_at=datetime.now()
- )
- self.db.add(synthesis_record)
- self.db.commit()
- self.db.refresh(synthesis_record)
- return TTSResponse(
- audio_url=audio_url,
- duration=round(duration, 2),
- format=request.format,
- sample_rate=request.sample_rate,
- characters=len(request.text)
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"TTS合成失败: {type(e).__name__}: {str(e)}")
- raise HTTPException(status_code=502, detail=f"语音合成失败: {str(e)}")
-
- async def synthesize_stream(self, request: TTSRequest) -> AsyncGenerator[bytes, None]:
- """
- 流式语音合成
-
- Args:
- request: TTS请求对象
-
- Yields:
- 音频数据块
-
- Raises:
- HTTPException: 合成失败
- """
- import dashscope
- from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback
-
- # 验证请求(流式模式下不限制文本长度)
- if not request.text or not request.text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空或仅包含空白字符")
- # 动态查库验证模型
- from app.models.model import ModelNew, ModelCategory
- from sqlalchemy import cast
- from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
- valid_model = self.db.query(ModelNew).filter(
- ModelNew.model_code == request.model,
- ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
- ModelNew.is_api_enabled == True,
- ).first()
- if not valid_model:
- raise HTTPException(status_code=400, detail=f"无效的模型名称: {request.model}")
-
- # 设置API Key
- dashscope.api_key = self.api_key
-
- # 用于收集音频数据的队列
- import asyncio
- audio_queue = asyncio.Queue()
-
- class StreamCallback(ResultCallback):
- def on_open(self):
- pass
-
- def on_complete(self):
- asyncio.get_event_loop().call_soon_threadsafe(
- audio_queue.put_nowait, None
- )
-
- def on_error(self, message: str):
- asyncio.get_event_loop().call_soon_threadsafe(
- audio_queue.put_nowait, Exception(message)
- )
-
- def on_event(self, message):
- pass
-
- def on_data(self, data: bytes):
- asyncio.get_event_loop().call_soon_threadsafe(
- audio_queue.put_nowait, data
- )
-
- try:
- # 获取音频格式
- audio_format = self._get_audio_format(request.format, request.sample_rate)
-
- # 创建合成器
- callback = StreamCallback()
- synthesizer = SpeechSynthesizer(
- model=request.model,
- voice=request.voice,
- format=audio_format,
- volume=request.volume,
- speech_rate=request.speech_rate,
- pitch_rate=request.pitch_rate,
- callback=callback
- )
-
- # 启动流式合成
- synthesizer.streaming_call(request.text)
- synthesizer.streaming_complete()
-
- # 从队列中读取数据
- while True:
- data = await audio_queue.get()
- if data is None:
- break
- if isinstance(data, Exception):
- raise HTTPException(status_code=502, detail=str(data))
- yield data
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"TTS流式合成失败: {type(e).__name__}: {str(e)}")
- raise HTTPException(status_code=502, detail=f"语音合成失败: {str(e)}")
-
- def split_text(self, text: str, max_length: int = 2000) -> List[str]:
- """
- 按句子边界智能切割文本
-
- 根据文档要求,文本长度限制按实际字符数(Unicode字符数)计算,每段不超过max_length字符。
- 优先在句子边界(。!?;\n)处切割,避免截断句子。
- 如果单句超过限制,则在逗号处切割;如果仍超过,则强制按字符数切割。
-
- Args:
- text: 待切割的文本
- max_length: 每段最大长度(Unicode字符数),默认2000
-
- Returns:
- 切割后的文本列表
- """
- if not text:
- return []
-
- # 按实际字符数计算(Unicode字符数)
- text_length = len(text)
- if text_length <= max_length:
- return [text]
-
- segments = []
- current = ""
- current_length = 0
-
- # 按句子分隔符切割:。!?;\n
- # 使用正则表达式保留分隔符
- sentence_pattern = r'([。!?;\n])'
- parts = re.split(sentence_pattern, text)
-
- i = 0
- while i < len(parts):
- part = parts[i]
- # 获取分隔符(如果存在)
- delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(sentence_pattern, parts[i + 1]) else ""
- if delimiter:
- i += 1
-
- full_sentence = part + delimiter
- sentence_length = len(full_sentence)
-
- # 如果当前段加上新句子不超过限制
- if current_length + sentence_length <= max_length:
- current += full_sentence
- current_length += sentence_length
- else:
- # 如果当前段不为空,保存它
- if current:
- segments.append(current)
- current = ""
- current_length = 0
-
- # 如果单个句子超过限制,需要进一步切割
- if sentence_length > max_length:
- # 按逗号切割
- sub_segments = self._split_long_sentence(full_sentence, max_length)
- segments.extend(sub_segments[:-1])
- if sub_segments:
- current = sub_segments[-1]
- current_length = len(current)
- else:
- current = full_sentence
- current_length = sentence_length
-
- i += 1
-
- # 添加最后一段
- if current:
- segments.append(current)
-
- return segments
-
- def _split_long_sentence(self, sentence: str, max_length: int) -> List[str]:
- """
- 切割超长句子(按逗号或强制切割)
-
- 按实际字符数(Unicode字符数)进行切割,优先在逗号处切割,避免截断词语。
- 如果单个部分仍超过限制,则强制按字符数切割。
-
- Args:
- sentence: 超长句子
- max_length: 最大长度(Unicode字符数)
-
- Returns:
- 切割后的片段列表
- """
- sentence_length = len(sentence)
- if sentence_length <= max_length:
- return [sentence]
-
- segments = []
- current = ""
- current_length = 0
-
- # 按逗号切割
- comma_pattern = r'([,,、])'
- parts = re.split(comma_pattern, sentence)
-
- i = 0
- while i < len(parts):
- part = parts[i]
- delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(comma_pattern, parts[i + 1]) else ""
- if delimiter:
- i += 1
-
- full_part = part + delimiter
- part_length = len(full_part)
-
- if current_length + part_length <= max_length:
- current += full_part
- current_length += part_length
- else:
- if current:
- segments.append(current)
- current = ""
- current_length = 0
-
- # 如果单个部分仍然超过限制,强制按字符数切割
- if part_length > max_length:
- # 按字符数切割
- for j in range(0, part_length, max_length):
- chunk = full_part[j:j + max_length]
- if j + max_length < part_length:
- segments.append(chunk)
- else:
- current = chunk
- current_length = len(chunk)
- else:
- current = full_part
- current_length = part_length
-
- i += 1
-
- if current:
- segments.append(current)
-
- return segments
-
- async def synthesize_long(self, request: TTSRequest) -> LongTTSResponse:
- """
- 长文本语音合成(非流式输出)
-
- 根据文档要求,由于声音合成限制输入字符(单次不超过2000字符),
- 长文本转音频通过分割文本方式实现,使用非流式输出(call方法)。
-
- 实现流程:
- 1. 将长文本按句子边界智能切割,每段不超过2000字符
- 2. 对每个文本段使用非流式调用(call方法)进行合成
- 3. 合并所有音频片段
- 4. 转换为目标格式并上传到OSS
-
- Args:
- request: TTS请求对象(text可超过2000字符)
-
- Returns:
- 长文本TTS响应对象
-
- Raises:
- HTTPException: 合成失败
- """
- import dashscope
- from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
-
- # 验证文本不为空
- if not request.text or not request.text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空或仅包含空白字符")
-
- # 验证模型(动态查库)
- from app.models.model import ModelNew, ModelCategory
- from sqlalchemy import cast
- from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
- valid_model = self.db.query(ModelNew).filter(
- ModelNew.model_code == request.model,
- ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
- ModelNew.is_api_enabled == True,
- ).first()
- if not valid_model:
- raise HTTPException(
- status_code=400,
- detail=f"无效的模型名称: {request.model}"
- )
-
- # 设置API Key
- dashscope.api_key = self.api_key
-
- try:
- # 切割文本(按实际字符数,每段不超过2000字符)
- segments = self.split_text(request.text, max_length=2000)
-
- if not segments:
- raise HTTPException(status_code=400, detail="文本切割失败")
-
- logger.info(f"长文本已切割为 {len(segments)} 段,总字符数: {len(request.text)}")
-
- # 根据用户指定的采样率选择PCM格式用于中间处理
- # 使用PCM格式便于合并音频片段
- sample_rate = request.sample_rate
- pcm_format_map = {
- 8000: AudioFormat.PCM_8000HZ_MONO_16BIT,
- 16000: AudioFormat.PCM_16000HZ_MONO_16BIT,
- 22050: AudioFormat.PCM_22050HZ_MONO_16BIT,
- 24000: AudioFormat.PCM_24000HZ_MONO_16BIT,
- 44100: AudioFormat.PCM_44100HZ_MONO_16BIT,
- 48000: AudioFormat.PCM_48000HZ_MONO_16BIT,
- }
- pcm_format = pcm_format_map.get(sample_rate, AudioFormat.PCM_22050HZ_MONO_16BIT)
-
- # 逐段使用非流式调用进行合成
- # 注意:synthesizer.call() 是同步阻塞调用,必须放到线程池执行,
- # 否则会阻塞 event loop 导致 gunicorn 心跳超时被 SIGABRT 杀掉
- import asyncio
- loop = asyncio.get_event_loop()
- def _synthesize_segment(segment: str) -> bytes:
- """在线程池中执行同步 TTS 调用"""
- synthesizer = SpeechSynthesizer(
- model=request.model,
- voice=request.voice,
- format=pcm_format,
- volume=request.volume,
- speech_rate=request.speech_rate,
- pitch_rate=request.pitch_rate
- )
- return synthesizer.call(segment)
- audio_parts = []
- total_characters = 0
- for idx, segment in enumerate(segments, 1):
- try:
- # 在线程池中执行阻塞调用,不阻塞 event loop
- audio_data = await loop.run_in_executor(None, _synthesize_segment, segment)
-
- if not audio_data:
- logger.warning(f"第 {idx}/{len(segments)} 段合成失败,返回空数据")
- continue
-
- audio_parts.append(audio_data)
- total_characters += len(segment)
- logger.debug(f"第 {idx}/{len(segments)} 段合成成功,字符数: {len(segment)}, 音频大小: {len(audio_data)} 字节")
-
- except Exception as e:
- logger.error(f"第 {idx}/{len(segments)} 段合成失败: {type(e).__name__}: {str(e)}")
- # 继续处理下一段,不中断整个流程
- continue
-
- if not audio_parts:
- raise HTTPException(status_code=502, detail="语音合成失败,所有片段均未返回音频数据")
-
- # 合并PCM音频片段(直接拼接字节)
- merged_pcm = b''.join(audio_parts)
-
- # 转换为目标格式
- final_audio = self._convert_pcm_to_format(
- merged_pcm,
- request.format,
- sample_rate
- )
-
- # 上传到OSS
- audio_url = self.oss_service.upload_file(
- final_audio,
- prefix="audio/tts",
- original_filename=f"audio.{request.format}"
- )
-
- # 估算时长(PCM: 16bit mono = 2 bytes per sample)
- duration = len(merged_pcm) / (sample_rate * 2)
-
- # 计算费用(API调用免费)
- bill = Decimal("0")
- # 保存合成记录
- text_preview = request.text[:1000] + "..." if len(request.text) > 1000 else request.text
- synthesis_record = AudioSynthesis(
- user_id=self.user_id,
- model=request.model,
- voice=request.voice,
- text=text_preview,
- audio_url=audio_url,
- duration=duration,
- format=request.format,
- characters=len(request.text),
- bill=bill,
- completed_at=datetime.now()
- )
- self.db.add(synthesis_record)
- self.db.commit()
- self.db.refresh(synthesis_record)
-
- logger.info(f"长文本合成完成: {len(segments)} 段, 总字符数: {len(request.text)}, 时长: {duration:.2f}秒")
-
- return LongTTSResponse(
- audio_url=audio_url,
- duration=round(duration, 2),
- format=request.format,
- total_characters=len(request.text),
- segments=len(segments)
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"长文本TTS合成失败: {type(e).__name__}: {str(e)}", exc_info=True)
- raise HTTPException(status_code=502, detail=f"语音合成失败: {str(e)}")
-
- def _convert_pcm_to_format(self, pcm_data: bytes, format_str: str, sample_rate: int) -> bytes:
- """
- 将PCM数据转换为目标格式
-
- Args:
- pcm_data: PCM音频数据(16bit mono)
- format_str: 目标格式(mp3、wav、pcm、opus)
- sample_rate: 采样率
-
- Returns:
- 转换后的音频数据
- """
- if format_str == "pcm":
- return pcm_data
-
- if format_str == "wav":
- return self._pcm_to_wav(pcm_data, sample_rate)
-
- if format_str == "mp3":
- # 先转WAV,再转MP3(需要pydub和ffmpeg)
- try:
- from pydub import AudioSegment
- wav_data = self._pcm_to_wav(pcm_data, sample_rate)
- audio = AudioSegment.from_wav(io.BytesIO(wav_data))
- mp3_buffer = io.BytesIO()
- # 根据采样率选择合适的码率
- bitrate_map = {
- 8000: "64k",
- 16000: "96k",
- 22050: "128k",
- 24000: "128k",
- 44100: "192k",
- 48000: "192k",
- }
- bitrate = bitrate_map.get(sample_rate, "128k")
- audio.export(mp3_buffer, format="mp3", bitrate=bitrate)
- return mp3_buffer.getvalue()
- except ImportError:
- logger.warning("pydub未安装,无法转换为MP3格式,返回WAV格式")
- return self._pcm_to_wav(pcm_data, sample_rate)
- except Exception as e:
- logger.error(f"MP3转换失败: {str(e)},返回WAV格式")
- return self._pcm_to_wav(pcm_data, sample_rate)
-
- if format_str == "opus":
- # OPUS格式转换(需要pydub和ffmpeg)
- try:
- from pydub import AudioSegment
- wav_data = self._pcm_to_wav(pcm_data, sample_rate)
- audio = AudioSegment.from_wav(io.BytesIO(wav_data))
- opus_buffer = io.BytesIO()
- audio.export(opus_buffer, format="opus", bitrate="32k")
- return opus_buffer.getvalue()
- except ImportError:
- logger.warning("pydub未安装,无法转换为OPUS格式,返回WAV格式")
- return self._pcm_to_wav(pcm_data, sample_rate)
- except Exception as e:
- logger.error(f"OPUS转换失败: {str(e)},返回WAV格式")
- return self._pcm_to_wav(pcm_data, sample_rate)
-
- # 默认返回WAV
- logger.warning(f"不支持的格式: {format_str},返回WAV格式")
- return self._pcm_to_wav(pcm_data, sample_rate)
-
- def _pcm_to_wav(self, pcm_data: bytes, sample_rate: int) -> bytes:
- """
- 将PCM数据转换为WAV格式
-
- Args:
- pcm_data: PCM音频数据
- sample_rate: 采样率
-
- Returns:
- WAV格式音频数据
- """
- import struct
-
- # WAV文件头参数
- channels = 1
- bits_per_sample = 16
- byte_rate = sample_rate * channels * bits_per_sample // 8
- block_align = channels * bits_per_sample // 8
- data_size = len(pcm_data)
-
- # 构建WAV头
- wav_header = struct.pack(
- '<4sI4s4sIHHIIHH4sI',
- b'RIFF',
- 36 + data_size,
- b'WAVE',
- b'fmt ',
- 16,
- 1, # PCM
- channels,
- sample_rate,
- byte_rate,
- block_align,
- bits_per_sample,
- b'data',
- data_size
- )
-
- return wav_header + pcm_data
-
- def get_tts_models(self) -> List[TTSModelResponse]:
- """获取TTS模型列表(从数据库动态查询)"""
- from app.models.model import ModelNew, ModelPriceNew, ModelCategory
- from sqlalchemy import cast
- from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
- models = self.db.query(ModelNew).filter(
- ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
- ModelNew.is_api_enabled == True,
- ModelNew.is_show_enabled == True,
- ).all()
- result = []
- for i, m in enumerate(models):
- # clone 模型是声音复刻专用,不出现在普通 TTS 列表
- # realtime 模型使用 WebSocket 接口,当前 TTS 服务不支持
- if 'clone' in m.model_code.lower() or 'realtime' in m.model_code.lower():
- continue
- price_row = self.db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == m.model_code,
- ModelPriceNew.is_active == True,
- ).first()
- price_str = ""
- if price_row:
- # TTS 按字符计费,用 input_price
- price_val = price_row.input_price_discounted
- # 去掉多余的0
- price_normalized = price_val.normalize()
- unit_str = price_row.unit.replace('元/', '')
- price_str = f"{price_normalized}/{unit_str}"
- features = []
- if m.features:
- if isinstance(m.features, dict):
- features = [k for k, v in m.features.items() if v]
- elif isinstance(m.features, list):
- features = m.features
- result.append(TTSModelResponse(
- id=m.id,
- title=m.model_code,
- name=m.display_name or m.model_code,
- description=m.custom_description or m.description or "",
- price=price_str,
- features=features,
- ))
- return result
|