audio_synthesis_service.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. 语音合成服务V2
  3. 提供异步语音合成的业务逻辑处理
  4. """
  5. import logging
  6. import requests
  7. from datetime import datetime
  8. from typing import List
  9. from sqlalchemy.orm import Session
  10. from sqlalchemy import desc
  11. from fastapi import HTTPException
  12. from app.models.audio import AudioSynthesisV2
  13. from app.schemas.audio_v2 import (
  14. AudioSynthesisV2CreateRequest,
  15. AudioSynthesisV2Response,
  16. AudioSynthesisV2ListResponse,
  17. TaskListQueryParams
  18. )
  19. from .base_service import BaseV2Service
  20. logger = logging.getLogger(__name__)
  21. class AudioSynthesisV2Service(BaseV2Service):
  22. """语音合成服务V2(异步模式)"""
  23. # DashScope API基础URL
  24. DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
  25. # 有效的TTS模型
  26. VALID_MODELS = ["cosyvoice-v3-flash", "cosyvoice-v3-plus", "cosyvoice-v2"]
  27. async def create_task(
  28. self,
  29. request: AudioSynthesisV2CreateRequest
  30. ) -> AudioSynthesisV2Response:
  31. """
  32. 创建语音合成任务
  33. Args:
  34. request: 创建请求
  35. Returns:
  36. 任务响应
  37. Raises:
  38. HTTPException: 创建失败
  39. """
  40. # 验证模型
  41. if request.model not in self.VALID_MODELS:
  42. raise HTTPException(
  43. status_code=400,
  44. detail=f"无效的模型,支持的模型: {self.VALID_MODELS}"
  45. )
  46. try:
  47. # 调用DashScope API提交异步任务
  48. url = f"{self.DASHSCOPE_BASE_URL}/services/audio/tts/synthesis"
  49. headers = {
  50. "Authorization": f"Bearer {self.api_key}",
  51. "Content-Type": "application/json",
  52. "X-DashScope-Async": "enable"
  53. }
  54. payload = {
  55. "model": request.model,
  56. "input": {"text": request.text},
  57. "parameters": {
  58. "voice": request.voice,
  59. "format": request.format
  60. }
  61. }
  62. response = requests.post(url, headers=headers, json=payload, timeout=30)
  63. if response.status_code != 200:
  64. error_data = response.json() if response.text else {}
  65. error_msg = error_data.get("message", f"HTTP {response.status_code}")
  66. logger.error(f"提交TTS任务失败: {error_msg}")
  67. raise HTTPException(
  68. status_code=502,
  69. detail=f"提交合成任务失败: {error_msg}"
  70. )
  71. data = response.json()
  72. output = data.get("output", {})
  73. task_id = output.get("task_id")
  74. if not task_id:
  75. raise HTTPException(
  76. status_code=502,
  77. detail="提交合成任务失败,未返回task_id"
  78. )
  79. # 保存到数据库
  80. tts_task = AudioSynthesisV2(
  81. user_id=self.user_id,
  82. task_id=task_id,
  83. model=request.model,
  84. voice=request.voice,
  85. text=request.text,
  86. format=request.format,
  87. characters=len(request.text),
  88. custom_name=request.custom_name,
  89. status="PENDING"
  90. )
  91. self.db.add(tts_task)
  92. self.db.commit()
  93. self.db.refresh(tts_task)
  94. return AudioSynthesisV2Response.from_orm(tts_task)
  95. except HTTPException:
  96. raise
  97. except requests.exceptions.Timeout:
  98. raise HTTPException(status_code=504, detail="提交合成任务超时")
  99. except Exception as e:
  100. logger.error(f"创建TTS任务失败: {type(e).__name__}: {str(e)}")
  101. raise HTTPException(
  102. status_code=502,
  103. detail=f"创建合成任务失败: {str(e)}"
  104. )
  105. async def get_task(self, task_id: str) -> AudioSynthesisV2Response:
  106. """
  107. 查询任务详情
  108. Args:
  109. task_id: 任务ID
  110. Returns:
  111. 任务响应
  112. Raises:
  113. HTTPException: 任务不存在
  114. """
  115. task = self.db.query(AudioSynthesisV2).filter(
  116. AudioSynthesisV2.task_id == task_id,
  117. AudioSynthesisV2.user_id == self.user_id
  118. ).first()
  119. if not task:
  120. raise HTTPException(status_code=404, detail="任务不存在")
  121. # 如果任务未完成,查询最新状态
  122. if task.status in ["PENDING", "PROCESSING"]:
  123. await self._update_task_status(task)
  124. return AudioSynthesisV2Response.from_orm(task)
  125. async def list_tasks(
  126. self,
  127. params: TaskListQueryParams
  128. ) -> AudioSynthesisV2ListResponse:
  129. """
  130. 查询任务列表
  131. Args:
  132. params: 查询参数
  133. Returns:
  134. 任务列表响应
  135. """
  136. query = self.db.query(AudioSynthesisV2).filter(
  137. AudioSynthesisV2.user_id == self.user_id
  138. )
  139. # 状态筛选
  140. if params.status:
  141. query = query.filter(AudioSynthesisV2.status == params.status)
  142. # 总数
  143. total = query.count()
  144. # 排序
  145. if params.order_by == "created_at":
  146. order_column = AudioSynthesisV2.created_at
  147. elif params.order_by == "updated_at":
  148. order_column = AudioSynthesisV2.updated_at
  149. else:
  150. order_column = AudioSynthesisV2.created_at
  151. if params.order == "desc":
  152. query = query.order_by(desc(order_column))
  153. else:
  154. query = query.order_by(order_column)
  155. # 分页
  156. offset = (params.page - 1) * params.page_size
  157. tasks = query.offset(offset).limit(params.page_size).all()
  158. items = [AudioSynthesisV2Response.from_orm(task) for task in tasks]
  159. return AudioSynthesisV2ListResponse(total=total, items=items)
  160. async def _update_task_status(self, task: AudioSynthesisV2) -> None:
  161. """
  162. 更新任务状态(从DashScope查询)
  163. Args:
  164. task: 任务对象
  165. """
  166. try:
  167. url = f"{self.DASHSCOPE_BASE_URL}/tasks/{task.task_id}"
  168. headers = {
  169. "Authorization": f"Bearer {self.api_key}",
  170. "X-DashScope-Async": "enable"
  171. }
  172. response = requests.get(url, headers=headers, timeout=30)
  173. if response.status_code != 200:
  174. logger.warning(f"查询任务状态失败: {response.status_code}")
  175. return
  176. data = response.json()
  177. output = data.get("output", {})
  178. # 更新状态
  179. new_status = output.get("task_status", task.status)
  180. task.status = new_status
  181. task.updated_at = datetime.now()
  182. # 如果任务完成,提取结果
  183. if new_status == "SUCCEEDED":
  184. result = output.get("result", {})
  185. # 提取音频URL
  186. task.audio_url = result.get("audio_url")
  187. # 提取时长
  188. usage = data.get("usage", {})
  189. duration = usage.get("duration", 0)
  190. task.duration = duration
  191. task.completed_at = datetime.now()
  192. elif new_status == "FAILED":
  193. # 提取错误信息
  194. task.error_message = output.get("message", "合成失败")
  195. task.completed_at = datetime.now()
  196. self.db.commit()
  197. except Exception as e:
  198. logger.error(f"更新任务状态失败: {type(e).__name__}: {str(e)}")