long_text_audio_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. """
  2. 长文本转音频服务
  3. 提供长文本分段合成和拼接的业务逻辑处理
  4. """
  5. import asyncio
  6. import io
  7. import logging
  8. import re
  9. import uuid
  10. from datetime import datetime
  11. from typing import List
  12. from sqlalchemy.orm import Session
  13. from sqlalchemy import desc, text
  14. from sqlalchemy.exc import ProgrammingError
  15. from fastapi import HTTPException
  16. from decimal import Decimal
  17. from app.database import SessionLocal
  18. from app.models.audio import LongTextAudio
  19. from app.schemas.audio_v2 import (
  20. LongTextAudioCreateRequest,
  21. LongTextAudioResponse,
  22. LongTextAudioListResponse,
  23. TaskListQueryParams,
  24. SegmentInfo
  25. )
  26. from .base_service import BaseV2Service
  27. logger = logging.getLogger(__name__)
  28. class LongTextAudioService(BaseV2Service):
  29. """长文本转音频服务"""
  30. # 有效的TTS模型
  31. VALID_MODELS = ["cosyvoice-v3-flash", "cosyvoice-v3-plus", "cosyvoice-v2"]
  32. # 每段最大字符数
  33. MAX_SEGMENT_LENGTH = 500
  34. def _ensure_storage_ready(self) -> None:
  35. """Ensure schema/table exist for long text tasks."""
  36. try:
  37. self.db.execute(text("CREATE SCHEMA IF NOT EXISTS aigcspace"))
  38. self.db.commit()
  39. except Exception as e:
  40. self.db.rollback()
  41. logger.warning(f"创建schema失败: {type(e).__name__}: {str(e)}")
  42. try:
  43. LongTextAudio.__table__.create(self.db.get_bind(), checkfirst=True)
  44. except Exception as e:
  45. logger.warning(f"创建long_text_audio表失败: {type(e).__name__}: {str(e)}")
  46. async def create_task(
  47. self,
  48. request: LongTextAudioCreateRequest
  49. ) -> LongTextAudioResponse:
  50. """
  51. 创建长文本转音频任务
  52. Args:
  53. request: 创建请求
  54. Returns:
  55. 任务响应
  56. Raises:
  57. HTTPException: 创建失败
  58. """
  59. # 验证模型
  60. if request.model not in self.VALID_MODELS:
  61. raise HTTPException(
  62. status_code=400,
  63. detail=f"无效的模型,支持的模型: {self.VALID_MODELS}"
  64. )
  65. try:
  66. self._ensure_storage_ready()
  67. # 生成任务ID
  68. task_id = str(uuid.uuid4())
  69. # 分段文本
  70. segments = self._split_text(request.text)
  71. # 创建分段信息
  72. segment_infos = []
  73. for idx, segment_text in enumerate(segments, 1):
  74. segment_infos.append({
  75. "index": idx,
  76. "text": segment_text,
  77. "task_id": None,
  78. "audio_url": None,
  79. "duration": None,
  80. "status": "PENDING"
  81. })
  82. # 保存到数据库
  83. long_text_task = LongTextAudio(
  84. user_id=self.user_id,
  85. task_id=task_id,
  86. model=request.model,
  87. voice=request.voice,
  88. text=request.text,
  89. text_length=len(request.text),
  90. segment_count=len(segments),
  91. segments=segment_infos,
  92. format=request.format,
  93. custom_name=request.custom_name,
  94. status="PENDING",
  95. progress=0
  96. )
  97. try:
  98. self.db.add(long_text_task)
  99. self.db.commit()
  100. self.db.refresh(long_text_task)
  101. except ProgrammingError:
  102. self.db.rollback()
  103. self._ensure_storage_ready()
  104. self.db.add(long_text_task)
  105. self.db.commit()
  106. self.db.refresh(long_text_task)
  107. # 转换segments为SegmentInfo对象
  108. response = LongTextAudioResponse.from_orm(long_text_task)
  109. response.segments = [SegmentInfo(**seg) for seg in segment_infos]
  110. # 启动后台处理(不阻塞当前请求)
  111. asyncio.create_task(
  112. self._process_task(
  113. task_id=task_id,
  114. request=request,
  115. segments=segments,
  116. )
  117. )
  118. return response
  119. except HTTPException:
  120. raise
  121. except Exception as e:
  122. logger.error(f"创建长文本任务失败: {type(e).__name__}: {str(e)}")
  123. raise HTTPException(
  124. status_code=502,
  125. detail=f"创建长文本任务失败: {str(e)}"
  126. )
  127. async def get_task(self, task_id: str) -> LongTextAudioResponse:
  128. """
  129. 查询任务详情
  130. Args:
  131. task_id: 任务ID
  132. Returns:
  133. 任务响应
  134. Raises:
  135. HTTPException: 任务不存在
  136. """
  137. task = self.db.query(LongTextAudio).filter(
  138. LongTextAudio.task_id == task_id,
  139. LongTextAudio.user_id == self.user_id
  140. ).first()
  141. if not task:
  142. raise HTTPException(status_code=404, detail="任务不存在")
  143. # 转换segments
  144. response = LongTextAudioResponse.from_orm(task)
  145. if task.segments:
  146. response.segments = [SegmentInfo(**seg) for seg in task.segments]
  147. return response
  148. async def list_tasks(
  149. self,
  150. params: TaskListQueryParams
  151. ) -> LongTextAudioListResponse:
  152. """
  153. 查询任务列表
  154. Args:
  155. params: 查询参数
  156. Returns:
  157. 任务列表响应
  158. """
  159. query = self.db.query(LongTextAudio).filter(
  160. LongTextAudio.user_id == self.user_id
  161. )
  162. # 状态筛选
  163. if params.status:
  164. query = query.filter(LongTextAudio.status == params.status)
  165. # 总数
  166. total = query.count()
  167. # 排序
  168. if params.order_by == "created_at":
  169. order_column = LongTextAudio.created_at
  170. elif params.order_by == "updated_at":
  171. order_column = LongTextAudio.updated_at
  172. else:
  173. order_column = LongTextAudio.created_at
  174. if params.order == "desc":
  175. query = query.order_by(desc(order_column))
  176. else:
  177. query = query.order_by(order_column)
  178. # 分页
  179. offset = (params.page - 1) * params.page_size
  180. tasks = query.offset(offset).limit(params.page_size).all()
  181. items = []
  182. for task in tasks:
  183. response = LongTextAudioResponse.from_orm(task)
  184. if task.segments:
  185. response.segments = [SegmentInfo(**seg) for seg in task.segments]
  186. items.append(response)
  187. return LongTextAudioListResponse(total=total, items=items)
  188. async def _process_task(
  189. self,
  190. task_id: str,
  191. request: LongTextAudioCreateRequest,
  192. segments: List[str],
  193. ) -> None:
  194. """
  195. 后台执行长文本 TTS 合成(逐段合成 → 合并 → 上传 OSS)
  196. 使用独立的数据库会话,避免与请求会话冲突。
  197. 所有阻塞的 DashScope 调用都放到线程池执行,不阻塞 event loop。
  198. """
  199. db = SessionLocal()
  200. try:
  201. import dashscope
  202. from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat as DashAudioFormat
  203. dashscope.api_key = self.api_key
  204. # 更新状态为 PROCESSING
  205. task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
  206. if not task:
  207. return
  208. task.status = "PROCESSING"
  209. task.updated_at = datetime.now()
  210. db.commit()
  211. # 采样率 → PCM 格式映射
  212. sample_rate = getattr(request, 'sample_rate', 22050) or 22050
  213. pcm_format_map = {
  214. 8000: DashAudioFormat.PCM_8000HZ_MONO_16BIT,
  215. 16000: DashAudioFormat.PCM_16000HZ_MONO_16BIT,
  216. 22050: DashAudioFormat.PCM_22050HZ_MONO_16BIT,
  217. 24000: DashAudioFormat.PCM_24000HZ_MONO_16BIT,
  218. 44100: DashAudioFormat.PCM_44100HZ_MONO_16BIT,
  219. 48000: DashAudioFormat.PCM_48000HZ_MONO_16BIT,
  220. }
  221. pcm_fmt = pcm_format_map.get(sample_rate, DashAudioFormat.PCM_22050HZ_MONO_16BIT)
  222. loop = asyncio.get_event_loop()
  223. audio_parts: List[bytes] = []
  224. seg_infos = list(task.segments) if task.segments else []
  225. for idx, seg_text in enumerate(segments):
  226. def _synth(text=seg_text):
  227. synth = SpeechSynthesizer(
  228. model=request.model,
  229. voice=request.voice,
  230. format=pcm_fmt,
  231. volume=getattr(request, 'volume', 50),
  232. speech_rate=getattr(request, 'speech_rate', 1.0),
  233. pitch_rate=getattr(request, 'pitch_rate', 1.0),
  234. )
  235. return synth.call(text)
  236. try:
  237. audio_data = await loop.run_in_executor(None, _synth)
  238. except Exception as e:
  239. logger.error(f"[长文本TTS] 第{idx+1}段合成失败: {e}")
  240. audio_data = None
  241. # 更新分段状态
  242. if idx < len(seg_infos):
  243. seg_infos[idx] = dict(seg_infos[idx])
  244. if audio_data:
  245. seg_infos[idx]['status'] = 'SUCCEEDED'
  246. audio_parts.append(audio_data)
  247. else:
  248. seg_infos[idx]['status'] = 'FAILED'
  249. elif audio_data:
  250. audio_parts.append(audio_data)
  251. # 更新进度
  252. done = sum(1 for s in seg_infos if s.get('status') in ('SUCCEEDED', 'FAILED'))
  253. task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
  254. if task:
  255. task.segments = seg_infos
  256. task.progress = int(done / len(segments) * 100)
  257. task.updated_at = datetime.now()
  258. db.commit()
  259. if not audio_parts:
  260. raise RuntimeError("所有分段合成均失败")
  261. # 合并 PCM → 目标格式
  262. merged_pcm = b''.join(audio_parts)
  263. fmt = getattr(request, 'format', 'mp3') or 'mp3'
  264. final_audio = self._convert_pcm(merged_pcm, fmt, sample_rate)
  265. # 上传 OSS
  266. audio_url = self.oss_service.upload_file(
  267. final_audio,
  268. prefix="audio/long-text",
  269. original_filename=f"audio.{fmt}",
  270. )
  271. # 估算时长
  272. duration = len(merged_pcm) / (sample_rate * 2)
  273. # 费用(API调用免费)
  274. bill = Decimal("0")
  275. # 写入最终结果到 LongTextAudio
  276. task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
  277. if task:
  278. task.status = "SUCCEEDED"
  279. task.audio_url = audio_url
  280. task.duration = round(duration, 2)
  281. task.bill = bill
  282. task.progress = 100
  283. task.completed_at = datetime.now()
  284. task.updated_at = datetime.now()
  285. db.commit()
  286. # 写入 AudioSynthesis 记录,供创作历史使用
  287. from app.models.audio import AudioSynthesis
  288. text_preview = request.text[:1000] + "..." if len(request.text) > 1000 else request.text
  289. synthesis_record = AudioSynthesis(
  290. user_id=self.user_id,
  291. model=request.model,
  292. voice=request.voice,
  293. text=text_preview,
  294. audio_url=audio_url,
  295. duration=round(duration, 2),
  296. format=fmt,
  297. characters=len(request.text),
  298. bill=bill,
  299. completed_at=datetime.now(),
  300. )
  301. db.add(synthesis_record)
  302. db.commit()
  303. db.refresh(synthesis_record)
  304. logger.info(f"[长文本TTS] 任务 {task_id} 完成,时长 {duration:.1f}s")
  305. except Exception as e:
  306. logger.error(f"[长文本TTS] 任务 {task_id} 失败: {e}", exc_info=True)
  307. try:
  308. task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
  309. if task:
  310. task.status = "FAILED"
  311. task.error_message = str(e)
  312. task.updated_at = datetime.now()
  313. task.completed_at = datetime.now()
  314. db.commit()
  315. except Exception:
  316. db.rollback()
  317. finally:
  318. db.close()
  319. def _convert_pcm(self, pcm_data: bytes, fmt: str, sample_rate: int) -> bytes:
  320. """PCM → 目标格式转换(复用 tts_service 的逻辑)"""
  321. import struct
  322. def to_wav(pcm: bytes) -> bytes:
  323. channels, bits = 1, 16
  324. byte_rate = sample_rate * channels * bits // 8
  325. block_align = channels * bits // 8
  326. data_size = len(pcm)
  327. header = struct.pack(
  328. '<4sI4s4sIHHIIHH4sI',
  329. b'RIFF', 36 + data_size, b'WAVE', b'fmt ',
  330. 16, 1, channels, sample_rate,
  331. byte_rate, block_align, bits, b'data', data_size,
  332. )
  333. return header + pcm
  334. if fmt == 'pcm':
  335. return pcm_data
  336. if fmt == 'wav':
  337. return to_wav(pcm_data)
  338. if fmt in ('mp3', 'opus'):
  339. try:
  340. from pydub import AudioSegment
  341. wav = to_wav(pcm_data)
  342. audio = AudioSegment.from_wav(io.BytesIO(wav))
  343. buf = io.BytesIO()
  344. bitrate = '128k' if fmt == 'mp3' else '32k'
  345. audio.export(buf, format=fmt, bitrate=bitrate)
  346. return buf.getvalue()
  347. except Exception as e:
  348. logger.warning(f"转换 {fmt} 失败({e}),回退为 WAV")
  349. return to_wav(pcm_data)
  350. return to_wav(pcm_data)
  351. def _split_text(self, text: str) -> List[str]:
  352. """
  353. 按句子边界智能切割文本
  354. Args:
  355. text: 待切割的文本
  356. Returns:
  357. 切割后的文本列表
  358. """
  359. if not text:
  360. return []
  361. text_length = len(text)
  362. if text_length <= self.MAX_SEGMENT_LENGTH:
  363. return [text]
  364. segments = []
  365. current = ""
  366. current_length = 0
  367. # 按句子分隔符切割
  368. sentence_pattern = r'([。!?;\n])'
  369. parts = re.split(sentence_pattern, text)
  370. i = 0
  371. while i < len(parts):
  372. part = parts[i]
  373. delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(sentence_pattern, parts[i + 1]) else ""
  374. if delimiter:
  375. i += 1
  376. full_sentence = part + delimiter
  377. sentence_length = len(full_sentence)
  378. if current_length + sentence_length <= self.MAX_SEGMENT_LENGTH:
  379. current += full_sentence
  380. current_length += sentence_length
  381. else:
  382. if current:
  383. segments.append(current)
  384. current = ""
  385. current_length = 0
  386. if sentence_length > self.MAX_SEGMENT_LENGTH:
  387. # 强制按字符数切割
  388. for j in range(0, sentence_length, self.MAX_SEGMENT_LENGTH):
  389. chunk = full_sentence[j:j + self.MAX_SEGMENT_LENGTH]
  390. if j + self.MAX_SEGMENT_LENGTH < sentence_length:
  391. segments.append(chunk)
  392. else:
  393. current = chunk
  394. current_length = len(chunk)
  395. else:
  396. current = full_sentence
  397. current_length = sentence_length
  398. i += 1
  399. if current:
  400. segments.append(current)
  401. return segments