voice_clone_service.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. """
  2. 声音克隆服务V2
  3. 提供异步声音克隆的业务逻辑处理
  4. """
  5. import logging
  6. from datetime import datetime
  7. from typing import List
  8. from sqlalchemy.orm import Session
  9. from sqlalchemy import desc
  10. from fastapi import HTTPException
  11. from app.models.audio import VoiceCloneV2
  12. from app.schemas.audio_v2 import (
  13. VoiceCloneV2CreateRequest,
  14. VoiceCloneV2Response,
  15. VoiceCloneV2ListResponse,
  16. TaskListQueryParams
  17. )
  18. from .base_service import BaseV2Service
  19. logger = logging.getLogger(__name__)
  20. class VoiceCloneV2Service(BaseV2Service):
  21. """声音克隆服务V2(异步模式)"""
  22. # 有效的目标模型
  23. VALID_MODELS = ["cosyvoice-v3-plus", "cosyvoice-v3-flash", "cosyvoice-v2"]
  24. def __init__(self, db: Session, user_id: str, api_key: str = None):
  25. """初始化服务"""
  26. super().__init__(db, user_id, api_key)
  27. self._voice_service = None
  28. @property
  29. def voice_service(self):
  30. """延迟初始化VoiceEnrollmentService"""
  31. if self._voice_service is None:
  32. from dashscope.audio.tts_v2 import VoiceEnrollmentService
  33. self._voice_service = VoiceEnrollmentService()
  34. return self._voice_service
  35. async def create_task(
  36. self,
  37. request: VoiceCloneV2CreateRequest
  38. ) -> VoiceCloneV2Response:
  39. """
  40. 创建声音克隆任务
  41. Args:
  42. request: 创建请求
  43. Returns:
  44. 任务响应
  45. Raises:
  46. HTTPException: 创建失败
  47. """
  48. # 验证模型
  49. if request.target_model not in self.VALID_MODELS:
  50. raise HTTPException(
  51. status_code=400,
  52. detail=f"无效的模型,支持的模型: {self.VALID_MODELS}"
  53. )
  54. try:
  55. # 调用DashScope API创建音色
  56. voice_id = self.voice_service.create_voice(
  57. target_model=request.target_model,
  58. prefix=request.prefix,
  59. url=request.audio_url
  60. )
  61. if not voice_id:
  62. raise HTTPException(
  63. status_code=502,
  64. detail="创建音色失败,未返回voice_id"
  65. )
  66. # 保存到数据库(使用voice_id作为task_id)
  67. voice_task = VoiceCloneV2(
  68. user_id=self.user_id,
  69. task_id=voice_id, # 使用voice_id作为task_id
  70. voice_id=None, # 完成后才有
  71. target_model=request.target_model,
  72. prefix=request.prefix,
  73. voice_name=request.voice_name,
  74. audio_url=request.audio_url,
  75. status="PENDING",
  76. bill=bill
  77. )
  78. self.db.add(voice_task)
  79. self.db.commit()
  80. self.db.refresh(voice_task)
  81. return VoiceCloneV2Response.from_orm(voice_task)
  82. except HTTPException:
  83. raise
  84. except Exception as e:
  85. logger.error(f"创建声音克隆任务失败: {type(e).__name__}: {str(e)}")
  86. raise HTTPException(
  87. status_code=502,
  88. detail=f"创建声音克隆任务失败: {str(e)}"
  89. )
  90. async def get_task(self, task_id: str) -> VoiceCloneV2Response:
  91. """
  92. 查询任务详情
  93. Args:
  94. task_id: 任务ID
  95. Returns:
  96. 任务响应
  97. Raises:
  98. HTTPException: 任务不存在
  99. """
  100. task = self.db.query(VoiceCloneV2).filter(
  101. VoiceCloneV2.task_id == task_id,
  102. VoiceCloneV2.user_id == self.user_id
  103. ).first()
  104. if not task:
  105. raise HTTPException(status_code=404, detail="任务不存在")
  106. # 如果任务未完成,查询最新状态
  107. if task.status in ["PENDING", "PROCESSING"]:
  108. await self._update_task_status(task)
  109. return VoiceCloneV2Response.from_orm(task)
  110. async def list_tasks(
  111. self,
  112. params: TaskListQueryParams
  113. ) -> VoiceCloneV2ListResponse:
  114. """
  115. 查询任务列表
  116. Args:
  117. params: 查询参数
  118. Returns:
  119. 任务列表响应
  120. """
  121. query = self.db.query(VoiceCloneV2).filter(
  122. VoiceCloneV2.user_id == self.user_id
  123. )
  124. # 状态筛选
  125. if params.status:
  126. query = query.filter(VoiceCloneV2.status == params.status)
  127. # 总数
  128. total = query.count()
  129. # 排序
  130. if params.order_by == "created_at":
  131. order_column = VoiceCloneV2.created_at
  132. elif params.order_by == "updated_at":
  133. order_column = VoiceCloneV2.updated_at
  134. else:
  135. order_column = VoiceCloneV2.created_at
  136. if params.order == "desc":
  137. query = query.order_by(desc(order_column))
  138. else:
  139. query = query.order_by(order_column)
  140. # 分页
  141. offset = (params.page - 1) * params.page_size
  142. tasks = query.offset(offset).limit(params.page_size).all()
  143. items = [VoiceCloneV2Response.from_orm(task) for task in tasks]
  144. return VoiceCloneV2ListResponse(total=total, items=items)
  145. async def _update_task_status(self, task: VoiceCloneV2) -> None:
  146. """
  147. 更新任务状态(从DashScope查询)
  148. Args:
  149. task: 任务对象
  150. """
  151. try:
  152. # 查询音色状态
  153. result = self.voice_service.query_voice(voice_id=task.task_id)
  154. if not result:
  155. return
  156. # 解析状态
  157. if isinstance(result, dict):
  158. status = result.get('status', 'UNKNOWN')
  159. else:
  160. status = getattr(result, 'status', 'UNKNOWN')
  161. # 映射状态
  162. status_map = {
  163. "DEPLOYING": "PROCESSING",
  164. "OK": "SUCCEEDED",
  165. "DEPLOYED": "SUCCEEDED",
  166. "UNDEPLOYED": "FAILED",
  167. "FAILED": "FAILED"
  168. }
  169. new_status = status_map.get(status, status)
  170. task.status = new_status
  171. task.updated_at = datetime.now()
  172. # 如果成功,设置voice_id
  173. if new_status == "SUCCEEDED" and not task.voice_id:
  174. task.voice_id = task.task_id
  175. task.completed_at = datetime.now()
  176. elif new_status == "FAILED":
  177. task.error_message = "音色训练失败"
  178. task.completed_at = datetime.now()
  179. self.db.commit()
  180. except Exception as e:
  181. logger.error(f"更新任务状态失败: {type(e).__name__}: {str(e)}")