task_progress.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. 审查进度轮询接口
  3. 支持Celery任务状态查询和进度展示
  4. """
  5. import time
  6. import random
  7. from datetime import datetime
  8. from fastapi import APIRouter, HTTPException, Query
  9. from pydantic import BaseModel
  10. from typing import Optional
  11. from celery.result import AsyncResult
  12. from foundation.base.celery_app import app
  13. task_progress_router = APIRouter(prefix="/sgsc", tags=["进度轮询"])
  14. # 导入错误码定义
  15. from .schemas.error_schemas import TaskProgressErrors
  16. class TaskProgressResponse(BaseModel):
  17. code: int
  18. data: dict
  19. def update_task_progress(callback_task_id: str) -> dict:
  20. """更新任务进度(模拟真实的处理过程)"""
  21. if callback_task_id not in uploaded_files:
  22. return None
  23. task_info = uploaded_files[callback_task_id]
  24. current_time = int(time.time())
  25. # 根据时间模拟进度推进
  26. time_elapsed = current_time - task_info.get("updated_at", current_time)
  27. # 定义各阶段的时间分配(总时长约30分钟)
  28. stage_durations = {
  29. "格式校验": 60, # 1分钟
  30. "内容提取": 900, # 15分钟
  31. "智能审查": 840 # 14分钟
  32. }
  33. total_duration = sum(stage_durations.values())
  34. # 计算当前应该处于哪个阶段
  35. accumulated_time = 0
  36. overall_progress = 0
  37. stages = []
  38. for stage_name, duration in stage_durations.items():
  39. if time_elapsed > accumulated_time + duration:
  40. # 阶段已完成
  41. stages.append({
  42. "stage_name": stage_name,
  43. "progress": 100,
  44. "stage_status": "completed"
  45. })
  46. accumulated_time += duration
  47. elif time_elapsed > accumulated_time:
  48. # 阶段进行中
  49. stage_progress = min(100, int((time_elapsed - accumulated_time) / duration * 100))
  50. stages.append({
  51. "stage_name": stage_name,
  52. "progress": stage_progress,
  53. "stage_status": "processing"
  54. })
  55. accumulated_time += duration
  56. else:
  57. # 阶段未开始
  58. stages.append({
  59. "stage_name": stage_name,
  60. "progress": 0,
  61. "stage_status": "pending"
  62. })
  63. # 计算总进度
  64. overall_progress = min(100, int(time_elapsed / total_duration * 100))
  65. # 确定任务状态
  66. if overall_progress >= 100:
  67. review_task_status = "completed"
  68. estimated_remaining = 0
  69. else:
  70. review_task_status = "processing"
  71. estimated_remaining = max(0, total_duration - time_elapsed)
  72. # 更新任务信息
  73. task_info.update({
  74. "review_task_status": review_task_status,
  75. "overall_progress": overall_progress,
  76. "stages": stages,
  77. "updated_at": current_time,
  78. "estimated_remaining": estimated_remaining
  79. })
  80. return task_info
  81. @task_progress_router.get("/task_progress/{callback_task_id}", response_model=TaskProgressResponse)
  82. async def task_progress(
  83. callback_task_id: str,
  84. user: str = Query(None)
  85. ):
  86. """
  87. 任务进度轮询接口
  88. """
  89. try:
  90. # 验证参数
  91. if user is None or not isinstance(user, str):
  92. raise TaskProgressErrors.missing_parameters()
  93. if not callback_task_id or not isinstance(callback_task_id, str):
  94. raise TaskProgressErrors.missing_parameters()
  95. # 检查callback_task_id格式(应该是UUID-时间戳格式)
  96. if len(callback_task_id) < 20 or callback_task_id.count('-') < 4:
  97. raise TaskProgressErrors.invalid_param_format()
  98. # 验证用户标识(应该是指定用户如user-001)
  99. valid_users = {"user-001", "user-002", "user-003"} # 可以配置化
  100. if user == "" or user not in valid_users:
  101. raise TaskProgressErrors.invalid_user()
  102. # 检查任务是否存在
  103. if callback_task_id not in uploaded_files:
  104. raise TaskProgressErrors.task_not_found()
  105. # 验证用户权限
  106. task_info = uploaded_files[callback_task_id]
  107. if task_info.get("user") != user:
  108. raise TaskProgressErrors.invalid_user()
  109. # 更新进度
  110. updated_task = update_task_progress(callback_task_id)
  111. return TaskProgressResponse(
  112. code=200,
  113. data={
  114. "callback_task_id": callback_task_id,
  115. "user": user,
  116. "review_task_status": updated_task["review_task_status"],
  117. "overall_progress": updated_task["overall_progress"],
  118. "stages": updated_task["stages"],
  119. "updated_at": updated_task["updated_at"]
  120. }
  121. )
  122. except HTTPException:
  123. raise
  124. except Exception as e:
  125. raise TaskProgressErrors.server_internal_error(e)
  126. @task_progress_router.post("/mock/advance_time")
  127. async def advance_time(seconds: int = 300):
  128. """Mock接口:推进时间(用于测试)"""
  129. for callback_task_id in list(uploaded_files.keys()):
  130. if "review_task_status" in uploaded_files[callback_task_id]:
  131. uploaded_files[callback_task_id]["updated_at"] -= seconds
  132. return {"message": f"时间推进了 {seconds} 秒"}