task_manager.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. 任务管理器
  3. 提供统一的任务状态查询和管理功能
  4. """
  5. import logging
  6. from typing import Optional, List, Dict, Any
  7. from datetime import datetime
  8. from sqlalchemy.orm import Session
  9. from sqlalchemy import desc, or_
  10. from fastapi import HTTPException
  11. from decimal import Decimal
  12. from app.models.audio import (
  13. ASRRecognitionV2,
  14. AudioSynthesisV2,
  15. VoiceCloneV2,
  16. LongTextAudio
  17. )
  18. from app.schemas.audio_v2 import (
  19. TaskStatusResponse,
  20. TaskStatistics,
  21. UserAudioStatisticsResponse
  22. )
  23. logger = logging.getLogger(__name__)
  24. class TaskManager:
  25. """任务管理器"""
  26. def __init__(self, db: Session, user_id: str):
  27. """
  28. 初始化任务管理器
  29. Args:
  30. db: 数据库会话
  31. user_id: 用户ID
  32. """
  33. self.db = db
  34. self.user_id = user_id
  35. def get_task_status(self, task_id: str) -> TaskStatusResponse:
  36. """
  37. 查询任务状态(跨所有任务类型)
  38. Args:
  39. task_id: 任务ID
  40. Returns:
  41. 任务状态响应
  42. Raises:
  43. HTTPException: 任务不存在
  44. """
  45. # 依次查询各个表
  46. task = None
  47. task_type = None
  48. # 查询语音识别任务
  49. task = self.db.query(ASRRecognitionV2).filter(
  50. ASRRecognitionV2.task_id == task_id,
  51. ASRRecognitionV2.user_id == self.user_id
  52. ).first()
  53. if task:
  54. task_type = "asr_recognition"
  55. # 查询语音合成任务
  56. if not task:
  57. task = self.db.query(AudioSynthesisV2).filter(
  58. AudioSynthesisV2.task_id == task_id,
  59. AudioSynthesisV2.user_id == self.user_id
  60. ).first()
  61. if task:
  62. task_type = "audio_synthesis"
  63. # 查询声音克隆任务
  64. if not task:
  65. task = self.db.query(VoiceCloneV2).filter(
  66. VoiceCloneV2.task_id == task_id,
  67. VoiceCloneV2.user_id == self.user_id
  68. ).first()
  69. if task:
  70. task_type = "voice_clone"
  71. # 查询长文本转音频任务
  72. if not task:
  73. task = self.db.query(LongTextAudio).filter(
  74. LongTextAudio.task_id == task_id,
  75. LongTextAudio.user_id == self.user_id
  76. ).first()
  77. if task:
  78. task_type = "long_text_audio"
  79. if not task:
  80. raise HTTPException(status_code=404, detail="任务不存在")
  81. # 构建响应
  82. progress = None
  83. if task_type == "long_text_audio":
  84. progress = task.progress
  85. return TaskStatusResponse(
  86. task_id=task.task_id,
  87. status=task.status,
  88. progress=progress,
  89. error_message=task.error_message,
  90. created_at=task.created_at,
  91. updated_at=task.updated_at,
  92. completed_at=task.completed_at
  93. )
  94. def cancel_task(self, task_id: str) -> bool:
  95. """
  96. 取消任务(仅PENDING状态可取消)
  97. Args:
  98. task_id: 任务ID
  99. Returns:
  100. 是否成功取消
  101. Raises:
  102. HTTPException: 任务不存在或无法取消
  103. """
  104. # 查找任务
  105. task = None
  106. task_type = None
  107. for model, type_name in [
  108. (ASRRecognitionV2, "asr_recognition"),
  109. (AudioSynthesisV2, "audio_synthesis"),
  110. (VoiceCloneV2, "voice_clone"),
  111. (LongTextAudio, "long_text_audio")
  112. ]:
  113. task = self.db.query(model).filter(
  114. model.task_id == task_id,
  115. model.user_id == self.user_id
  116. ).first()
  117. if task:
  118. task_type = type_name
  119. break
  120. if not task:
  121. raise HTTPException(status_code=404, detail="任务不存在")
  122. # 检查状态
  123. if task.status != "PENDING":
  124. raise HTTPException(
  125. status_code=400,
  126. detail=f"任务状态为{task.status},无法取消"
  127. )
  128. # 更新状态为FAILED
  129. task.status = "FAILED"
  130. task.error_message = "用户取消"
  131. task.updated_at = datetime.now()
  132. task.completed_at = datetime.now()
  133. self.db.commit()
  134. return True
  135. def delete_task(self, task_id: str) -> bool:
  136. """
  137. 删除任务记录
  138. Args:
  139. task_id: 任务ID
  140. Returns:
  141. 是否成功删除
  142. Raises:
  143. HTTPException: 任务不存在
  144. """
  145. # 查找并删除任务
  146. deleted = False
  147. for model in [ASRRecognitionV2, AudioSynthesisV2, VoiceCloneV2, LongTextAudio]:
  148. task = self.db.query(model).filter(
  149. model.task_id == task_id,
  150. model.user_id == self.user_id
  151. ).first()
  152. if task:
  153. self.db.delete(task)
  154. self.db.commit()
  155. deleted = True
  156. break
  157. if not deleted:
  158. raise HTTPException(status_code=404, detail="任务不存在")
  159. return True
  160. def batch_delete_tasks(self, task_ids: List[str]) -> Dict[str, Any]:
  161. """
  162. 批量删除任务
  163. Args:
  164. task_ids: 任务ID列表
  165. Returns:
  166. 删除结果统计
  167. """
  168. success_count = 0
  169. failed_tasks = []
  170. for task_id in task_ids:
  171. try:
  172. self.delete_task(task_id)
  173. success_count += 1
  174. except Exception as e:
  175. failed_tasks.append({
  176. "task_id": task_id,
  177. "error": str(e)
  178. })
  179. return {
  180. "success_count": success_count,
  181. "failed_count": len(failed_tasks),
  182. "failed_tasks": failed_tasks
  183. }
  184. def get_user_statistics(self) -> UserAudioStatisticsResponse:
  185. """
  186. 获取用户语音统计信息
  187. Returns:
  188. 用户语音统计响应
  189. """
  190. # 语音识别统计
  191. asr_stats = self._get_task_statistics(ASRRecognitionV2)
  192. # 语音合成统计
  193. tts_stats = self._get_task_statistics(AudioSynthesisV2)
  194. # 声音克隆统计
  195. voice_clone_stats = self._get_task_statistics(VoiceCloneV2)
  196. # 长文本转音频统计
  197. long_text_stats = self._get_task_statistics(LongTextAudio)
  198. return UserAudioStatisticsResponse(
  199. asr_stats=asr_stats,
  200. tts_stats=tts_stats,
  201. voice_clone_stats=voice_clone_stats,
  202. long_text_stats=long_text_stats
  203. )
  204. def _get_task_statistics(self, model) -> TaskStatistics:
  205. """
  206. 获取指定模型的任务统计
  207. Args:
  208. model: ORM模型类
  209. Returns:
  210. 任务统计信息
  211. """
  212. query = self.db.query(model).filter(model.user_id == self.user_id)
  213. total = query.count()
  214. pending = query.filter(model.status == "PENDING").count()
  215. processing = query.filter(model.status == "PROCESSING").count()
  216. succeeded = query.filter(model.status == "SUCCEEDED").count()
  217. failed = query.filter(model.status == "FAILED").count()
  218. # 计算总费用
  219. total_bill = self.db.query(
  220. func.sum(model.bill)
  221. ).filter(
  222. model.user_id == self.user_id
  223. ).scalar() or Decimal('0')
  224. # 计算总时长(如果有duration字段)
  225. total_duration = None
  226. if hasattr(model, 'duration'):
  227. total_duration = self.db.query(
  228. func.sum(model.duration)
  229. ).filter(
  230. model.user_id == self.user_id,
  231. model.status == "SUCCEEDED"
  232. ).scalar()
  233. return TaskStatistics(
  234. total=total,
  235. pending=pending,
  236. processing=processing,
  237. succeeded=succeeded,
  238. failed=failed,
  239. total_bill=total_bill,
  240. total_duration=total_duration
  241. )