""" 任务管理器 提供统一的任务状态查询和管理功能 """ import logging from typing import Optional, List, Dict, Any from datetime import datetime from sqlalchemy.orm import Session from sqlalchemy import desc, or_ from fastapi import HTTPException from decimal import Decimal from app.models.audio import ( ASRRecognitionV2, AudioSynthesisV2, VoiceCloneV2, LongTextAudio ) from app.schemas.audio_v2 import ( TaskStatusResponse, TaskStatistics, UserAudioStatisticsResponse ) logger = logging.getLogger(__name__) class TaskManager: """任务管理器""" def __init__(self, db: Session, user_id: str): """ 初始化任务管理器 Args: db: 数据库会话 user_id: 用户ID """ self.db = db self.user_id = user_id def get_task_status(self, task_id: str) -> TaskStatusResponse: """ 查询任务状态(跨所有任务类型) Args: task_id: 任务ID Returns: 任务状态响应 Raises: HTTPException: 任务不存在 """ # 依次查询各个表 task = None task_type = None # 查询语音识别任务 task = self.db.query(ASRRecognitionV2).filter( ASRRecognitionV2.task_id == task_id, ASRRecognitionV2.user_id == self.user_id ).first() if task: task_type = "asr_recognition" # 查询语音合成任务 if not task: task = self.db.query(AudioSynthesisV2).filter( AudioSynthesisV2.task_id == task_id, AudioSynthesisV2.user_id == self.user_id ).first() if task: task_type = "audio_synthesis" # 查询声音克隆任务 if not task: task = self.db.query(VoiceCloneV2).filter( VoiceCloneV2.task_id == task_id, VoiceCloneV2.user_id == self.user_id ).first() if task: task_type = "voice_clone" # 查询长文本转音频任务 if not task: task = self.db.query(LongTextAudio).filter( LongTextAudio.task_id == task_id, LongTextAudio.user_id == self.user_id ).first() if task: task_type = "long_text_audio" if not task: raise HTTPException(status_code=404, detail="任务不存在") # 构建响应 progress = None if task_type == "long_text_audio": progress = task.progress return TaskStatusResponse( task_id=task.task_id, status=task.status, progress=progress, error_message=task.error_message, created_at=task.created_at, updated_at=task.updated_at, completed_at=task.completed_at ) def cancel_task(self, task_id: str) -> bool: """ 取消任务(仅PENDING状态可取消) Args: task_id: 任务ID Returns: 是否成功取消 Raises: HTTPException: 任务不存在或无法取消 """ # 查找任务 task = None task_type = None for model, type_name in [ (ASRRecognitionV2, "asr_recognition"), (AudioSynthesisV2, "audio_synthesis"), (VoiceCloneV2, "voice_clone"), (LongTextAudio, "long_text_audio") ]: task = self.db.query(model).filter( model.task_id == task_id, model.user_id == self.user_id ).first() if task: task_type = type_name break if not task: raise HTTPException(status_code=404, detail="任务不存在") # 检查状态 if task.status != "PENDING": raise HTTPException( status_code=400, detail=f"任务状态为{task.status},无法取消" ) # 更新状态为FAILED task.status = "FAILED" task.error_message = "用户取消" task.updated_at = datetime.now() task.completed_at = datetime.now() self.db.commit() return True def delete_task(self, task_id: str) -> bool: """ 删除任务记录 Args: task_id: 任务ID Returns: 是否成功删除 Raises: HTTPException: 任务不存在 """ # 查找并删除任务 deleted = False for model in [ASRRecognitionV2, AudioSynthesisV2, VoiceCloneV2, LongTextAudio]: task = self.db.query(model).filter( model.task_id == task_id, model.user_id == self.user_id ).first() if task: self.db.delete(task) self.db.commit() deleted = True break if not deleted: raise HTTPException(status_code=404, detail="任务不存在") return True def batch_delete_tasks(self, task_ids: List[str]) -> Dict[str, Any]: """ 批量删除任务 Args: task_ids: 任务ID列表 Returns: 删除结果统计 """ success_count = 0 failed_tasks = [] for task_id in task_ids: try: self.delete_task(task_id) success_count += 1 except Exception as e: failed_tasks.append({ "task_id": task_id, "error": str(e) }) return { "success_count": success_count, "failed_count": len(failed_tasks), "failed_tasks": failed_tasks } def get_user_statistics(self) -> UserAudioStatisticsResponse: """ 获取用户语音统计信息 Returns: 用户语音统计响应 """ # 语音识别统计 asr_stats = self._get_task_statistics(ASRRecognitionV2) # 语音合成统计 tts_stats = self._get_task_statistics(AudioSynthesisV2) # 声音克隆统计 voice_clone_stats = self._get_task_statistics(VoiceCloneV2) # 长文本转音频统计 long_text_stats = self._get_task_statistics(LongTextAudio) return UserAudioStatisticsResponse( asr_stats=asr_stats, tts_stats=tts_stats, voice_clone_stats=voice_clone_stats, long_text_stats=long_text_stats ) def _get_task_statistics(self, model) -> TaskStatistics: """ 获取指定模型的任务统计 Args: model: ORM模型类 Returns: 任务统计信息 """ query = self.db.query(model).filter(model.user_id == self.user_id) total = query.count() pending = query.filter(model.status == "PENDING").count() processing = query.filter(model.status == "PROCESSING").count() succeeded = query.filter(model.status == "SUCCEEDED").count() failed = query.filter(model.status == "FAILED").count() # 计算总费用 total_bill = self.db.query( func.sum(model.bill) ).filter( model.user_id == self.user_id ).scalar() or Decimal('0') # 计算总时长(如果有duration字段) total_duration = None if hasattr(model, 'duration'): total_duration = self.db.query( func.sum(model.duration) ).filter( model.user_id == self.user_id, model.status == "SUCCEEDED" ).scalar() return TaskStatistics( total=total, pending=pending, processing=processing, succeeded=succeeded, failed=failed, total_bill=total_bill, total_duration=total_duration )