""" 长文本转音频服务 提供长文本分段合成和拼接的业务逻辑处理 """ import asyncio import io import logging import re import uuid from datetime import datetime from typing import List from sqlalchemy.orm import Session from sqlalchemy import desc, text from sqlalchemy.exc import ProgrammingError from fastapi import HTTPException from decimal import Decimal from app.database import SessionLocal from app.models.audio import LongTextAudio from app.schemas.audio_v2 import ( LongTextAudioCreateRequest, LongTextAudioResponse, LongTextAudioListResponse, TaskListQueryParams, SegmentInfo ) from .base_service import BaseV2Service logger = logging.getLogger(__name__) class LongTextAudioService(BaseV2Service): """长文本转音频服务""" # 有效的TTS模型 VALID_MODELS = ["cosyvoice-v3-flash", "cosyvoice-v3-plus", "cosyvoice-v2"] # 每段最大字符数 MAX_SEGMENT_LENGTH = 500 def _ensure_storage_ready(self) -> None: """Ensure schema/table exist for long text tasks.""" try: self.db.execute(text("CREATE SCHEMA IF NOT EXISTS aigcspace")) self.db.commit() except Exception as e: self.db.rollback() logger.warning(f"创建schema失败: {type(e).__name__}: {str(e)}") try: LongTextAudio.__table__.create(self.db.get_bind(), checkfirst=True) except Exception as e: logger.warning(f"创建long_text_audio表失败: {type(e).__name__}: {str(e)}") async def create_task( self, request: LongTextAudioCreateRequest ) -> LongTextAudioResponse: """ 创建长文本转音频任务 Args: request: 创建请求 Returns: 任务响应 Raises: HTTPException: 创建失败 """ # 验证模型 if request.model not in self.VALID_MODELS: raise HTTPException( status_code=400, detail=f"无效的模型,支持的模型: {self.VALID_MODELS}" ) try: self._ensure_storage_ready() # 生成任务ID task_id = str(uuid.uuid4()) # 分段文本 segments = self._split_text(request.text) # 创建分段信息 segment_infos = [] for idx, segment_text in enumerate(segments, 1): segment_infos.append({ "index": idx, "text": segment_text, "task_id": None, "audio_url": None, "duration": None, "status": "PENDING" }) # 保存到数据库 long_text_task = LongTextAudio( user_id=self.user_id, task_id=task_id, model=request.model, voice=request.voice, text=request.text, text_length=len(request.text), segment_count=len(segments), segments=segment_infos, format=request.format, custom_name=request.custom_name, status="PENDING", progress=0 ) try: self.db.add(long_text_task) self.db.commit() self.db.refresh(long_text_task) except ProgrammingError: self.db.rollback() self._ensure_storage_ready() self.db.add(long_text_task) self.db.commit() self.db.refresh(long_text_task) # 转换segments为SegmentInfo对象 response = LongTextAudioResponse.from_orm(long_text_task) response.segments = [SegmentInfo(**seg) for seg in segment_infos] # 启动后台处理(不阻塞当前请求) asyncio.create_task( self._process_task( task_id=task_id, request=request, segments=segments, ) ) return response except HTTPException: raise except Exception as e: logger.error(f"创建长文本任务失败: {type(e).__name__}: {str(e)}") raise HTTPException( status_code=502, detail=f"创建长文本任务失败: {str(e)}" ) async def get_task(self, task_id: str) -> LongTextAudioResponse: """ 查询任务详情 Args: task_id: 任务ID Returns: 任务响应 Raises: HTTPException: 任务不存在 """ task = self.db.query(LongTextAudio).filter( LongTextAudio.task_id == task_id, LongTextAudio.user_id == self.user_id ).first() if not task: raise HTTPException(status_code=404, detail="任务不存在") # 转换segments response = LongTextAudioResponse.from_orm(task) if task.segments: response.segments = [SegmentInfo(**seg) for seg in task.segments] return response async def list_tasks( self, params: TaskListQueryParams ) -> LongTextAudioListResponse: """ 查询任务列表 Args: params: 查询参数 Returns: 任务列表响应 """ query = self.db.query(LongTextAudio).filter( LongTextAudio.user_id == self.user_id ) # 状态筛选 if params.status: query = query.filter(LongTextAudio.status == params.status) # 总数 total = query.count() # 排序 if params.order_by == "created_at": order_column = LongTextAudio.created_at elif params.order_by == "updated_at": order_column = LongTextAudio.updated_at else: order_column = LongTextAudio.created_at if params.order == "desc": query = query.order_by(desc(order_column)) else: query = query.order_by(order_column) # 分页 offset = (params.page - 1) * params.page_size tasks = query.offset(offset).limit(params.page_size).all() items = [] for task in tasks: response = LongTextAudioResponse.from_orm(task) if task.segments: response.segments = [SegmentInfo(**seg) for seg in task.segments] items.append(response) return LongTextAudioListResponse(total=total, items=items) async def _process_task( self, task_id: str, request: LongTextAudioCreateRequest, segments: List[str], ) -> None: """ 后台执行长文本 TTS 合成(逐段合成 → 合并 → 上传 OSS) 使用独立的数据库会话,避免与请求会话冲突。 所有阻塞的 DashScope 调用都放到线程池执行,不阻塞 event loop。 """ db = SessionLocal() try: import dashscope from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat as DashAudioFormat dashscope.api_key = self.api_key # 更新状态为 PROCESSING task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first() if not task: return task.status = "PROCESSING" task.updated_at = datetime.now() db.commit() # 采样率 → PCM 格式映射 sample_rate = getattr(request, 'sample_rate', 22050) or 22050 pcm_format_map = { 8000: DashAudioFormat.PCM_8000HZ_MONO_16BIT, 16000: DashAudioFormat.PCM_16000HZ_MONO_16BIT, 22050: DashAudioFormat.PCM_22050HZ_MONO_16BIT, 24000: DashAudioFormat.PCM_24000HZ_MONO_16BIT, 44100: DashAudioFormat.PCM_44100HZ_MONO_16BIT, 48000: DashAudioFormat.PCM_48000HZ_MONO_16BIT, } pcm_fmt = pcm_format_map.get(sample_rate, DashAudioFormat.PCM_22050HZ_MONO_16BIT) loop = asyncio.get_event_loop() audio_parts: List[bytes] = [] seg_infos = list(task.segments) if task.segments else [] for idx, seg_text in enumerate(segments): def _synth(text=seg_text): synth = SpeechSynthesizer( model=request.model, voice=request.voice, format=pcm_fmt, volume=getattr(request, 'volume', 50), speech_rate=getattr(request, 'speech_rate', 1.0), pitch_rate=getattr(request, 'pitch_rate', 1.0), ) return synth.call(text) try: audio_data = await loop.run_in_executor(None, _synth) except Exception as e: logger.error(f"[长文本TTS] 第{idx+1}段合成失败: {e}") audio_data = None # 更新分段状态 if idx < len(seg_infos): seg_infos[idx] = dict(seg_infos[idx]) if audio_data: seg_infos[idx]['status'] = 'SUCCEEDED' audio_parts.append(audio_data) else: seg_infos[idx]['status'] = 'FAILED' elif audio_data: audio_parts.append(audio_data) # 更新进度 done = sum(1 for s in seg_infos if s.get('status') in ('SUCCEEDED', 'FAILED')) task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first() if task: task.segments = seg_infos task.progress = int(done / len(segments) * 100) task.updated_at = datetime.now() db.commit() if not audio_parts: raise RuntimeError("所有分段合成均失败") # 合并 PCM → 目标格式 merged_pcm = b''.join(audio_parts) fmt = getattr(request, 'format', 'mp3') or 'mp3' final_audio = self._convert_pcm(merged_pcm, fmt, sample_rate) # 上传 OSS audio_url = self.oss_service.upload_file( final_audio, prefix="audio/long-text", original_filename=f"audio.{fmt}", ) # 估算时长 duration = len(merged_pcm) / (sample_rate * 2) # 费用(API调用免费) bill = Decimal("0") # 写入最终结果到 LongTextAudio task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first() if task: task.status = "SUCCEEDED" task.audio_url = audio_url task.duration = round(duration, 2) task.bill = bill task.progress = 100 task.completed_at = datetime.now() task.updated_at = datetime.now() db.commit() # 写入 AudioSynthesis 记录,供创作历史使用 from app.models.audio import AudioSynthesis text_preview = request.text[:1000] + "..." if len(request.text) > 1000 else request.text synthesis_record = AudioSynthesis( user_id=self.user_id, model=request.model, voice=request.voice, text=text_preview, audio_url=audio_url, duration=round(duration, 2), format=fmt, characters=len(request.text), bill=bill, completed_at=datetime.now(), ) db.add(synthesis_record) db.commit() db.refresh(synthesis_record) logger.info(f"[长文本TTS] 任务 {task_id} 完成,时长 {duration:.1f}s") except Exception as e: logger.error(f"[长文本TTS] 任务 {task_id} 失败: {e}", exc_info=True) try: task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first() if task: task.status = "FAILED" task.error_message = str(e) task.updated_at = datetime.now() task.completed_at = datetime.now() db.commit() except Exception: db.rollback() finally: db.close() def _convert_pcm(self, pcm_data: bytes, fmt: str, sample_rate: int) -> bytes: """PCM → 目标格式转换(复用 tts_service 的逻辑)""" import struct def to_wav(pcm: bytes) -> bytes: channels, bits = 1, 16 byte_rate = sample_rate * channels * bits // 8 block_align = channels * bits // 8 data_size = len(pcm) header = struct.pack( '<4sI4s4sIHHIIHH4sI', b'RIFF', 36 + data_size, b'WAVE', b'fmt ', 16, 1, channels, sample_rate, byte_rate, block_align, bits, b'data', data_size, ) return header + pcm if fmt == 'pcm': return pcm_data if fmt == 'wav': return to_wav(pcm_data) if fmt in ('mp3', 'opus'): try: from pydub import AudioSegment wav = to_wav(pcm_data) audio = AudioSegment.from_wav(io.BytesIO(wav)) buf = io.BytesIO() bitrate = '128k' if fmt == 'mp3' else '32k' audio.export(buf, format=fmt, bitrate=bitrate) return buf.getvalue() except Exception as e: logger.warning(f"转换 {fmt} 失败({e}),回退为 WAV") return to_wav(pcm_data) return to_wav(pcm_data) def _split_text(self, text: str) -> List[str]: """ 按句子边界智能切割文本 Args: text: 待切割的文本 Returns: 切割后的文本列表 """ if not text: return [] text_length = len(text) if text_length <= self.MAX_SEGMENT_LENGTH: return [text] segments = [] current = "" current_length = 0 # 按句子分隔符切割 sentence_pattern = r'([。!?;\n])' parts = re.split(sentence_pattern, text) i = 0 while i < len(parts): part = parts[i] delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(sentence_pattern, parts[i + 1]) else "" if delimiter: i += 1 full_sentence = part + delimiter sentence_length = len(full_sentence) if current_length + sentence_length <= self.MAX_SEGMENT_LENGTH: current += full_sentence current_length += sentence_length else: if current: segments.append(current) current = "" current_length = 0 if sentence_length > self.MAX_SEGMENT_LENGTH: # 强制按字符数切割 for j in range(0, sentence_length, self.MAX_SEGMENT_LENGTH): chunk = full_sentence[j:j + self.MAX_SEGMENT_LENGTH] if j + self.MAX_SEGMENT_LENGTH < sentence_length: segments.append(chunk) else: current = chunk current_length = len(chunk) else: current = full_sentence current_length = sentence_length i += 1 if current: segments.append(current) return segments