task.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. """
  2. Task API router.
  3. Provides CRUD endpoints for task management.
  4. """
  5. import uuid
  6. import json
  7. from typing import List, Optional
  8. from fastapi import APIRouter, HTTPException, status, Query, Request
  9. from database import get_db_connection
  10. from schemas.task import (
  11. TaskCreate, TaskUpdate, TaskResponse,
  12. TaskAssignRequest, BatchAssignRequest, BatchAssignResponse,
  13. TaskAssignmentResponse, MyTasksResponse,
  14. AssignmentPreviewRequest, AssignmentPreviewResponse,
  15. DispatchRequest, DispatchResponse,
  16. )
  17. from models import Task
  18. from datetime import datetime
  19. from services.assignment_service import assignment_service
  20. router = APIRouter(
  21. prefix="/api/tasks",
  22. tags=["tasks"]
  23. )
  24. def calculate_progress(data_str: str, annotation_count: int) -> float:
  25. """计算任务进度"""
  26. try:
  27. data = json.loads(data_str) if isinstance(data_str, str) else data_str
  28. items = data.get('items', [])
  29. if not items:
  30. return 0.0
  31. return min(annotation_count / len(items), 1.0)
  32. except:
  33. return 0.0
  34. @router.get("", response_model=List[TaskResponse])
  35. async def list_tasks(
  36. request: Request,
  37. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  38. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status"),
  39. assigned_to: Optional[str] = Query(None, description="Filter by assigned user")
  40. ):
  41. """
  42. List all tasks with optional filters.
  43. Requires authentication.
  44. """
  45. with get_db_connection() as conn:
  46. cursor = conn.cursor()
  47. # Build query with filters
  48. query = """
  49. SELECT
  50. t.id,
  51. t.project_id,
  52. t.name,
  53. t.data,
  54. t.status,
  55. t.assigned_to,
  56. t.created_at,
  57. COUNT(a.id) as annotation_count
  58. FROM tasks t
  59. LEFT JOIN annotations a ON t.id = a.task_id
  60. WHERE 1=1
  61. """
  62. params = []
  63. if project_id:
  64. query += " AND t.project_id = ?"
  65. params.append(project_id)
  66. if status_filter:
  67. query += " AND t.status = ?"
  68. params.append(status_filter)
  69. if assigned_to:
  70. query += " AND t.assigned_to = ?"
  71. params.append(assigned_to)
  72. query += " GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at ORDER BY t.created_at DESC"
  73. cursor.execute(query, tuple(params))
  74. rows = cursor.fetchall()
  75. tasks = []
  76. for row in rows:
  77. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  78. progress = calculate_progress(row["data"], row["annotation_count"])
  79. tasks.append(TaskResponse(
  80. id=row["id"],
  81. project_id=row["project_id"],
  82. name=row["name"],
  83. data=data,
  84. status=row["status"],
  85. assigned_to=row["assigned_to"],
  86. created_at=row["created_at"],
  87. progress=progress
  88. ))
  89. return tasks
  90. @router.post("", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
  91. async def create_task(request: Request, task: TaskCreate):
  92. """
  93. Create a new task.
  94. Requires authentication.
  95. """
  96. task_id = f"task_{uuid.uuid4().hex[:12]}"
  97. user = request.state.user
  98. assigned_to = task.assigned_to if task.assigned_to else user["id"]
  99. with get_db_connection() as conn:
  100. cursor = conn.cursor()
  101. # Verify project exists
  102. cursor.execute("SELECT id FROM projects WHERE id = ?", (task.project_id,))
  103. if not cursor.fetchone():
  104. raise HTTPException(
  105. status_code=status.HTTP_404_NOT_FOUND,
  106. detail=f"Project with id '{task.project_id}' not found"
  107. )
  108. data_json = json.dumps(task.data)
  109. cursor.execute("""
  110. INSERT INTO tasks (id, project_id, name, data, status, assigned_to)
  111. VALUES (?, ?, ?, ?, 'pending', ?)
  112. """, (task_id, task.project_id, task.name, data_json, assigned_to))
  113. cursor.execute("""
  114. SELECT id, project_id, name, data, status, assigned_to, created_at
  115. FROM tasks WHERE id = ?
  116. """, (task_id,))
  117. row = cursor.fetchone()
  118. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  119. return TaskResponse(
  120. id=row["id"],
  121. project_id=row["project_id"],
  122. name=row["name"],
  123. data=data,
  124. status=row["status"],
  125. assigned_to=row["assigned_to"],
  126. created_at=row["created_at"],
  127. progress=0.0
  128. )
  129. @router.get("/my-tasks", response_model=MyTasksResponse)
  130. async def get_my_tasks(
  131. request: Request,
  132. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  133. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status")
  134. ):
  135. """
  136. Get tasks assigned to the current user.
  137. Requires authentication.
  138. 标注人员只能看到分配给自己的任务。
  139. """
  140. user = request.state.user
  141. user_id = user["id"]
  142. with get_db_connection() as conn:
  143. cursor = conn.cursor()
  144. # 构建查询条件
  145. where_clauses = ["t.assigned_to = ?"]
  146. params = [user_id]
  147. if project_id:
  148. where_clauses.append("t.project_id = ?")
  149. params.append(project_id)
  150. if status_filter:
  151. where_clauses.append("t.status = ?")
  152. params.append(status_filter)
  153. where_sql = " AND ".join(where_clauses)
  154. # 查询任务列表
  155. query = f"""
  156. SELECT
  157. t.id,
  158. t.project_id,
  159. t.name,
  160. t.data,
  161. t.status,
  162. t.assigned_to,
  163. t.created_at,
  164. COUNT(a.id) as annotation_count
  165. FROM tasks t
  166. LEFT JOIN annotations a ON t.id = a.task_id
  167. WHERE {where_sql}
  168. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  169. ORDER BY t.created_at DESC
  170. """
  171. cursor.execute(query, tuple(params))
  172. rows = cursor.fetchall()
  173. tasks = []
  174. completed = 0
  175. in_progress = 0
  176. pending = 0
  177. for row in rows:
  178. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  179. progress = calculate_progress(row["data"], row["annotation_count"])
  180. task_status = row["status"]
  181. if task_status == "completed":
  182. completed += 1
  183. elif task_status == "in_progress":
  184. in_progress += 1
  185. else:
  186. pending += 1
  187. tasks.append(TaskResponse(
  188. id=row["id"],
  189. project_id=row["project_id"],
  190. name=row["name"],
  191. data=data,
  192. status=task_status,
  193. assigned_to=row["assigned_to"],
  194. created_at=row["created_at"],
  195. progress=progress
  196. ))
  197. return MyTasksResponse(
  198. tasks=tasks,
  199. total=len(tasks),
  200. completed=completed,
  201. in_progress=in_progress,
  202. pending=pending
  203. )
  204. @router.post("/batch-assign", response_model=BatchAssignResponse)
  205. async def batch_assign_tasks(request: Request, assign_data: BatchAssignRequest):
  206. """
  207. Batch assign multiple tasks to multiple users.
  208. Requires authentication and admin role.
  209. 支持两种分配模式:
  210. - round_robin: 轮询分配,按顺序将任务分配给用户
  211. - equal: 平均分配,尽量使每个用户分配的任务数量相等
  212. """
  213. user = request.state.user
  214. # 只有管理员可以分配任务
  215. if user["role"] != "admin":
  216. raise HTTPException(
  217. status_code=status.HTTP_403_FORBIDDEN,
  218. detail="只有管理员可以分配任务"
  219. )
  220. if assign_data.mode not in ["round_robin", "equal"]:
  221. raise HTTPException(
  222. status_code=status.HTTP_400_BAD_REQUEST,
  223. detail="分配模式必须是 'round_robin' 或 'equal'"
  224. )
  225. with get_db_connection() as conn:
  226. cursor = conn.cursor()
  227. # 验证所有用户存在
  228. valid_user_ids = []
  229. for user_id in assign_data.user_ids:
  230. cursor.execute("SELECT id FROM users WHERE id = ?", (user_id,))
  231. if cursor.fetchone():
  232. valid_user_ids.append(user_id)
  233. if not valid_user_ids:
  234. raise HTTPException(
  235. status_code=status.HTTP_400_BAD_REQUEST,
  236. detail="没有有效的用户可以分配任务"
  237. )
  238. # 验证所有任务存在
  239. valid_task_ids = []
  240. for task_id in assign_data.task_ids:
  241. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  242. if cursor.fetchone():
  243. valid_task_ids.append(task_id)
  244. if not valid_task_ids:
  245. raise HTTPException(
  246. status_code=status.HTTP_400_BAD_REQUEST,
  247. detail="没有有效的任务可以分配"
  248. )
  249. assignments = []
  250. errors = []
  251. assigned_at = datetime.now()
  252. if assign_data.mode == "round_robin":
  253. # 轮询分配
  254. for i, task_id in enumerate(valid_task_ids):
  255. user_index = i % len(valid_user_ids)
  256. target_user_id = valid_user_ids[user_index]
  257. try:
  258. cursor.execute("""
  259. UPDATE tasks SET assigned_to = ? WHERE id = ?
  260. """, (target_user_id, task_id))
  261. assignments.append(TaskAssignmentResponse(
  262. task_id=task_id,
  263. assigned_to=target_user_id,
  264. assigned_by=user["id"],
  265. assigned_at=assigned_at
  266. ))
  267. except Exception as e:
  268. errors.append({
  269. "task_id": task_id,
  270. "error": str(e)
  271. })
  272. else: # equal 模式
  273. # 平均分配:计算每个用户应该分配的任务数
  274. num_tasks = len(valid_task_ids)
  275. num_users = len(valid_user_ids)
  276. base_count = num_tasks // num_users
  277. extra_count = num_tasks % num_users
  278. task_index = 0
  279. for user_index, target_user_id in enumerate(valid_user_ids):
  280. # 前 extra_count 个用户多分配一个任务
  281. count = base_count + (1 if user_index < extra_count else 0)
  282. for _ in range(count):
  283. if task_index >= len(valid_task_ids):
  284. break
  285. task_id = valid_task_ids[task_index]
  286. task_index += 1
  287. try:
  288. cursor.execute("""
  289. UPDATE tasks SET assigned_to = ? WHERE id = ?
  290. """, (target_user_id, task_id))
  291. assignments.append(TaskAssignmentResponse(
  292. task_id=task_id,
  293. assigned_to=target_user_id,
  294. assigned_by=user["id"],
  295. assigned_at=assigned_at
  296. ))
  297. except Exception as e:
  298. errors.append({
  299. "task_id": task_id,
  300. "error": str(e)
  301. })
  302. return BatchAssignResponse(
  303. success_count=len(assignments),
  304. failed_count=len(errors),
  305. assignments=assignments,
  306. errors=errors
  307. )
  308. @router.get("/{task_id}", response_model=TaskResponse)
  309. async def get_task(request: Request, task_id: str):
  310. """
  311. Get task by ID.
  312. Requires authentication.
  313. """
  314. with get_db_connection() as conn:
  315. cursor = conn.cursor()
  316. cursor.execute("""
  317. SELECT
  318. t.id,
  319. t.project_id,
  320. t.name,
  321. t.data,
  322. t.status,
  323. t.assigned_to,
  324. t.created_at,
  325. COUNT(a.id) as annotation_count
  326. FROM tasks t
  327. LEFT JOIN annotations a ON t.id = a.task_id
  328. WHERE t.id = ?
  329. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  330. """, (task_id,))
  331. row = cursor.fetchone()
  332. if not row:
  333. raise HTTPException(
  334. status_code=status.HTTP_404_NOT_FOUND,
  335. detail=f"Task with id '{task_id}' not found"
  336. )
  337. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  338. progress = calculate_progress(row["data"], row["annotation_count"])
  339. return TaskResponse(
  340. id=row["id"],
  341. project_id=row["project_id"],
  342. name=row["name"],
  343. data=data,
  344. status=row["status"],
  345. assigned_to=row["assigned_to"],
  346. created_at=row["created_at"],
  347. progress=progress
  348. )
  349. @router.put("/{task_id}", response_model=TaskResponse)
  350. async def update_task(request: Request, task_id: str, task: TaskUpdate):
  351. """
  352. Update an existing task.
  353. Requires authentication.
  354. """
  355. with get_db_connection() as conn:
  356. cursor = conn.cursor()
  357. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  358. if not cursor.fetchone():
  359. raise HTTPException(
  360. status_code=status.HTTP_404_NOT_FOUND,
  361. detail=f"Task with id '{task_id}' not found"
  362. )
  363. update_fields = []
  364. update_values = []
  365. if task.name is not None:
  366. update_fields.append("name = ?")
  367. update_values.append(task.name)
  368. if task.data is not None:
  369. update_fields.append("data = ?")
  370. update_values.append(json.dumps(task.data))
  371. if task.status is not None:
  372. update_fields.append("status = ?")
  373. update_values.append(task.status)
  374. if task.assigned_to is not None:
  375. update_fields.append("assigned_to = ?")
  376. update_values.append(task.assigned_to)
  377. if update_fields:
  378. update_values.append(task_id)
  379. cursor.execute(f"""
  380. UPDATE tasks SET {', '.join(update_fields)} WHERE id = ?
  381. """, tuple(update_values))
  382. cursor.execute("""
  383. SELECT
  384. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  385. COUNT(a.id) as annotation_count
  386. FROM tasks t
  387. LEFT JOIN annotations a ON t.id = a.task_id
  388. WHERE t.id = ?
  389. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  390. """, (task_id,))
  391. row = cursor.fetchone()
  392. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  393. progress = calculate_progress(row["data"], row["annotation_count"])
  394. return TaskResponse(
  395. id=row["id"],
  396. project_id=row["project_id"],
  397. name=row["name"],
  398. data=data,
  399. status=row["status"],
  400. assigned_to=row["assigned_to"],
  401. created_at=row["created_at"],
  402. progress=progress
  403. )
  404. @router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
  405. async def delete_task(request: Request, task_id: str):
  406. """
  407. Delete a task and all associated annotations.
  408. Requires authentication and admin role.
  409. """
  410. user = request.state.user
  411. if user["role"] != "admin":
  412. raise HTTPException(
  413. status_code=status.HTTP_403_FORBIDDEN,
  414. detail="只有管理员可以删除任务"
  415. )
  416. with get_db_connection() as conn:
  417. cursor = conn.cursor()
  418. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  419. if not cursor.fetchone():
  420. raise HTTPException(
  421. status_code=status.HTTP_404_NOT_FOUND,
  422. detail=f"Task with id '{task_id}' not found"
  423. )
  424. cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
  425. return None
  426. @router.get("/projects/{project_id}/tasks", response_model=List[TaskResponse])
  427. async def get_project_tasks(request: Request, project_id: str):
  428. """
  429. Get all tasks for a specific project.
  430. Requires authentication.
  431. """
  432. with get_db_connection() as conn:
  433. cursor = conn.cursor()
  434. cursor.execute("SELECT id FROM projects WHERE id = ?", (project_id,))
  435. if not cursor.fetchone():
  436. raise HTTPException(
  437. status_code=status.HTTP_404_NOT_FOUND,
  438. detail=f"Project with id '{project_id}' not found"
  439. )
  440. cursor.execute("""
  441. SELECT
  442. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  443. COUNT(a.id) as annotation_count
  444. FROM tasks t
  445. LEFT JOIN annotations a ON t.id = a.task_id
  446. WHERE t.project_id = ?
  447. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  448. ORDER BY t.created_at DESC
  449. """, (project_id,))
  450. rows = cursor.fetchall()
  451. tasks = []
  452. for row in rows:
  453. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  454. progress = calculate_progress(row["data"], row["annotation_count"])
  455. tasks.append(TaskResponse(
  456. id=row["id"],
  457. project_id=row["project_id"],
  458. name=row["name"],
  459. data=data,
  460. status=row["status"],
  461. assigned_to=row["assigned_to"],
  462. created_at=row["created_at"],
  463. progress=progress
  464. ))
  465. return tasks
  466. @router.put("/{task_id}/assign", response_model=TaskAssignmentResponse)
  467. async def assign_task(request: Request, task_id: str, assign_data: TaskAssignRequest):
  468. """
  469. Assign a task to a specific user.
  470. Requires authentication and admin role.
  471. 将任务分配给指定用户,记录分配时间和分配人。
  472. """
  473. user = request.state.user
  474. # 只有管理员可以分配任务
  475. if user["role"] != "admin":
  476. raise HTTPException(
  477. status_code=status.HTTP_403_FORBIDDEN,
  478. detail="只有管理员可以分配任务"
  479. )
  480. with get_db_connection() as conn:
  481. cursor = conn.cursor()
  482. # 验证任务存在
  483. cursor.execute("SELECT id, assigned_to FROM tasks WHERE id = ?", (task_id,))
  484. task_row = cursor.fetchone()
  485. if not task_row:
  486. raise HTTPException(
  487. status_code=status.HTTP_404_NOT_FOUND,
  488. detail=f"任务 '{task_id}' 不存在"
  489. )
  490. # 验证用户存在
  491. cursor.execute("SELECT id FROM users WHERE id = ?", (assign_data.user_id,))
  492. if not cursor.fetchone():
  493. raise HTTPException(
  494. status_code=status.HTTP_404_NOT_FOUND,
  495. detail=f"用户 '{assign_data.user_id}' 不存在"
  496. )
  497. # 更新任务分配
  498. cursor.execute("""
  499. UPDATE tasks SET assigned_to = ? WHERE id = ?
  500. """, (assign_data.user_id, task_id))
  501. assigned_at = datetime.now()
  502. return TaskAssignmentResponse(
  503. task_id=task_id,
  504. assigned_to=assign_data.user_id,
  505. assigned_by=user["id"],
  506. assigned_at=assigned_at
  507. )
  508. @router.post("/preview-assignment/{project_id}", response_model=AssignmentPreviewResponse)
  509. async def preview_assignment(
  510. request: Request,
  511. project_id: str,
  512. preview_request: AssignmentPreviewRequest
  513. ):
  514. """
  515. Preview task assignment distribution for a project.
  516. Shows how tasks would be distributed among selected annotators
  517. without actually performing the assignment.
  518. Requires authentication and admin role.
  519. """
  520. user = request.state.user
  521. if user["role"] != "admin":
  522. raise HTTPException(
  523. status_code=status.HTTP_403_FORBIDDEN,
  524. detail="只有管理员可以预览任务分配"
  525. )
  526. try:
  527. preview = assignment_service.preview_assignment(
  528. project_id=project_id,
  529. annotator_ids=preview_request.user_ids
  530. )
  531. return preview
  532. except ValueError as e:
  533. raise HTTPException(
  534. status_code=status.HTTP_400_BAD_REQUEST,
  535. detail=str(e)
  536. )
  537. @router.post("/dispatch/{project_id}", response_model=DispatchResponse)
  538. async def dispatch_tasks(
  539. request: Request,
  540. project_id: str,
  541. dispatch_request: DispatchRequest
  542. ):
  543. """
  544. Dispatch (assign) all unassigned tasks in a project to selected annotators.
  545. This is a one-click operation that:
  546. 1. Distributes tasks evenly among selected annotators
  547. 2. Updates project status to 'in_progress'
  548. Only works when project is in 'ready' status.
  549. Requires authentication and admin role.
  550. """
  551. user = request.state.user
  552. if user["role"] != "admin":
  553. raise HTTPException(
  554. status_code=status.HTTP_403_FORBIDDEN,
  555. detail="只有管理员可以分发任务"
  556. )
  557. try:
  558. result = assignment_service.dispatch_tasks(
  559. project_id=project_id,
  560. annotator_ids=dispatch_request.user_ids,
  561. admin_id=user["id"]
  562. )
  563. return result
  564. except ValueError as e:
  565. raise HTTPException(
  566. status_code=status.HTTP_400_BAD_REQUEST,
  567. detail=str(e)
  568. )