task_progress.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. 审查进度轮询接口Mock实现
  3. 模拟任务进度更新,支持多阶段进度展示
  4. """
  5. import time
  6. import random
  7. from datetime import datetime
  8. from fastapi import APIRouter, HTTPException, Query
  9. from pydantic import BaseModel
  10. from typing import Optional
  11. task_progress_router = APIRouter(prefix="/sgsc", tags=["进度轮询Mock"])
  12. # 导入文件上传模块的存储
  13. try:
  14. from .file_upload import uploaded_files
  15. except ImportError:
  16. from views.construction_review.file_upload import uploaded_files
  17. # 导入错误码定义
  18. from .schemas.error_schemas import TaskProgressErrors
  19. class TaskProgressResponse(BaseModel):
  20. code: int
  21. data: dict
  22. def update_task_progress(callback_task_id: str) -> dict:
  23. """更新任务进度(模拟真实的处理过程)"""
  24. if callback_task_id not in uploaded_files:
  25. return None
  26. task_info = uploaded_files[callback_task_id]
  27. current_time = int(time.time())
  28. # 根据时间模拟进度推进
  29. time_elapsed = current_time - task_info.get("updated_at", current_time)
  30. # 定义各阶段的时间分配(总时长约30分钟)
  31. stage_durations = {
  32. "格式校验": 60, # 1分钟
  33. "内容提取": 900, # 15分钟
  34. "智能审查": 840 # 14分钟
  35. }
  36. total_duration = sum(stage_durations.values())
  37. # 计算当前应该处于哪个阶段
  38. accumulated_time = 0
  39. overall_progress = 0
  40. stages = []
  41. for stage_name, duration in stage_durations.items():
  42. if time_elapsed > accumulated_time + duration:
  43. # 阶段已完成
  44. stages.append({
  45. "stage_name": stage_name,
  46. "progress": 100,
  47. "stage_status": "completed"
  48. })
  49. accumulated_time += duration
  50. elif time_elapsed > accumulated_time:
  51. # 阶段进行中
  52. stage_progress = min(100, int((time_elapsed - accumulated_time) / duration * 100))
  53. stages.append({
  54. "stage_name": stage_name,
  55. "progress": stage_progress,
  56. "stage_status": "processing"
  57. })
  58. accumulated_time += duration
  59. else:
  60. # 阶段未开始
  61. stages.append({
  62. "stage_name": stage_name,
  63. "progress": 0,
  64. "stage_status": "pending"
  65. })
  66. # 计算总进度
  67. overall_progress = min(100, int(time_elapsed / total_duration * 100))
  68. # 确定任务状态
  69. if overall_progress >= 100:
  70. review_task_status = "completed"
  71. estimated_remaining = 0
  72. else:
  73. review_task_status = "processing"
  74. estimated_remaining = max(0, total_duration - time_elapsed)
  75. # 更新任务信息
  76. task_info.update({
  77. "review_task_status": review_task_status,
  78. "overall_progress": overall_progress,
  79. "stages": stages,
  80. "updated_at": current_time,
  81. "estimated_remaining": estimated_remaining
  82. })
  83. return task_info
  84. @task_progress_router.get("/task_progress/{callback_task_id}", response_model=TaskProgressResponse)
  85. async def task_progress(
  86. callback_task_id: str,
  87. user: str = Query(None)
  88. ):
  89. """
  90. Mock任务进度轮询接口
  91. """
  92. try:
  93. # 验证参数
  94. if user is None or not isinstance(user, str):
  95. raise TaskProgressErrors.missing_parameters()
  96. if not callback_task_id or not isinstance(callback_task_id, str):
  97. raise TaskProgressErrors.missing_parameters()
  98. # 检查callback_task_id格式(应该是UUID-时间戳格式)
  99. if len(callback_task_id) < 20 or callback_task_id.count('-') < 4:
  100. raise TaskProgressErrors.invalid_param_format()
  101. # 验证用户标识(应该是指定用户如user-001)
  102. valid_users = {"user-001", "user-002", "user-003"} # 可以配置化
  103. if user == "" or user not in valid_users:
  104. raise TaskProgressErrors.invalid_user()
  105. # 检查任务是否存在
  106. if callback_task_id not in uploaded_files:
  107. raise TaskProgressErrors.task_not_found()
  108. # 验证用户权限
  109. task_info = uploaded_files[callback_task_id]
  110. if task_info.get("user") != user:
  111. raise TaskProgressErrors.invalid_user()
  112. # 更新进度
  113. updated_task = update_task_progress(callback_task_id)
  114. return TaskProgressResponse(
  115. code=200,
  116. data={
  117. "callback_task_id": callback_task_id,
  118. "user": user,
  119. "review_task_status": updated_task["review_task_status"],
  120. "overall_progress": updated_task["overall_progress"],
  121. "stages": updated_task["stages"],
  122. "updated_at": updated_task["updated_at"],
  123. "estimated_remaining": updated_task["estimated_remaining"]
  124. }
  125. )
  126. except HTTPException:
  127. raise
  128. except Exception as e:
  129. raise TaskProgressErrors.server_internal_error(e)
  130. @task_progress_router.post("/mock/advance_time")
  131. async def advance_time(seconds: int = 300):
  132. """Mock接口:推进时间(用于测试)"""
  133. for callback_task_id in list(uploaded_files.keys()):
  134. if "review_task_status" in uploaded_files[callback_task_id]:
  135. uploaded_files[callback_task_id]["updated_at"] -= seconds
  136. return {"message": f"时间推进了 {seconds} 秒"}