""" 审查进度轮询接口 支持Celery任务状态查询和进度展示 """ import time import random from datetime import datetime from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from typing import Optional from celery.result import AsyncResult from foundation.base.celery_app import app task_progress_router = APIRouter(prefix="/sgsc", tags=["进度轮询"]) # 导入错误码定义 from .schemas.error_schemas import TaskProgressErrors class TaskProgressResponse(BaseModel): code: int data: dict def update_task_progress(callback_task_id: str) -> dict: """更新任务进度(模拟真实的处理过程)""" if callback_task_id not in uploaded_files: return None task_info = uploaded_files[callback_task_id] current_time = int(time.time()) # 根据时间模拟进度推进 time_elapsed = current_time - task_info.get("updated_at", current_time) # 定义各阶段的时间分配(总时长约30分钟) stage_durations = { "格式校验": 60, # 1分钟 "内容提取": 900, # 15分钟 "智能审查": 840 # 14分钟 } total_duration = sum(stage_durations.values()) # 计算当前应该处于哪个阶段 accumulated_time = 0 overall_progress = 0 stages = [] for stage_name, duration in stage_durations.items(): if time_elapsed > accumulated_time + duration: # 阶段已完成 stages.append({ "stage_name": stage_name, "progress": 100, "stage_status": "completed" }) accumulated_time += duration elif time_elapsed > accumulated_time: # 阶段进行中 stage_progress = min(100, int((time_elapsed - accumulated_time) / duration * 100)) stages.append({ "stage_name": stage_name, "progress": stage_progress, "stage_status": "processing" }) accumulated_time += duration else: # 阶段未开始 stages.append({ "stage_name": stage_name, "progress": 0, "stage_status": "pending" }) # 计算总进度 overall_progress = min(100, int(time_elapsed / total_duration * 100)) # 确定任务状态 if overall_progress >= 100: review_task_status = "completed" estimated_remaining = 0 else: review_task_status = "processing" estimated_remaining = max(0, total_duration - time_elapsed) # 更新任务信息 task_info.update({ "review_task_status": review_task_status, "overall_progress": overall_progress, "stages": stages, "updated_at": current_time, "estimated_remaining": estimated_remaining }) return task_info @task_progress_router.get("/task_progress/{callback_task_id}", response_model=TaskProgressResponse) async def task_progress( callback_task_id: str, user: str = Query(None) ): """ 任务进度轮询接口 """ try: # 验证参数 if user is None or not isinstance(user, str): raise TaskProgressErrors.missing_parameters() if not callback_task_id or not isinstance(callback_task_id, str): raise TaskProgressErrors.missing_parameters() # 检查callback_task_id格式(应该是UUID-时间戳格式) if len(callback_task_id) < 20 or callback_task_id.count('-') < 4: raise TaskProgressErrors.invalid_param_format() # 验证用户标识(应该是指定用户如user-001) valid_users = {"user-001", "user-002", "user-003"} # 可以配置化 if user == "" or user not in valid_users: raise TaskProgressErrors.invalid_user() # 检查任务是否存在 if callback_task_id not in uploaded_files: raise TaskProgressErrors.task_not_found() # 验证用户权限 task_info = uploaded_files[callback_task_id] if task_info.get("user") != user: raise TaskProgressErrors.invalid_user() # 更新进度 updated_task = update_task_progress(callback_task_id) return TaskProgressResponse( code=200, data={ "callback_task_id": callback_task_id, "user": user, "review_task_status": updated_task["review_task_status"], "overall_progress": updated_task["overall_progress"], "stages": updated_task["stages"], "updated_at": updated_task["updated_at"] } ) except HTTPException: raise except Exception as e: raise TaskProgressErrors.server_internal_error(e) @task_progress_router.post("/mock/advance_time") async def advance_time(seconds: int = 300): """Mock接口:推进时间(用于测试)""" for callback_task_id in list(uploaded_files.keys()): if "review_task_status" in uploaded_files[callback_task_id]: uploaded_files[callback_task_id]["updated_at"] -= seconds return {"message": f"时间推进了 {seconds} 秒"}