task.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. """
  2. Task API router.
  3. Provides CRUD endpoints for task management.
  4. """
  5. import uuid
  6. import json
  7. from typing import List, Optional
  8. from fastapi import APIRouter, HTTPException, status, Query, Request
  9. from database import get_db_connection
  10. from schemas.task import (
  11. TaskCreate, TaskUpdate, TaskResponse,
  12. TaskAssignRequest, BatchAssignRequest, BatchAssignResponse,
  13. TaskAssignmentResponse, MyTasksResponse,
  14. AssignmentPreviewRequest, AssignmentPreviewResponse,
  15. DispatchRequest, DispatchResponse,
  16. )
  17. from models import Task
  18. from datetime import datetime
  19. from services.assignment_service import assignment_service
  20. router = APIRouter(
  21. prefix="/api/tasks",
  22. tags=["tasks"]
  23. )
  24. def calculate_progress(data_str: str, annotation_count: int) -> float:
  25. """计算任务进度"""
  26. try:
  27. data = json.loads(data_str) if isinstance(data_str, str) else data_str
  28. items = data.get('items', [])
  29. if not items:
  30. return 0.0
  31. return min(annotation_count / len(items), 1.0)
  32. except:
  33. return 0.0
  34. @router.get("", response_model=List[TaskResponse])
  35. async def list_tasks(
  36. request: Request,
  37. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  38. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status"),
  39. assigned_to: Optional[str] = Query(None, description="Filter by assigned user")
  40. ):
  41. """
  42. List tasks with optional filters.
  43. For admin users: Returns all tasks matching the filters.
  44. For annotator users: Returns only tasks assigned to them (ignores assigned_to filter).
  45. Requires authentication.
  46. """
  47. user = request.state.user
  48. user_id = user["id"]
  49. user_role = user["role"]
  50. with get_db_connection() as conn:
  51. cursor = conn.cursor()
  52. # Build query with filters
  53. query = """
  54. SELECT
  55. t.id,
  56. t.project_id,
  57. t.name,
  58. t.data,
  59. t.status,
  60. t.assigned_to,
  61. t.created_at,
  62. COUNT(a.id) as annotation_count
  63. FROM tasks t
  64. LEFT JOIN annotations a ON t.id = a.task_id
  65. WHERE 1=1
  66. """
  67. params = []
  68. if project_id:
  69. query += " AND t.project_id = ?"
  70. params.append(project_id)
  71. if status_filter:
  72. query += " AND t.status = ?"
  73. params.append(status_filter)
  74. # 标注员只能看到分配给自己的任务
  75. if user_role != "admin":
  76. query += " AND t.assigned_to = ?"
  77. params.append(user_id)
  78. elif assigned_to:
  79. # 管理员可以按 assigned_to 过滤
  80. query += " AND t.assigned_to = ?"
  81. params.append(assigned_to)
  82. query += " GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at ORDER BY t.created_at DESC"
  83. cursor.execute(query, tuple(params))
  84. rows = cursor.fetchall()
  85. tasks = []
  86. for row in rows:
  87. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  88. progress = calculate_progress(row["data"], row["annotation_count"])
  89. tasks.append(TaskResponse(
  90. id=row["id"],
  91. project_id=row["project_id"],
  92. name=row["name"],
  93. data=data,
  94. status=row["status"],
  95. assigned_to=row["assigned_to"],
  96. created_at=row["created_at"],
  97. progress=progress
  98. ))
  99. return tasks
  100. @router.post("", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
  101. async def create_task(request: Request, task: TaskCreate):
  102. """
  103. Create a new task.
  104. Requires authentication.
  105. """
  106. task_id = f"task_{uuid.uuid4().hex[:12]}"
  107. user = request.state.user
  108. assigned_to = task.assigned_to if task.assigned_to else user["id"]
  109. with get_db_connection() as conn:
  110. cursor = conn.cursor()
  111. # Verify project exists
  112. cursor.execute("SELECT id FROM projects WHERE id = ?", (task.project_id,))
  113. if not cursor.fetchone():
  114. raise HTTPException(
  115. status_code=status.HTTP_404_NOT_FOUND,
  116. detail=f"Project with id '{task.project_id}' not found"
  117. )
  118. data_json = json.dumps(task.data)
  119. cursor.execute("""
  120. INSERT INTO tasks (id, project_id, name, data, status, assigned_to)
  121. VALUES (?, ?, ?, ?, 'pending', ?)
  122. """, (task_id, task.project_id, task.name, data_json, assigned_to))
  123. cursor.execute("""
  124. SELECT id, project_id, name, data, status, assigned_to, created_at
  125. FROM tasks WHERE id = ?
  126. """, (task_id,))
  127. row = cursor.fetchone()
  128. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  129. return TaskResponse(
  130. id=row["id"],
  131. project_id=row["project_id"],
  132. name=row["name"],
  133. data=data,
  134. status=row["status"],
  135. assigned_to=row["assigned_to"],
  136. created_at=row["created_at"],
  137. progress=0.0
  138. )
  139. @router.get("/my-tasks", response_model=MyTasksResponse)
  140. async def get_my_tasks(
  141. request: Request,
  142. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  143. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status")
  144. ):
  145. """
  146. Get tasks assigned to the current user.
  147. Requires authentication.
  148. 标注人员只能看到分配给自己的任务。
  149. """
  150. user = request.state.user
  151. user_id = user["id"]
  152. with get_db_connection() as conn:
  153. cursor = conn.cursor()
  154. # 构建查询条件
  155. where_clauses = ["t.assigned_to = ?"]
  156. params = [user_id]
  157. if project_id:
  158. where_clauses.append("t.project_id = ?")
  159. params.append(project_id)
  160. if status_filter:
  161. where_clauses.append("t.status = ?")
  162. params.append(status_filter)
  163. where_sql = " AND ".join(where_clauses)
  164. # 查询任务列表
  165. query = f"""
  166. SELECT
  167. t.id,
  168. t.project_id,
  169. t.name,
  170. t.data,
  171. t.status,
  172. t.assigned_to,
  173. t.created_at,
  174. COUNT(a.id) as annotation_count
  175. FROM tasks t
  176. LEFT JOIN annotations a ON t.id = a.task_id
  177. WHERE {where_sql}
  178. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  179. ORDER BY t.created_at DESC
  180. """
  181. cursor.execute(query, tuple(params))
  182. rows = cursor.fetchall()
  183. tasks = []
  184. completed = 0
  185. in_progress = 0
  186. pending = 0
  187. for row in rows:
  188. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  189. progress = calculate_progress(row["data"], row["annotation_count"])
  190. task_status = row["status"]
  191. if task_status == "completed":
  192. completed += 1
  193. elif task_status == "in_progress":
  194. in_progress += 1
  195. else:
  196. pending += 1
  197. tasks.append(TaskResponse(
  198. id=row["id"],
  199. project_id=row["project_id"],
  200. name=row["name"],
  201. data=data,
  202. status=task_status,
  203. assigned_to=row["assigned_to"],
  204. created_at=row["created_at"],
  205. progress=progress
  206. ))
  207. return MyTasksResponse(
  208. tasks=tasks,
  209. total=len(tasks),
  210. completed=completed,
  211. in_progress=in_progress,
  212. pending=pending
  213. )
  214. @router.post("/batch-assign", response_model=BatchAssignResponse)
  215. async def batch_assign_tasks(request: Request, assign_data: BatchAssignRequest):
  216. """
  217. Batch assign multiple tasks to multiple users.
  218. Requires authentication and admin role.
  219. 支持两种分配模式:
  220. - round_robin: 轮询分配,按顺序将任务分配给用户
  221. - equal: 平均分配,尽量使每个用户分配的任务数量相等
  222. """
  223. user = request.state.user
  224. # 只有管理员可以分配任务
  225. if user["role"] != "admin":
  226. raise HTTPException(
  227. status_code=status.HTTP_403_FORBIDDEN,
  228. detail="只有管理员可以分配任务"
  229. )
  230. if assign_data.mode not in ["round_robin", "equal"]:
  231. raise HTTPException(
  232. status_code=status.HTTP_400_BAD_REQUEST,
  233. detail="分配模式必须是 'round_robin' 或 'equal'"
  234. )
  235. with get_db_connection() as conn:
  236. cursor = conn.cursor()
  237. # 验证所有用户存在
  238. valid_user_ids = []
  239. for user_id in assign_data.user_ids:
  240. cursor.execute("SELECT id FROM users WHERE id = ?", (user_id,))
  241. if cursor.fetchone():
  242. valid_user_ids.append(user_id)
  243. if not valid_user_ids:
  244. raise HTTPException(
  245. status_code=status.HTTP_400_BAD_REQUEST,
  246. detail="没有有效的用户可以分配任务"
  247. )
  248. # 验证所有任务存在
  249. valid_task_ids = []
  250. for task_id in assign_data.task_ids:
  251. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  252. if cursor.fetchone():
  253. valid_task_ids.append(task_id)
  254. if not valid_task_ids:
  255. raise HTTPException(
  256. status_code=status.HTTP_400_BAD_REQUEST,
  257. detail="没有有效的任务可以分配"
  258. )
  259. assignments = []
  260. errors = []
  261. assigned_at = datetime.now()
  262. if assign_data.mode == "round_robin":
  263. # 轮询分配
  264. for i, task_id in enumerate(valid_task_ids):
  265. user_index = i % len(valid_user_ids)
  266. target_user_id = valid_user_ids[user_index]
  267. try:
  268. cursor.execute("""
  269. UPDATE tasks SET assigned_to = ? WHERE id = ?
  270. """, (target_user_id, task_id))
  271. assignments.append(TaskAssignmentResponse(
  272. task_id=task_id,
  273. assigned_to=target_user_id,
  274. assigned_by=user["id"],
  275. assigned_at=assigned_at
  276. ))
  277. except Exception as e:
  278. errors.append({
  279. "task_id": task_id,
  280. "error": str(e)
  281. })
  282. else: # equal 模式
  283. # 平均分配:计算每个用户应该分配的任务数
  284. num_tasks = len(valid_task_ids)
  285. num_users = len(valid_user_ids)
  286. base_count = num_tasks // num_users
  287. extra_count = num_tasks % num_users
  288. task_index = 0
  289. for user_index, target_user_id in enumerate(valid_user_ids):
  290. # 前 extra_count 个用户多分配一个任务
  291. count = base_count + (1 if user_index < extra_count else 0)
  292. for _ in range(count):
  293. if task_index >= len(valid_task_ids):
  294. break
  295. task_id = valid_task_ids[task_index]
  296. task_index += 1
  297. try:
  298. cursor.execute("""
  299. UPDATE tasks SET assigned_to = ? WHERE id = ?
  300. """, (target_user_id, task_id))
  301. assignments.append(TaskAssignmentResponse(
  302. task_id=task_id,
  303. assigned_to=target_user_id,
  304. assigned_by=user["id"],
  305. assigned_at=assigned_at
  306. ))
  307. except Exception as e:
  308. errors.append({
  309. "task_id": task_id,
  310. "error": str(e)
  311. })
  312. return BatchAssignResponse(
  313. success_count=len(assignments),
  314. failed_count=len(errors),
  315. assignments=assignments,
  316. errors=errors
  317. )
  318. @router.get("/{task_id}", response_model=TaskResponse)
  319. async def get_task(request: Request, task_id: str):
  320. """
  321. Get task by ID.
  322. Requires authentication.
  323. """
  324. with get_db_connection() as conn:
  325. cursor = conn.cursor()
  326. cursor.execute("""
  327. SELECT
  328. t.id,
  329. t.project_id,
  330. t.name,
  331. t.data,
  332. t.status,
  333. t.assigned_to,
  334. t.created_at,
  335. COUNT(a.id) as annotation_count
  336. FROM tasks t
  337. LEFT JOIN annotations a ON t.id = a.task_id
  338. WHERE t.id = ?
  339. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  340. """, (task_id,))
  341. row = cursor.fetchone()
  342. if not row:
  343. raise HTTPException(
  344. status_code=status.HTTP_404_NOT_FOUND,
  345. detail=f"Task with id '{task_id}' not found"
  346. )
  347. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  348. progress = calculate_progress(row["data"], row["annotation_count"])
  349. return TaskResponse(
  350. id=row["id"],
  351. project_id=row["project_id"],
  352. name=row["name"],
  353. data=data,
  354. status=row["status"],
  355. assigned_to=row["assigned_to"],
  356. created_at=row["created_at"],
  357. progress=progress
  358. )
  359. @router.put("/{task_id}", response_model=TaskResponse)
  360. async def update_task(request: Request, task_id: str, task: TaskUpdate):
  361. """
  362. Update an existing task.
  363. Requires authentication.
  364. """
  365. with get_db_connection() as conn:
  366. cursor = conn.cursor()
  367. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  368. if not cursor.fetchone():
  369. raise HTTPException(
  370. status_code=status.HTTP_404_NOT_FOUND,
  371. detail=f"Task with id '{task_id}' not found"
  372. )
  373. update_fields = []
  374. update_values = []
  375. if task.name is not None:
  376. update_fields.append("name = ?")
  377. update_values.append(task.name)
  378. if task.data is not None:
  379. update_fields.append("data = ?")
  380. update_values.append(json.dumps(task.data))
  381. if task.status is not None:
  382. update_fields.append("status = ?")
  383. update_values.append(task.status)
  384. if task.assigned_to is not None:
  385. update_fields.append("assigned_to = ?")
  386. update_values.append(task.assigned_to)
  387. if update_fields:
  388. update_values.append(task_id)
  389. cursor.execute(f"""
  390. UPDATE tasks SET {', '.join(update_fields)} WHERE id = ?
  391. """, tuple(update_values))
  392. cursor.execute("""
  393. SELECT
  394. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  395. COUNT(a.id) as annotation_count
  396. FROM tasks t
  397. LEFT JOIN annotations a ON t.id = a.task_id
  398. WHERE t.id = ?
  399. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  400. """, (task_id,))
  401. row = cursor.fetchone()
  402. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  403. progress = calculate_progress(row["data"], row["annotation_count"])
  404. return TaskResponse(
  405. id=row["id"],
  406. project_id=row["project_id"],
  407. name=row["name"],
  408. data=data,
  409. status=row["status"],
  410. assigned_to=row["assigned_to"],
  411. created_at=row["created_at"],
  412. progress=progress
  413. )
  414. @router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
  415. async def delete_task(request: Request, task_id: str):
  416. """
  417. Delete a task and all associated annotations.
  418. Requires authentication and admin role.
  419. """
  420. user = request.state.user
  421. if user["role"] != "admin":
  422. raise HTTPException(
  423. status_code=status.HTTP_403_FORBIDDEN,
  424. detail="只有管理员可以删除任务"
  425. )
  426. with get_db_connection() as conn:
  427. cursor = conn.cursor()
  428. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  429. if not cursor.fetchone():
  430. raise HTTPException(
  431. status_code=status.HTTP_404_NOT_FOUND,
  432. detail=f"Task with id '{task_id}' not found"
  433. )
  434. cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
  435. return None
  436. @router.get("/projects/{project_id}/tasks", response_model=List[TaskResponse])
  437. async def get_project_tasks(request: Request, project_id: str):
  438. """
  439. Get tasks for a specific project.
  440. For admin users: Returns all tasks in the project.
  441. For annotator users: Returns only tasks assigned to them.
  442. Requires authentication.
  443. """
  444. user = request.state.user
  445. user_id = user["id"]
  446. user_role = user["role"]
  447. with get_db_connection() as conn:
  448. cursor = conn.cursor()
  449. cursor.execute("SELECT id FROM projects WHERE id = ?", (project_id,))
  450. if not cursor.fetchone():
  451. raise HTTPException(
  452. status_code=status.HTTP_404_NOT_FOUND,
  453. detail=f"Project with id '{project_id}' not found"
  454. )
  455. if user_role == "admin":
  456. # 管理员:返回项目的所有任务
  457. cursor.execute("""
  458. SELECT
  459. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  460. COUNT(a.id) as annotation_count
  461. FROM tasks t
  462. LEFT JOIN annotations a ON t.id = a.task_id
  463. WHERE t.project_id = ?
  464. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  465. ORDER BY t.created_at DESC
  466. """, (project_id,))
  467. else:
  468. # 标注员:只返回分配给自己的任务
  469. cursor.execute("""
  470. SELECT
  471. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  472. COUNT(a.id) as annotation_count
  473. FROM tasks t
  474. LEFT JOIN annotations a ON t.id = a.task_id
  475. WHERE t.project_id = ? AND t.assigned_to = ?
  476. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  477. ORDER BY t.created_at DESC
  478. """, (project_id, user_id))
  479. rows = cursor.fetchall()
  480. tasks = []
  481. for row in rows:
  482. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  483. progress = calculate_progress(row["data"], row["annotation_count"])
  484. tasks.append(TaskResponse(
  485. id=row["id"],
  486. project_id=row["project_id"],
  487. name=row["name"],
  488. data=data,
  489. status=row["status"],
  490. assigned_to=row["assigned_to"],
  491. created_at=row["created_at"],
  492. progress=progress
  493. ))
  494. return tasks
  495. @router.put("/{task_id}/assign", response_model=TaskAssignmentResponse)
  496. async def assign_task(request: Request, task_id: str, assign_data: TaskAssignRequest):
  497. """
  498. Assign a task to a specific user.
  499. Requires authentication and admin role.
  500. 将任务分配给指定用户,记录分配时间和分配人。
  501. """
  502. user = request.state.user
  503. # 只有管理员可以分配任务
  504. if user["role"] != "admin":
  505. raise HTTPException(
  506. status_code=status.HTTP_403_FORBIDDEN,
  507. detail="只有管理员可以分配任务"
  508. )
  509. with get_db_connection() as conn:
  510. cursor = conn.cursor()
  511. # 验证任务存在
  512. cursor.execute("SELECT id, assigned_to FROM tasks WHERE id = ?", (task_id,))
  513. task_row = cursor.fetchone()
  514. if not task_row:
  515. raise HTTPException(
  516. status_code=status.HTTP_404_NOT_FOUND,
  517. detail=f"任务 '{task_id}' 不存在"
  518. )
  519. # 验证用户存在
  520. cursor.execute("SELECT id FROM users WHERE id = ?", (assign_data.user_id,))
  521. if not cursor.fetchone():
  522. raise HTTPException(
  523. status_code=status.HTTP_404_NOT_FOUND,
  524. detail=f"用户 '{assign_data.user_id}' 不存在"
  525. )
  526. # 更新任务分配
  527. cursor.execute("""
  528. UPDATE tasks SET assigned_to = ? WHERE id = ?
  529. """, (assign_data.user_id, task_id))
  530. assigned_at = datetime.now()
  531. return TaskAssignmentResponse(
  532. task_id=task_id,
  533. assigned_to=assign_data.user_id,
  534. assigned_by=user["id"],
  535. assigned_at=assigned_at
  536. )
  537. @router.post("/preview-assignment/{project_id}", response_model=AssignmentPreviewResponse)
  538. async def preview_assignment(
  539. request: Request,
  540. project_id: str,
  541. preview_request: AssignmentPreviewRequest
  542. ):
  543. """
  544. Preview task assignment distribution for a project.
  545. Shows how tasks would be distributed among selected annotators
  546. without actually performing the assignment.
  547. Requires authentication and admin role.
  548. """
  549. user = request.state.user
  550. if user["role"] != "admin":
  551. raise HTTPException(
  552. status_code=status.HTTP_403_FORBIDDEN,
  553. detail="只有管理员可以预览任务分配"
  554. )
  555. try:
  556. preview = assignment_service.preview_assignment(
  557. project_id=project_id,
  558. annotator_ids=preview_request.user_ids
  559. )
  560. return preview
  561. except ValueError as e:
  562. raise HTTPException(
  563. status_code=status.HTTP_400_BAD_REQUEST,
  564. detail=str(e)
  565. )
  566. @router.post("/dispatch/{project_id}", response_model=DispatchResponse)
  567. async def dispatch_tasks(
  568. request: Request,
  569. project_id: str,
  570. dispatch_request: DispatchRequest
  571. ):
  572. """
  573. Dispatch (assign) all unassigned tasks in a project to selected annotators.
  574. This is a one-click operation that:
  575. 1. Distributes tasks evenly among selected annotators
  576. 2. Updates project status to 'in_progress'
  577. Only works when project is in 'ready' status.
  578. Requires authentication and admin role.
  579. """
  580. user = request.state.user
  581. if user["role"] != "admin":
  582. raise HTTPException(
  583. status_code=status.HTTP_403_FORBIDDEN,
  584. detail="只有管理员可以分发任务"
  585. )
  586. try:
  587. result = assignment_service.dispatch_tasks(
  588. project_id=project_id,
  589. annotator_ids=dispatch_request.user_ids,
  590. admin_id=user["id"]
  591. )
  592. return result
  593. except ValueError as e:
  594. raise HTTPException(
  595. status_code=status.HTTP_400_BAD_REQUEST,
  596. detail=str(e)
  597. )