tts_service.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. """
  2. TTS语音合成服务
  3. 提供语音合成的业务逻辑处理,集成阿里云百炼平台DashScope
  4. 需求: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8
  5. 支持: 非流式合成、流式合成、长文本合成、文本切割
  6. """
  7. import io
  8. import logging
  9. import os
  10. import re
  11. import uuid
  12. from datetime import date, datetime
  13. from typing import AsyncGenerator, List, Optional
  14. from sqlalchemy.orm import Session
  15. from fastapi import HTTPException
  16. from app.models.audio import AudioSynthesis
  17. from app.schemas.audio_schema import (
  18. TTSRequest, TTSResponse, LongTTSResponse, TTSModelResponse
  19. )
  20. from app.services.oss_service import get_oss_service
  21. from decimal import Decimal
  22. logger = logging.getLogger(__name__)
  23. class TTSService:
  24. """TTS语音合成服务类"""
  25. # TTS模型配置
  26. TTS_MODELS = [
  27. {
  28. "id": 1,
  29. "title": "cosyvoice-v3-flash",
  30. "name": "CosyVoice V3 Flash",
  31. "description": "平衡效果与成本,性价比高",
  32. "price": "0.14335/万字符",
  33. "features": ["快速合成", "支持SSML", "支持Instruct"]
  34. },
  35. {
  36. "id": 2,
  37. "title": "cosyvoice-v3-plus",
  38. "name": "CosyVoice V3 Plus",
  39. "description": "最高质量,最佳表现力",
  40. "price": "0.286706/万字符",
  41. "features": ["高质量", "支持SSML", "支持Instruct"]
  42. },
  43. {
  44. "id": 3,
  45. "title": "cosyvoice-v2",
  46. "name": "CosyVoice V2",
  47. "description": "兼容旧版,稳定可靠",
  48. "price": "0.286706/万字符",
  49. "features": ["稳定", "支持SSML"]
  50. }
  51. ]
  52. # 有效的TTS模型名称
  53. VALID_MODELS = ["cosyvoice-v3-flash", "cosyvoice-v3-plus", "cosyvoice-v2"]
  54. def __init__(self, db: Session, user_id: str, api_key: str = None):
  55. """
  56. 初始化TTS服务
  57. Args:
  58. db: 数据库会话
  59. user_id: 用户ID
  60. api_key: 用户的API密钥(从用户数据动态加载)
  61. """
  62. self.db = db
  63. self.user_id = user_id
  64. self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
  65. self.oss_service = get_oss_service()
  66. def _calculate_char_count(self, text: str) -> int:
  67. """
  68. 计算字符数(汉字算两个字符)
  69. Args:
  70. text: 待计算的文本
  71. Returns:
  72. 字符数(汉字算两个字符)
  73. """
  74. count = 0
  75. for char in text:
  76. # 判断是否为汉字(CJK统一汉字范围)
  77. if '\u4e00' <= char <= '\u9fff':
  78. count += 2
  79. else:
  80. count += 1
  81. return count
  82. def _validate_request(self, request: TTSRequest) -> None:
  83. """
  84. 验证TTS请求参数
  85. Args:
  86. request: TTS请求对象
  87. Raises:
  88. HTTPException: 参数验证失败
  89. """
  90. # 验证文本不为空
  91. if not request.text or not request.text.strip():
  92. raise HTTPException(status_code=400, detail="文本不能为空或仅包含空白字符")
  93. # 验证模型(动态查库)
  94. from app.models.model import ModelNew, ModelCategory
  95. from sqlalchemy import cast
  96. from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
  97. valid_model = self.db.query(ModelNew).filter(
  98. ModelNew.model_code == request.model,
  99. ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
  100. ModelNew.is_api_enabled == True,
  101. ).first()
  102. if not valid_model:
  103. raise HTTPException(
  104. status_code=400,
  105. detail=f"无效的模型名称: {request.model}"
  106. )
  107. # 验证文本长度(非长文本合成时,汉字算两个字符)
  108. char_count = self._calculate_char_count(request.text)
  109. if char_count > 2000:
  110. raise HTTPException(
  111. status_code=400,
  112. detail="文本长度超过限制(最大2000字符,汉字算两个字符),请使用长文本转语音功能"
  113. )
  114. def _get_audio_format(self, format_str: str, sample_rate: int):
  115. """
  116. 获取DashScope音频格式枚举
  117. Args:
  118. format_str: 格式字符串 (mp3, wav, pcm, opus)
  119. sample_rate: 采样率
  120. Returns:
  121. AudioFormat枚举值
  122. """
  123. from dashscope.audio.tts_v2 import AudioFormat
  124. format_map = {
  125. ("mp3", 16000): AudioFormat.MP3_16000HZ_MONO_128KBPS,
  126. ("mp3", 22050): AudioFormat.MP3_22050HZ_MONO_256KBPS,
  127. ("mp3", 24000): AudioFormat.MP3_24000HZ_MONO_256KBPS,
  128. ("mp3", 44100): AudioFormat.MP3_44100HZ_MONO_256KBPS,
  129. ("mp3", 48000): AudioFormat.MP3_48000HZ_MONO_256KBPS,
  130. ("wav", 8000): AudioFormat.WAV_8000HZ_MONO_16BIT,
  131. ("wav", 16000): AudioFormat.WAV_16000HZ_MONO_16BIT,
  132. ("wav", 22050): AudioFormat.WAV_22050HZ_MONO_16BIT,
  133. ("wav", 24000): AudioFormat.WAV_24000HZ_MONO_16BIT,
  134. ("wav", 44100): AudioFormat.WAV_44100HZ_MONO_16BIT,
  135. ("wav", 48000): AudioFormat.WAV_48000HZ_MONO_16BIT,
  136. ("pcm", 8000): AudioFormat.PCM_8000HZ_MONO_16BIT,
  137. ("pcm", 16000): AudioFormat.PCM_16000HZ_MONO_16BIT,
  138. ("pcm", 22050): AudioFormat.PCM_22050HZ_MONO_16BIT,
  139. ("pcm", 24000): AudioFormat.PCM_24000HZ_MONO_16BIT,
  140. ("pcm", 44100): AudioFormat.PCM_44100HZ_MONO_16BIT,
  141. ("pcm", 48000): AudioFormat.PCM_48000HZ_MONO_16BIT,
  142. }
  143. key = (format_str.lower(), sample_rate)
  144. if key in format_map:
  145. return format_map[key]
  146. # 默认返回MP3 22050Hz
  147. return AudioFormat.MP3_22050HZ_MONO_256KBPS
  148. def _generate_oss_path(self, format_str: str) -> str:
  149. """
  150. 生成OSS存储路径
  151. Args:
  152. format_str: 音频格式
  153. Returns:
  154. OSS路径
  155. """
  156. date_str = date.today().strftime('%Y%m%d')
  157. unique_id = uuid.uuid4().hex
  158. return f"audio/tts/{date_str}/{unique_id}.{format_str}"
  159. def _estimate_duration(self, audio_data: bytes, format_str: str, sample_rate: int) -> float:
  160. """
  161. 估算音频时长
  162. Args:
  163. audio_data: 音频数据
  164. format_str: 音频格式
  165. sample_rate: 采样率
  166. Returns:
  167. 估算的时长(秒)
  168. """
  169. if format_str == "pcm":
  170. # PCM: 16bit mono = 2 bytes per sample
  171. return len(audio_data) / (sample_rate * 2)
  172. elif format_str == "wav":
  173. # WAV: 跳过44字节头,16bit mono
  174. return (len(audio_data) - 44) / (sample_rate * 2)
  175. elif format_str == "mp3":
  176. # MP3: 粗略估算,128kbps = 16KB/s
  177. return len(audio_data) / 16000
  178. else:
  179. # 默认估算
  180. return len(audio_data) / 16000
  181. async def synthesize(self, request: TTSRequest) -> TTSResponse:
  182. """
  183. 非流式语音合成
  184. Args:
  185. request: TTS请求对象
  186. Returns:
  187. TTS响应对象
  188. Raises:
  189. HTTPException: 合成失败
  190. """
  191. import dashscope
  192. from dashscope.audio.tts_v2 import SpeechSynthesizer
  193. # 验证请求
  194. self._validate_request(request)
  195. # 设置API Key
  196. dashscope.api_key = self.api_key
  197. try:
  198. # 获取音频格式
  199. audio_format = self._get_audio_format(request.format, request.sample_rate)
  200. # 创建合成器
  201. synthesizer = SpeechSynthesizer(
  202. model=request.model,
  203. voice=request.voice,
  204. format=audio_format,
  205. volume=request.volume,
  206. speech_rate=request.speech_rate,
  207. pitch_rate=request.pitch_rate
  208. )
  209. # 合成音频(同步阻塞调用放到线程池,避免阻塞 event loop)
  210. import asyncio
  211. loop = asyncio.get_event_loop()
  212. audio_data = await loop.run_in_executor(None, synthesizer.call, request.text)
  213. if not audio_data:
  214. raise HTTPException(status_code=502, detail="语音合成失败,未返回音频数据")
  215. # 上传到OSS
  216. oss_path = self._generate_oss_path(request.format)
  217. audio_url = self.oss_service.upload_file(
  218. audio_data,
  219. prefix="audio/tts",
  220. original_filename=f"audio.{request.format}"
  221. )
  222. # 估算时长
  223. duration = self._estimate_duration(audio_data, request.format, request.sample_rate)
  224. # 计算费用(API调用免费)
  225. bill = Decimal("0")
  226. # 保存合成记录
  227. synthesis_record = AudioSynthesis(
  228. user_id=self.user_id,
  229. model=request.model,
  230. voice=request.voice,
  231. text=request.text,
  232. audio_url=audio_url,
  233. duration=duration,
  234. format=request.format,
  235. characters=len(request.text),
  236. bill=bill,
  237. completed_at=datetime.now()
  238. )
  239. self.db.add(synthesis_record)
  240. self.db.commit()
  241. self.db.refresh(synthesis_record)
  242. return TTSResponse(
  243. audio_url=audio_url,
  244. duration=round(duration, 2),
  245. format=request.format,
  246. sample_rate=request.sample_rate,
  247. characters=len(request.text)
  248. )
  249. except HTTPException:
  250. raise
  251. except Exception as e:
  252. logger.error(f"TTS合成失败: {type(e).__name__}: {str(e)}")
  253. raise HTTPException(status_code=502, detail=f"语音合成失败: {str(e)}")
  254. async def synthesize_stream(self, request: TTSRequest) -> AsyncGenerator[bytes, None]:
  255. """
  256. 流式语音合成
  257. Args:
  258. request: TTS请求对象
  259. Yields:
  260. 音频数据块
  261. Raises:
  262. HTTPException: 合成失败
  263. """
  264. import dashscope
  265. from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback
  266. # 验证请求(流式模式下不限制文本长度)
  267. if not request.text or not request.text.strip():
  268. raise HTTPException(status_code=400, detail="文本不能为空或仅包含空白字符")
  269. # 动态查库验证模型
  270. from app.models.model import ModelNew, ModelCategory
  271. from sqlalchemy import cast
  272. from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
  273. valid_model = self.db.query(ModelNew).filter(
  274. ModelNew.model_code == request.model,
  275. ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
  276. ModelNew.is_api_enabled == True,
  277. ).first()
  278. if not valid_model:
  279. raise HTTPException(status_code=400, detail=f"无效的模型名称: {request.model}")
  280. # 设置API Key
  281. dashscope.api_key = self.api_key
  282. # 用于收集音频数据的队列
  283. import asyncio
  284. audio_queue = asyncio.Queue()
  285. class StreamCallback(ResultCallback):
  286. def on_open(self):
  287. pass
  288. def on_complete(self):
  289. asyncio.get_event_loop().call_soon_threadsafe(
  290. audio_queue.put_nowait, None
  291. )
  292. def on_error(self, message: str):
  293. asyncio.get_event_loop().call_soon_threadsafe(
  294. audio_queue.put_nowait, Exception(message)
  295. )
  296. def on_event(self, message):
  297. pass
  298. def on_data(self, data: bytes):
  299. asyncio.get_event_loop().call_soon_threadsafe(
  300. audio_queue.put_nowait, data
  301. )
  302. try:
  303. # 获取音频格式
  304. audio_format = self._get_audio_format(request.format, request.sample_rate)
  305. # 创建合成器
  306. callback = StreamCallback()
  307. synthesizer = SpeechSynthesizer(
  308. model=request.model,
  309. voice=request.voice,
  310. format=audio_format,
  311. volume=request.volume,
  312. speech_rate=request.speech_rate,
  313. pitch_rate=request.pitch_rate,
  314. callback=callback
  315. )
  316. # 启动流式合成
  317. synthesizer.streaming_call(request.text)
  318. synthesizer.streaming_complete()
  319. # 从队列中读取数据
  320. while True:
  321. data = await audio_queue.get()
  322. if data is None:
  323. break
  324. if isinstance(data, Exception):
  325. raise HTTPException(status_code=502, detail=str(data))
  326. yield data
  327. except HTTPException:
  328. raise
  329. except Exception as e:
  330. logger.error(f"TTS流式合成失败: {type(e).__name__}: {str(e)}")
  331. raise HTTPException(status_code=502, detail=f"语音合成失败: {str(e)}")
  332. def split_text(self, text: str, max_length: int = 2000) -> List[str]:
  333. """
  334. 按句子边界智能切割文本
  335. 根据文档要求,文本长度限制按实际字符数(Unicode字符数)计算,每段不超过max_length字符。
  336. 优先在句子边界(。!?;\n)处切割,避免截断句子。
  337. 如果单句超过限制,则在逗号处切割;如果仍超过,则强制按字符数切割。
  338. Args:
  339. text: 待切割的文本
  340. max_length: 每段最大长度(Unicode字符数),默认2000
  341. Returns:
  342. 切割后的文本列表
  343. """
  344. if not text:
  345. return []
  346. # 按实际字符数计算(Unicode字符数)
  347. text_length = len(text)
  348. if text_length <= max_length:
  349. return [text]
  350. segments = []
  351. current = ""
  352. current_length = 0
  353. # 按句子分隔符切割:。!?;\n
  354. # 使用正则表达式保留分隔符
  355. sentence_pattern = r'([。!?;\n])'
  356. parts = re.split(sentence_pattern, text)
  357. i = 0
  358. while i < len(parts):
  359. part = parts[i]
  360. # 获取分隔符(如果存在)
  361. delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(sentence_pattern, parts[i + 1]) else ""
  362. if delimiter:
  363. i += 1
  364. full_sentence = part + delimiter
  365. sentence_length = len(full_sentence)
  366. # 如果当前段加上新句子不超过限制
  367. if current_length + sentence_length <= max_length:
  368. current += full_sentence
  369. current_length += sentence_length
  370. else:
  371. # 如果当前段不为空,保存它
  372. if current:
  373. segments.append(current)
  374. current = ""
  375. current_length = 0
  376. # 如果单个句子超过限制,需要进一步切割
  377. if sentence_length > max_length:
  378. # 按逗号切割
  379. sub_segments = self._split_long_sentence(full_sentence, max_length)
  380. segments.extend(sub_segments[:-1])
  381. if sub_segments:
  382. current = sub_segments[-1]
  383. current_length = len(current)
  384. else:
  385. current = full_sentence
  386. current_length = sentence_length
  387. i += 1
  388. # 添加最后一段
  389. if current:
  390. segments.append(current)
  391. return segments
  392. def _split_long_sentence(self, sentence: str, max_length: int) -> List[str]:
  393. """
  394. 切割超长句子(按逗号或强制切割)
  395. 按实际字符数(Unicode字符数)进行切割,优先在逗号处切割,避免截断词语。
  396. 如果单个部分仍超过限制,则强制按字符数切割。
  397. Args:
  398. sentence: 超长句子
  399. max_length: 最大长度(Unicode字符数)
  400. Returns:
  401. 切割后的片段列表
  402. """
  403. sentence_length = len(sentence)
  404. if sentence_length <= max_length:
  405. return [sentence]
  406. segments = []
  407. current = ""
  408. current_length = 0
  409. # 按逗号切割
  410. comma_pattern = r'([,,、])'
  411. parts = re.split(comma_pattern, sentence)
  412. i = 0
  413. while i < len(parts):
  414. part = parts[i]
  415. delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(comma_pattern, parts[i + 1]) else ""
  416. if delimiter:
  417. i += 1
  418. full_part = part + delimiter
  419. part_length = len(full_part)
  420. if current_length + part_length <= max_length:
  421. current += full_part
  422. current_length += part_length
  423. else:
  424. if current:
  425. segments.append(current)
  426. current = ""
  427. current_length = 0
  428. # 如果单个部分仍然超过限制,强制按字符数切割
  429. if part_length > max_length:
  430. # 按字符数切割
  431. for j in range(0, part_length, max_length):
  432. chunk = full_part[j:j + max_length]
  433. if j + max_length < part_length:
  434. segments.append(chunk)
  435. else:
  436. current = chunk
  437. current_length = len(chunk)
  438. else:
  439. current = full_part
  440. current_length = part_length
  441. i += 1
  442. if current:
  443. segments.append(current)
  444. return segments
  445. async def synthesize_long(self, request: TTSRequest) -> LongTTSResponse:
  446. """
  447. 长文本语音合成(非流式输出)
  448. 根据文档要求,由于声音合成限制输入字符(单次不超过2000字符),
  449. 长文本转音频通过分割文本方式实现,使用非流式输出(call方法)。
  450. 实现流程:
  451. 1. 将长文本按句子边界智能切割,每段不超过2000字符
  452. 2. 对每个文本段使用非流式调用(call方法)进行合成
  453. 3. 合并所有音频片段
  454. 4. 转换为目标格式并上传到OSS
  455. Args:
  456. request: TTS请求对象(text可超过2000字符)
  457. Returns:
  458. 长文本TTS响应对象
  459. Raises:
  460. HTTPException: 合成失败
  461. """
  462. import dashscope
  463. from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
  464. # 验证文本不为空
  465. if not request.text or not request.text.strip():
  466. raise HTTPException(status_code=400, detail="文本不能为空或仅包含空白字符")
  467. # 验证模型(动态查库)
  468. from app.models.model import ModelNew, ModelCategory
  469. from sqlalchemy import cast
  470. from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
  471. valid_model = self.db.query(ModelNew).filter(
  472. ModelNew.model_code == request.model,
  473. ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
  474. ModelNew.is_api_enabled == True,
  475. ).first()
  476. if not valid_model:
  477. raise HTTPException(
  478. status_code=400,
  479. detail=f"无效的模型名称: {request.model}"
  480. )
  481. # 设置API Key
  482. dashscope.api_key = self.api_key
  483. try:
  484. # 切割文本(按实际字符数,每段不超过2000字符)
  485. segments = self.split_text(request.text, max_length=2000)
  486. if not segments:
  487. raise HTTPException(status_code=400, detail="文本切割失败")
  488. logger.info(f"长文本已切割为 {len(segments)} 段,总字符数: {len(request.text)}")
  489. # 根据用户指定的采样率选择PCM格式用于中间处理
  490. # 使用PCM格式便于合并音频片段
  491. sample_rate = request.sample_rate
  492. pcm_format_map = {
  493. 8000: AudioFormat.PCM_8000HZ_MONO_16BIT,
  494. 16000: AudioFormat.PCM_16000HZ_MONO_16BIT,
  495. 22050: AudioFormat.PCM_22050HZ_MONO_16BIT,
  496. 24000: AudioFormat.PCM_24000HZ_MONO_16BIT,
  497. 44100: AudioFormat.PCM_44100HZ_MONO_16BIT,
  498. 48000: AudioFormat.PCM_48000HZ_MONO_16BIT,
  499. }
  500. pcm_format = pcm_format_map.get(sample_rate, AudioFormat.PCM_22050HZ_MONO_16BIT)
  501. # 逐段使用非流式调用进行合成
  502. # 注意:synthesizer.call() 是同步阻塞调用,必须放到线程池执行,
  503. # 否则会阻塞 event loop 导致 gunicorn 心跳超时被 SIGABRT 杀掉
  504. import asyncio
  505. loop = asyncio.get_event_loop()
  506. def _synthesize_segment(segment: str) -> bytes:
  507. """在线程池中执行同步 TTS 调用"""
  508. synthesizer = SpeechSynthesizer(
  509. model=request.model,
  510. voice=request.voice,
  511. format=pcm_format,
  512. volume=request.volume,
  513. speech_rate=request.speech_rate,
  514. pitch_rate=request.pitch_rate
  515. )
  516. return synthesizer.call(segment)
  517. audio_parts = []
  518. total_characters = 0
  519. for idx, segment in enumerate(segments, 1):
  520. try:
  521. # 在线程池中执行阻塞调用,不阻塞 event loop
  522. audio_data = await loop.run_in_executor(None, _synthesize_segment, segment)
  523. if not audio_data:
  524. logger.warning(f"第 {idx}/{len(segments)} 段合成失败,返回空数据")
  525. continue
  526. audio_parts.append(audio_data)
  527. total_characters += len(segment)
  528. logger.debug(f"第 {idx}/{len(segments)} 段合成成功,字符数: {len(segment)}, 音频大小: {len(audio_data)} 字节")
  529. except Exception as e:
  530. logger.error(f"第 {idx}/{len(segments)} 段合成失败: {type(e).__name__}: {str(e)}")
  531. # 继续处理下一段,不中断整个流程
  532. continue
  533. if not audio_parts:
  534. raise HTTPException(status_code=502, detail="语音合成失败,所有片段均未返回音频数据")
  535. # 合并PCM音频片段(直接拼接字节)
  536. merged_pcm = b''.join(audio_parts)
  537. # 转换为目标格式
  538. final_audio = self._convert_pcm_to_format(
  539. merged_pcm,
  540. request.format,
  541. sample_rate
  542. )
  543. # 上传到OSS
  544. audio_url = self.oss_service.upload_file(
  545. final_audio,
  546. prefix="audio/tts",
  547. original_filename=f"audio.{request.format}"
  548. )
  549. # 估算时长(PCM: 16bit mono = 2 bytes per sample)
  550. duration = len(merged_pcm) / (sample_rate * 2)
  551. # 计算费用(API调用免费)
  552. bill = Decimal("0")
  553. # 保存合成记录
  554. text_preview = request.text[:1000] + "..." if len(request.text) > 1000 else request.text
  555. synthesis_record = AudioSynthesis(
  556. user_id=self.user_id,
  557. model=request.model,
  558. voice=request.voice,
  559. text=text_preview,
  560. audio_url=audio_url,
  561. duration=duration,
  562. format=request.format,
  563. characters=len(request.text),
  564. bill=bill,
  565. completed_at=datetime.now()
  566. )
  567. self.db.add(synthesis_record)
  568. self.db.commit()
  569. self.db.refresh(synthesis_record)
  570. logger.info(f"长文本合成完成: {len(segments)} 段, 总字符数: {len(request.text)}, 时长: {duration:.2f}秒")
  571. return LongTTSResponse(
  572. audio_url=audio_url,
  573. duration=round(duration, 2),
  574. format=request.format,
  575. total_characters=len(request.text),
  576. segments=len(segments)
  577. )
  578. except HTTPException:
  579. raise
  580. except Exception as e:
  581. logger.error(f"长文本TTS合成失败: {type(e).__name__}: {str(e)}", exc_info=True)
  582. raise HTTPException(status_code=502, detail=f"语音合成失败: {str(e)}")
  583. def _convert_pcm_to_format(self, pcm_data: bytes, format_str: str, sample_rate: int) -> bytes:
  584. """
  585. 将PCM数据转换为目标格式
  586. Args:
  587. pcm_data: PCM音频数据(16bit mono)
  588. format_str: 目标格式(mp3、wav、pcm、opus)
  589. sample_rate: 采样率
  590. Returns:
  591. 转换后的音频数据
  592. """
  593. if format_str == "pcm":
  594. return pcm_data
  595. if format_str == "wav":
  596. return self._pcm_to_wav(pcm_data, sample_rate)
  597. if format_str == "mp3":
  598. # 先转WAV,再转MP3(需要pydub和ffmpeg)
  599. try:
  600. from pydub import AudioSegment
  601. wav_data = self._pcm_to_wav(pcm_data, sample_rate)
  602. audio = AudioSegment.from_wav(io.BytesIO(wav_data))
  603. mp3_buffer = io.BytesIO()
  604. # 根据采样率选择合适的码率
  605. bitrate_map = {
  606. 8000: "64k",
  607. 16000: "96k",
  608. 22050: "128k",
  609. 24000: "128k",
  610. 44100: "192k",
  611. 48000: "192k",
  612. }
  613. bitrate = bitrate_map.get(sample_rate, "128k")
  614. audio.export(mp3_buffer, format="mp3", bitrate=bitrate)
  615. return mp3_buffer.getvalue()
  616. except ImportError:
  617. logger.warning("pydub未安装,无法转换为MP3格式,返回WAV格式")
  618. return self._pcm_to_wav(pcm_data, sample_rate)
  619. except Exception as e:
  620. logger.error(f"MP3转换失败: {str(e)},返回WAV格式")
  621. return self._pcm_to_wav(pcm_data, sample_rate)
  622. if format_str == "opus":
  623. # OPUS格式转换(需要pydub和ffmpeg)
  624. try:
  625. from pydub import AudioSegment
  626. wav_data = self._pcm_to_wav(pcm_data, sample_rate)
  627. audio = AudioSegment.from_wav(io.BytesIO(wav_data))
  628. opus_buffer = io.BytesIO()
  629. audio.export(opus_buffer, format="opus", bitrate="32k")
  630. return opus_buffer.getvalue()
  631. except ImportError:
  632. logger.warning("pydub未安装,无法转换为OPUS格式,返回WAV格式")
  633. return self._pcm_to_wav(pcm_data, sample_rate)
  634. except Exception as e:
  635. logger.error(f"OPUS转换失败: {str(e)},返回WAV格式")
  636. return self._pcm_to_wav(pcm_data, sample_rate)
  637. # 默认返回WAV
  638. logger.warning(f"不支持的格式: {format_str},返回WAV格式")
  639. return self._pcm_to_wav(pcm_data, sample_rate)
  640. def _pcm_to_wav(self, pcm_data: bytes, sample_rate: int) -> bytes:
  641. """
  642. 将PCM数据转换为WAV格式
  643. Args:
  644. pcm_data: PCM音频数据
  645. sample_rate: 采样率
  646. Returns:
  647. WAV格式音频数据
  648. """
  649. import struct
  650. # WAV文件头参数
  651. channels = 1
  652. bits_per_sample = 16
  653. byte_rate = sample_rate * channels * bits_per_sample // 8
  654. block_align = channels * bits_per_sample // 8
  655. data_size = len(pcm_data)
  656. # 构建WAV头
  657. wav_header = struct.pack(
  658. '<4sI4s4sIHHIIHH4sI',
  659. b'RIFF',
  660. 36 + data_size,
  661. b'WAVE',
  662. b'fmt ',
  663. 16,
  664. 1, # PCM
  665. channels,
  666. sample_rate,
  667. byte_rate,
  668. block_align,
  669. bits_per_sample,
  670. b'data',
  671. data_size
  672. )
  673. return wav_header + pcm_data
  674. def get_tts_models(self) -> List[TTSModelResponse]:
  675. """获取TTS模型列表(从数据库动态查询)"""
  676. from app.models.model import ModelNew, ModelPriceNew, ModelCategory
  677. from sqlalchemy import cast
  678. from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
  679. models = self.db.query(ModelNew).filter(
  680. ModelNew.categories.contains(cast([int(ModelCategory.TTS)], ARRAY(INTEGER))),
  681. ModelNew.is_api_enabled == True,
  682. ModelNew.is_show_enabled == True,
  683. ).all()
  684. result = []
  685. for i, m in enumerate(models):
  686. # clone 模型是声音复刻专用,不出现在普通 TTS 列表
  687. # realtime 模型使用 WebSocket 接口,当前 TTS 服务不支持
  688. if 'clone' in m.model_code.lower() or 'realtime' in m.model_code.lower():
  689. continue
  690. price_row = self.db.query(ModelPriceNew).filter(
  691. ModelPriceNew.model_code == m.model_code,
  692. ModelPriceNew.is_active == True,
  693. ).first()
  694. price_str = ""
  695. if price_row:
  696. # TTS 按字符计费,用 input_price
  697. price_val = price_row.input_price_discounted
  698. # 去掉多余的0
  699. price_normalized = price_val.normalize()
  700. unit_str = price_row.unit.replace('元/', '')
  701. price_str = f"{price_normalized}/{unit_str}"
  702. features = []
  703. if m.features:
  704. if isinstance(m.features, dict):
  705. features = [k for k, v in m.features.items() if v]
  706. elif isinstance(m.features, list):
  707. features = m.features
  708. result.append(TTSModelResponse(
  709. id=m.id,
  710. title=m.model_code,
  711. name=m.display_name or m.model_code,
  712. description=m.custom_description or m.description or "",
  713. price=price_str,
  714. features=features,
  715. ))
  716. return result