| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- """
- 任务管理器
- 提供统一的任务状态查询和管理功能
- """
- import logging
- from typing import Optional, List, Dict, Any
- from datetime import datetime
- from sqlalchemy.orm import Session
- from sqlalchemy import desc, or_
- from fastapi import HTTPException
- from decimal import Decimal
- from app.models.audio import (
- ASRRecognitionV2,
- AudioSynthesisV2,
- VoiceCloneV2,
- LongTextAudio
- )
- from app.schemas.audio_v2 import (
- TaskStatusResponse,
- TaskStatistics,
- UserAudioStatisticsResponse
- )
- logger = logging.getLogger(__name__)
- class TaskManager:
- """任务管理器"""
-
- def __init__(self, db: Session, user_id: str):
- """
- 初始化任务管理器
-
- Args:
- db: 数据库会话
- user_id: 用户ID
- """
- self.db = db
- self.user_id = user_id
-
- def get_task_status(self, task_id: str) -> TaskStatusResponse:
- """
- 查询任务状态(跨所有任务类型)
-
- Args:
- task_id: 任务ID
-
- Returns:
- 任务状态响应
-
- Raises:
- HTTPException: 任务不存在
- """
- # 依次查询各个表
- task = None
- task_type = None
-
- # 查询语音识别任务
- task = self.db.query(ASRRecognitionV2).filter(
- ASRRecognitionV2.task_id == task_id,
- ASRRecognitionV2.user_id == self.user_id
- ).first()
- if task:
- task_type = "asr_recognition"
-
- # 查询语音合成任务
- if not task:
- task = self.db.query(AudioSynthesisV2).filter(
- AudioSynthesisV2.task_id == task_id,
- AudioSynthesisV2.user_id == self.user_id
- ).first()
- if task:
- task_type = "audio_synthesis"
-
- # 查询声音克隆任务
- if not task:
- task = self.db.query(VoiceCloneV2).filter(
- VoiceCloneV2.task_id == task_id,
- VoiceCloneV2.user_id == self.user_id
- ).first()
- if task:
- task_type = "voice_clone"
-
- # 查询长文本转音频任务
- if not task:
- task = self.db.query(LongTextAudio).filter(
- LongTextAudio.task_id == task_id,
- LongTextAudio.user_id == self.user_id
- ).first()
- if task:
- task_type = "long_text_audio"
-
- if not task:
- raise HTTPException(status_code=404, detail="任务不存在")
-
- # 构建响应
- progress = None
- if task_type == "long_text_audio":
- progress = task.progress
-
- return TaskStatusResponse(
- task_id=task.task_id,
- status=task.status,
- progress=progress,
- error_message=task.error_message,
- created_at=task.created_at,
- updated_at=task.updated_at,
- completed_at=task.completed_at
- )
-
- def cancel_task(self, task_id: str) -> bool:
- """
- 取消任务(仅PENDING状态可取消)
-
- Args:
- task_id: 任务ID
-
- Returns:
- 是否成功取消
-
- Raises:
- HTTPException: 任务不存在或无法取消
- """
- # 查找任务
- task = None
- task_type = None
-
- for model, type_name in [
- (ASRRecognitionV2, "asr_recognition"),
- (AudioSynthesisV2, "audio_synthesis"),
- (VoiceCloneV2, "voice_clone"),
- (LongTextAudio, "long_text_audio")
- ]:
- task = self.db.query(model).filter(
- model.task_id == task_id,
- model.user_id == self.user_id
- ).first()
- if task:
- task_type = type_name
- break
-
- if not task:
- raise HTTPException(status_code=404, detail="任务不存在")
-
- # 检查状态
- if task.status != "PENDING":
- raise HTTPException(
- status_code=400,
- detail=f"任务状态为{task.status},无法取消"
- )
-
- # 更新状态为FAILED
- task.status = "FAILED"
- task.error_message = "用户取消"
- task.updated_at = datetime.now()
- task.completed_at = datetime.now()
-
- self.db.commit()
-
- return True
-
- def delete_task(self, task_id: str) -> bool:
- """
- 删除任务记录
-
- Args:
- task_id: 任务ID
-
- Returns:
- 是否成功删除
-
- Raises:
- HTTPException: 任务不存在
- """
- # 查找并删除任务
- deleted = False
-
- for model in [ASRRecognitionV2, AudioSynthesisV2, VoiceCloneV2, LongTextAudio]:
- task = self.db.query(model).filter(
- model.task_id == task_id,
- model.user_id == self.user_id
- ).first()
-
- if task:
- self.db.delete(task)
- self.db.commit()
- deleted = True
- break
-
- if not deleted:
- raise HTTPException(status_code=404, detail="任务不存在")
-
- return True
-
- def batch_delete_tasks(self, task_ids: List[str]) -> Dict[str, Any]:
- """
- 批量删除任务
-
- Args:
- task_ids: 任务ID列表
-
- Returns:
- 删除结果统计
- """
- success_count = 0
- failed_tasks = []
-
- for task_id in task_ids:
- try:
- self.delete_task(task_id)
- success_count += 1
- except Exception as e:
- failed_tasks.append({
- "task_id": task_id,
- "error": str(e)
- })
-
- return {
- "success_count": success_count,
- "failed_count": len(failed_tasks),
- "failed_tasks": failed_tasks
- }
-
- def get_user_statistics(self) -> UserAudioStatisticsResponse:
- """
- 获取用户语音统计信息
-
- Returns:
- 用户语音统计响应
- """
- # 语音识别统计
- asr_stats = self._get_task_statistics(ASRRecognitionV2)
-
- # 语音合成统计
- tts_stats = self._get_task_statistics(AudioSynthesisV2)
-
- # 声音克隆统计
- voice_clone_stats = self._get_task_statistics(VoiceCloneV2)
-
- # 长文本转音频统计
- long_text_stats = self._get_task_statistics(LongTextAudio)
-
- return UserAudioStatisticsResponse(
- asr_stats=asr_stats,
- tts_stats=tts_stats,
- voice_clone_stats=voice_clone_stats,
- long_text_stats=long_text_stats
- )
-
- def _get_task_statistics(self, model) -> TaskStatistics:
- """
- 获取指定模型的任务统计
-
- Args:
- model: ORM模型类
-
- Returns:
- 任务统计信息
- """
- query = self.db.query(model).filter(model.user_id == self.user_id)
-
- total = query.count()
- pending = query.filter(model.status == "PENDING").count()
- processing = query.filter(model.status == "PROCESSING").count()
- succeeded = query.filter(model.status == "SUCCEEDED").count()
- failed = query.filter(model.status == "FAILED").count()
-
- # 计算总费用
- total_bill = self.db.query(
- func.sum(model.bill)
- ).filter(
- model.user_id == self.user_id
- ).scalar() or Decimal('0')
-
- # 计算总时长(如果有duration字段)
- total_duration = None
- if hasattr(model, 'duration'):
- total_duration = self.db.query(
- func.sum(model.duration)
- ).filter(
- model.user_id == self.user_id,
- model.status == "SUCCEEDED"
- ).scalar()
-
- return TaskStatistics(
- total=total,
- pending=pending,
- processing=processing,
- succeeded=succeeded,
- failed=failed,
- total_bill=total_bill,
- total_duration=total_duration
- )
|