| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475 |
- """
- 长文本转音频服务
- 提供长文本分段合成和拼接的业务逻辑处理
- """
- import asyncio
- import io
- import logging
- import re
- import uuid
- from datetime import datetime
- from typing import List
- from sqlalchemy.orm import Session
- from sqlalchemy import desc, text
- from sqlalchemy.exc import ProgrammingError
- from fastapi import HTTPException
- from decimal import Decimal
- from app.database import SessionLocal
- from app.models.audio import LongTextAudio
- from app.schemas.audio_v2 import (
- LongTextAudioCreateRequest,
- LongTextAudioResponse,
- LongTextAudioListResponse,
- TaskListQueryParams,
- SegmentInfo
- )
- from .base_service import BaseV2Service
- logger = logging.getLogger(__name__)
- class LongTextAudioService(BaseV2Service):
- """长文本转音频服务"""
-
- # 有效的TTS模型
- VALID_MODELS = ["cosyvoice-v3-flash", "cosyvoice-v3-plus", "cosyvoice-v2"]
-
- # 每段最大字符数
- MAX_SEGMENT_LENGTH = 500
- def _ensure_storage_ready(self) -> None:
- """Ensure schema/table exist for long text tasks."""
- try:
- self.db.execute(text("CREATE SCHEMA IF NOT EXISTS aigcspace"))
- self.db.commit()
- except Exception as e:
- self.db.rollback()
- logger.warning(f"创建schema失败: {type(e).__name__}: {str(e)}")
- try:
- LongTextAudio.__table__.create(self.db.get_bind(), checkfirst=True)
- except Exception as e:
- logger.warning(f"创建long_text_audio表失败: {type(e).__name__}: {str(e)}")
-
- async def create_task(
- self,
- request: LongTextAudioCreateRequest
- ) -> LongTextAudioResponse:
- """
- 创建长文本转音频任务
-
- Args:
- request: 创建请求
-
- Returns:
- 任务响应
-
- Raises:
- HTTPException: 创建失败
- """
- # 验证模型
- if request.model not in self.VALID_MODELS:
- raise HTTPException(
- status_code=400,
- detail=f"无效的模型,支持的模型: {self.VALID_MODELS}"
- )
-
- try:
- self._ensure_storage_ready()
- # 生成任务ID
- task_id = str(uuid.uuid4())
-
- # 分段文本
- segments = self._split_text(request.text)
-
- # 创建分段信息
- segment_infos = []
- for idx, segment_text in enumerate(segments, 1):
- segment_infos.append({
- "index": idx,
- "text": segment_text,
- "task_id": None,
- "audio_url": None,
- "duration": None,
- "status": "PENDING"
- })
-
- # 保存到数据库
- long_text_task = LongTextAudio(
- user_id=self.user_id,
- task_id=task_id,
- model=request.model,
- voice=request.voice,
- text=request.text,
- text_length=len(request.text),
- segment_count=len(segments),
- segments=segment_infos,
- format=request.format,
- custom_name=request.custom_name,
- status="PENDING",
- progress=0
- )
-
- try:
- self.db.add(long_text_task)
- self.db.commit()
- self.db.refresh(long_text_task)
- except ProgrammingError:
- self.db.rollback()
- self._ensure_storage_ready()
- self.db.add(long_text_task)
- self.db.commit()
- self.db.refresh(long_text_task)
-
- # 转换segments为SegmentInfo对象
- response = LongTextAudioResponse.from_orm(long_text_task)
- response.segments = [SegmentInfo(**seg) for seg in segment_infos]
- # 启动后台处理(不阻塞当前请求)
- asyncio.create_task(
- self._process_task(
- task_id=task_id,
- request=request,
- segments=segments,
- )
- )
- return response
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"创建长文本任务失败: {type(e).__name__}: {str(e)}")
- raise HTTPException(
- status_code=502,
- detail=f"创建长文本任务失败: {str(e)}"
- )
-
- async def get_task(self, task_id: str) -> LongTextAudioResponse:
- """
- 查询任务详情
-
- Args:
- task_id: 任务ID
-
- Returns:
- 任务响应
-
- Raises:
- HTTPException: 任务不存在
- """
- task = self.db.query(LongTextAudio).filter(
- LongTextAudio.task_id == task_id,
- LongTextAudio.user_id == self.user_id
- ).first()
-
- if not task:
- raise HTTPException(status_code=404, detail="任务不存在")
-
- # 转换segments
- response = LongTextAudioResponse.from_orm(task)
- if task.segments:
- response.segments = [SegmentInfo(**seg) for seg in task.segments]
-
- return response
-
- async def list_tasks(
- self,
- params: TaskListQueryParams
- ) -> LongTextAudioListResponse:
- """
- 查询任务列表
-
- Args:
- params: 查询参数
-
- Returns:
- 任务列表响应
- """
- query = self.db.query(LongTextAudio).filter(
- LongTextAudio.user_id == self.user_id
- )
-
- # 状态筛选
- if params.status:
- query = query.filter(LongTextAudio.status == params.status)
-
- # 总数
- total = query.count()
-
- # 排序
- if params.order_by == "created_at":
- order_column = LongTextAudio.created_at
- elif params.order_by == "updated_at":
- order_column = LongTextAudio.updated_at
- else:
- order_column = LongTextAudio.created_at
-
- if params.order == "desc":
- query = query.order_by(desc(order_column))
- else:
- query = query.order_by(order_column)
-
- # 分页
- offset = (params.page - 1) * params.page_size
- tasks = query.offset(offset).limit(params.page_size).all()
-
- items = []
- for task in tasks:
- response = LongTextAudioResponse.from_orm(task)
- if task.segments:
- response.segments = [SegmentInfo(**seg) for seg in task.segments]
- items.append(response)
-
- return LongTextAudioListResponse(total=total, items=items)
-
- async def _process_task(
- self,
- task_id: str,
- request: LongTextAudioCreateRequest,
- segments: List[str],
- ) -> None:
- """
- 后台执行长文本 TTS 合成(逐段合成 → 合并 → 上传 OSS)
- 使用独立的数据库会话,避免与请求会话冲突。
- 所有阻塞的 DashScope 调用都放到线程池执行,不阻塞 event loop。
- """
- db = SessionLocal()
- try:
- import dashscope
- from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat as DashAudioFormat
- dashscope.api_key = self.api_key
- # 更新状态为 PROCESSING
- task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
- if not task:
- return
- task.status = "PROCESSING"
- task.updated_at = datetime.now()
- db.commit()
- # 采样率 → PCM 格式映射
- sample_rate = getattr(request, 'sample_rate', 22050) or 22050
- pcm_format_map = {
- 8000: DashAudioFormat.PCM_8000HZ_MONO_16BIT,
- 16000: DashAudioFormat.PCM_16000HZ_MONO_16BIT,
- 22050: DashAudioFormat.PCM_22050HZ_MONO_16BIT,
- 24000: DashAudioFormat.PCM_24000HZ_MONO_16BIT,
- 44100: DashAudioFormat.PCM_44100HZ_MONO_16BIT,
- 48000: DashAudioFormat.PCM_48000HZ_MONO_16BIT,
- }
- pcm_fmt = pcm_format_map.get(sample_rate, DashAudioFormat.PCM_22050HZ_MONO_16BIT)
- loop = asyncio.get_event_loop()
- audio_parts: List[bytes] = []
- seg_infos = list(task.segments) if task.segments else []
- for idx, seg_text in enumerate(segments):
- def _synth(text=seg_text):
- synth = SpeechSynthesizer(
- model=request.model,
- voice=request.voice,
- format=pcm_fmt,
- volume=getattr(request, 'volume', 50),
- speech_rate=getattr(request, 'speech_rate', 1.0),
- pitch_rate=getattr(request, 'pitch_rate', 1.0),
- )
- return synth.call(text)
- try:
- audio_data = await loop.run_in_executor(None, _synth)
- except Exception as e:
- logger.error(f"[长文本TTS] 第{idx+1}段合成失败: {e}")
- audio_data = None
- # 更新分段状态
- if idx < len(seg_infos):
- seg_infos[idx] = dict(seg_infos[idx])
- if audio_data:
- seg_infos[idx]['status'] = 'SUCCEEDED'
- audio_parts.append(audio_data)
- else:
- seg_infos[idx]['status'] = 'FAILED'
- elif audio_data:
- audio_parts.append(audio_data)
- # 更新进度
- done = sum(1 for s in seg_infos if s.get('status') in ('SUCCEEDED', 'FAILED'))
- task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
- if task:
- task.segments = seg_infos
- task.progress = int(done / len(segments) * 100)
- task.updated_at = datetime.now()
- db.commit()
- if not audio_parts:
- raise RuntimeError("所有分段合成均失败")
- # 合并 PCM → 目标格式
- merged_pcm = b''.join(audio_parts)
- fmt = getattr(request, 'format', 'mp3') or 'mp3'
- final_audio = self._convert_pcm(merged_pcm, fmt, sample_rate)
- # 上传 OSS
- audio_url = self.oss_service.upload_file(
- final_audio,
- prefix="audio/long-text",
- original_filename=f"audio.{fmt}",
- )
- # 估算时长
- duration = len(merged_pcm) / (sample_rate * 2)
- # 费用(API调用免费)
- bill = Decimal("0")
- # 写入最终结果到 LongTextAudio
- task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
- if task:
- task.status = "SUCCEEDED"
- task.audio_url = audio_url
- task.duration = round(duration, 2)
- task.bill = bill
- task.progress = 100
- task.completed_at = datetime.now()
- task.updated_at = datetime.now()
- db.commit()
- # 写入 AudioSynthesis 记录,供创作历史使用
- from app.models.audio import AudioSynthesis
- text_preview = request.text[:1000] + "..." if len(request.text) > 1000 else request.text
- synthesis_record = AudioSynthesis(
- user_id=self.user_id,
- model=request.model,
- voice=request.voice,
- text=text_preview,
- audio_url=audio_url,
- duration=round(duration, 2),
- format=fmt,
- characters=len(request.text),
- bill=bill,
- completed_at=datetime.now(),
- )
- db.add(synthesis_record)
- db.commit()
- db.refresh(synthesis_record)
- logger.info(f"[长文本TTS] 任务 {task_id} 完成,时长 {duration:.1f}s")
- except Exception as e:
- logger.error(f"[长文本TTS] 任务 {task_id} 失败: {e}", exc_info=True)
- try:
- task = db.query(LongTextAudio).filter(LongTextAudio.task_id == task_id).first()
- if task:
- task.status = "FAILED"
- task.error_message = str(e)
- task.updated_at = datetime.now()
- task.completed_at = datetime.now()
- db.commit()
- except Exception:
- db.rollback()
- finally:
- db.close()
- def _convert_pcm(self, pcm_data: bytes, fmt: str, sample_rate: int) -> bytes:
- """PCM → 目标格式转换(复用 tts_service 的逻辑)"""
- import struct
- def to_wav(pcm: bytes) -> bytes:
- channels, bits = 1, 16
- byte_rate = sample_rate * channels * bits // 8
- block_align = channels * bits // 8
- data_size = len(pcm)
- header = struct.pack(
- '<4sI4s4sIHHIIHH4sI',
- b'RIFF', 36 + data_size, b'WAVE', b'fmt ',
- 16, 1, channels, sample_rate,
- byte_rate, block_align, bits, b'data', data_size,
- )
- return header + pcm
- if fmt == 'pcm':
- return pcm_data
- if fmt == 'wav':
- return to_wav(pcm_data)
- if fmt in ('mp3', 'opus'):
- try:
- from pydub import AudioSegment
- wav = to_wav(pcm_data)
- audio = AudioSegment.from_wav(io.BytesIO(wav))
- buf = io.BytesIO()
- bitrate = '128k' if fmt == 'mp3' else '32k'
- audio.export(buf, format=fmt, bitrate=bitrate)
- return buf.getvalue()
- except Exception as e:
- logger.warning(f"转换 {fmt} 失败({e}),回退为 WAV")
- return to_wav(pcm_data)
- return to_wav(pcm_data)
- def _split_text(self, text: str) -> List[str]:
- """
- 按句子边界智能切割文本
-
- Args:
- text: 待切割的文本
-
- Returns:
- 切割后的文本列表
- """
- if not text:
- return []
-
- text_length = len(text)
- if text_length <= self.MAX_SEGMENT_LENGTH:
- return [text]
-
- segments = []
- current = ""
- current_length = 0
-
- # 按句子分隔符切割
- sentence_pattern = r'([。!?;\n])'
- parts = re.split(sentence_pattern, text)
-
- i = 0
- while i < len(parts):
- part = parts[i]
- delimiter = parts[i + 1] if i + 1 < len(parts) and re.match(sentence_pattern, parts[i + 1]) else ""
- if delimiter:
- i += 1
-
- full_sentence = part + delimiter
- sentence_length = len(full_sentence)
-
- if current_length + sentence_length <= self.MAX_SEGMENT_LENGTH:
- current += full_sentence
- current_length += sentence_length
- else:
- if current:
- segments.append(current)
- current = ""
- current_length = 0
-
- if sentence_length > self.MAX_SEGMENT_LENGTH:
- # 强制按字符数切割
- for j in range(0, sentence_length, self.MAX_SEGMENT_LENGTH):
- chunk = full_sentence[j:j + self.MAX_SEGMENT_LENGTH]
- if j + self.MAX_SEGMENT_LENGTH < sentence_length:
- segments.append(chunk)
- else:
- current = chunk
- current_length = len(chunk)
- else:
- current = full_sentence
- current_length = sentence_length
-
- i += 1
-
- if current:
- segments.append(current)
-
- return segments
|