task.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796
  1. """
  2. Task API router.
  3. Provides CRUD endpoints for task management.
  4. """
  5. import uuid
  6. import json
  7. from typing import List, Optional
  8. from fastapi import APIRouter, HTTPException, status, Query, Request
  9. from database import get_db_connection
  10. from schemas.task import (
  11. TaskCreate, TaskUpdate, TaskResponse,
  12. TaskAssignRequest, BatchAssignRequest, BatchAssignResponse,
  13. TaskAssignmentResponse, MyTasksResponse,
  14. AssignmentPreviewRequest, AssignmentPreviewResponse,
  15. DispatchRequest, DispatchResponse,
  16. TaskListPaginationResponse,
  17. )
  18. from models import Task
  19. from datetime import datetime
  20. from services.assignment_service import assignment_service
  21. router = APIRouter(
  22. prefix="/api/tasks",
  23. tags=["tasks"]
  24. )
  25. def calculate_progress(data_str: str, annotation_count: int) -> float:
  26. """计算任务进度"""
  27. try:
  28. data = json.loads(data_str) if isinstance(data_str, str) else data_str
  29. items = data.get('items', [])
  30. if not items:
  31. return 0.0
  32. return min(annotation_count / len(items), 1.0)
  33. except:
  34. return 0.0
  35. @router.get("", response_model=TaskListPaginationResponse)
  36. async def list_tasks(
  37. request: Request,
  38. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  39. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status"),
  40. assigned_to: Optional[str] = Query(None, description="Filter by assigned user"),
  41. page: int = Query(1, ge=1, description="Page number"),
  42. page_size: int = Query(20, ge=1, le=100, description="Items per page")
  43. ):
  44. """
  45. List tasks with optional filters and pagination.
  46. For admin users: Returns all tasks matching the filters.
  47. For annotator users: Returns only tasks assigned to them (ignores assigned_to filter).
  48. Requires authentication.
  49. """
  50. user = request.state.user
  51. user_id = user["id"]
  52. user_role = user["role"]
  53. with get_db_connection() as conn:
  54. cursor = conn.cursor()
  55. # Build query with filters
  56. base_query = """
  57. SELECT
  58. t.id,
  59. t.project_id,
  60. p.name as project_name,
  61. t.name,
  62. t.data,
  63. t.status,
  64. t.assigned_to,
  65. t.created_at,
  66. COUNT(a.id) as annotation_count
  67. FROM tasks t
  68. LEFT JOIN annotations a ON t.id = a.task_id
  69. LEFT JOIN projects p ON t.project_id = p.id
  70. WHERE 1=1
  71. """
  72. params = []
  73. if project_id:
  74. base_query += " AND t.project_id = ?"
  75. params.append(project_id)
  76. if status_filter:
  77. base_query += " AND t.status = ?"
  78. params.append(status_filter)
  79. # 标注员只能看到分配给自己的任务
  80. if user_role != "admin":
  81. base_query += " AND t.assigned_to = ?"
  82. params.append(user_id)
  83. elif assigned_to:
  84. # 管理员可以按 assigned_to 过滤
  85. base_query += " AND t.assigned_to = ?"
  86. params.append(assigned_to)
  87. # 计算总数
  88. count_query = f"""
  89. SELECT COUNT(DISTINCT t.id) as total
  90. FROM tasks t
  91. LEFT JOIN projects p ON t.project_id = p.id
  92. WHERE 1=1
  93. """
  94. count_params = []
  95. if project_id:
  96. count_query += " AND t.project_id = ?"
  97. count_params.append(project_id)
  98. if status_filter:
  99. count_query += " AND t.status = ?"
  100. count_params.append(status_filter)
  101. if user_role != "admin":
  102. count_query += " AND t.assigned_to = ?"
  103. count_params.append(user_id)
  104. elif assigned_to:
  105. count_query += " AND t.assigned_to = ?"
  106. count_params.append(assigned_to)
  107. cursor.execute(count_query, tuple(count_params))
  108. total = cursor.fetchone()["total"]
  109. # 计算分页
  110. total_pages = (total + page_size - 1) // page_size
  111. offset = (page - 1) * page_size
  112. # 添加排序和分页
  113. base_query += " GROUP BY t.id, t.project_id, p.name, t.name, t.data, t.status, t.assigned_to, t.created_at ORDER BY t.created_at DESC LIMIT ? OFFSET ?"
  114. params.extend([page_size, offset])
  115. cursor.execute(base_query, tuple(params))
  116. rows = cursor.fetchall()
  117. tasks = []
  118. for row in rows:
  119. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  120. progress = calculate_progress(row["data"], row["annotation_count"])
  121. tasks.append(TaskResponse(
  122. id=row["id"],
  123. project_id=row["project_id"],
  124. project_name=row["project_name"],
  125. name=row["name"],
  126. data=data,
  127. status=row["status"],
  128. assigned_to=row["assigned_to"],
  129. created_at=row["created_at"],
  130. progress=progress
  131. ))
  132. return TaskListPaginationResponse(
  133. tasks=tasks,
  134. total=total,
  135. page=page,
  136. page_size=page_size,
  137. total_pages=total_pages,
  138. has_next=page < total_pages,
  139. has_prev=page > 1
  140. )
  141. @router.post("", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
  142. async def create_task(request: Request, task: TaskCreate):
  143. """
  144. Create a new task.
  145. Requires authentication.
  146. """
  147. task_id = f"task_{uuid.uuid4().hex[:12]}"
  148. user = request.state.user
  149. assigned_to = task.assigned_to if task.assigned_to else user["id"]
  150. with get_db_connection() as conn:
  151. cursor = conn.cursor()
  152. # Verify project exists
  153. cursor.execute("SELECT id FROM projects WHERE id = ?", (task.project_id,))
  154. if not cursor.fetchone():
  155. raise HTTPException(
  156. status_code=status.HTTP_404_NOT_FOUND,
  157. detail=f"Project with id '{task.project_id}' not found"
  158. )
  159. data_json = json.dumps(task.data)
  160. cursor.execute("""
  161. INSERT INTO tasks (id, project_id, name, data, status, assigned_to)
  162. VALUES (?, ?, ?, ?, 'pending', ?)
  163. """, (task_id, task.project_id, task.name, data_json, assigned_to))
  164. cursor.execute("""
  165. SELECT id, project_id, name, data, status, assigned_to, created_at
  166. FROM tasks WHERE id = ?
  167. """, (task_id,))
  168. row = cursor.fetchone()
  169. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  170. return TaskResponse(
  171. id=row["id"],
  172. project_id=row["project_id"],
  173. name=row["name"],
  174. data=data,
  175. status=row["status"],
  176. assigned_to=row["assigned_to"],
  177. created_at=row["created_at"],
  178. progress=0.0
  179. )
  180. @router.get("/my-tasks", response_model=MyTasksResponse)
  181. async def get_my_tasks(
  182. request: Request,
  183. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  184. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status"),
  185. page: int = Query(1, ge=1, description="Page number"),
  186. page_size: int = Query(20, ge=1, le=100, description="Items per page")
  187. ):
  188. """
  189. Get tasks assigned to the current user.
  190. Requires authentication.
  191. 标注人员只能看到分配给自己的任务。
  192. """
  193. user = request.state.user
  194. user_id = user["id"]
  195. with get_db_connection() as conn:
  196. cursor = conn.cursor()
  197. # 构建查询条件
  198. where_clauses = ["t.assigned_to = ?"]
  199. params = [user_id]
  200. if project_id:
  201. where_clauses.append("t.project_id = ?")
  202. params.append(project_id)
  203. if status_filter:
  204. where_clauses.append("t.status = ?")
  205. params.append(status_filter)
  206. where_sql = " AND ".join(where_clauses)
  207. # 查询总数
  208. count_query = f"""
  209. SELECT COUNT(*) as total
  210. FROM tasks t
  211. WHERE {where_sql}
  212. """
  213. cursor.execute(count_query, tuple(params))
  214. total = cursor.fetchone()["total"]
  215. # 查询各状态的任务数(不受分页影响)
  216. status_count_query = f"""
  217. SELECT
  218. SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
  219. SUM(CASE WHEN status = 'in_progress' THEN 1 ELSE 0 END) as in_progress,
  220. SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending
  221. FROM tasks t
  222. WHERE {where_sql}
  223. """
  224. cursor.execute(status_count_query, tuple(params))
  225. status_row = cursor.fetchone()
  226. completed = int(status_row["completed"] or 0)
  227. in_progress = int(status_row["in_progress"] or 0)
  228. pending = int(status_row["pending"] or 0)
  229. # 计算分页信息
  230. total_pages = (total + page_size - 1) // page_size
  231. has_next = page < total_pages
  232. has_prev = page > 1
  233. skip = (page - 1) * page_size
  234. # 查询任务列表(带项目名称)
  235. query = f"""
  236. SELECT
  237. t.id,
  238. t.project_id,
  239. p.name as project_name,
  240. t.name,
  241. t.data,
  242. t.status,
  243. t.assigned_to,
  244. t.created_at,
  245. COUNT(a.id) as annotation_count
  246. FROM tasks t
  247. LEFT JOIN annotations a ON t.id = a.task_id
  248. LEFT JOIN projects p ON t.project_id = p.id
  249. WHERE {where_sql}
  250. GROUP BY t.id, t.project_id, p.name, t.name, t.data, t.status, t.assigned_to, t.created_at
  251. ORDER BY t.created_at DESC
  252. LIMIT ? OFFSET ?
  253. """
  254. params.extend([page_size, skip])
  255. cursor.execute(query, tuple(params))
  256. rows = cursor.fetchall()
  257. tasks = []
  258. for row in rows:
  259. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  260. progress = calculate_progress(row["data"], row["annotation_count"])
  261. task_status = row["status"]
  262. tasks.append(TaskResponse(
  263. id=row["id"],
  264. project_id=row["project_id"],
  265. project_name=row["project_name"],
  266. name=row["name"],
  267. data=data,
  268. status=task_status,
  269. assigned_to=row["assigned_to"],
  270. created_at=row["created_at"],
  271. progress=progress
  272. ))
  273. return MyTasksResponse(
  274. tasks=tasks,
  275. total=total,
  276. completed=completed,
  277. in_progress=in_progress,
  278. pending=pending,
  279. page=page,
  280. page_size=page_size,
  281. total_pages=total_pages,
  282. has_next=has_next,
  283. has_prev=has_prev
  284. )
  285. @router.post("/batch-assign", response_model=BatchAssignResponse)
  286. async def batch_assign_tasks(request: Request, assign_data: BatchAssignRequest):
  287. """
  288. Batch assign multiple tasks to multiple users.
  289. Requires authentication and admin role.
  290. 支持两种分配模式:
  291. - round_robin: 轮询分配,按顺序将任务分配给用户
  292. - equal: 平均分配,尽量使每个用户分配的任务数量相等
  293. """
  294. user = request.state.user
  295. # 只有管理员可以分配任务
  296. if user["role"] != "admin":
  297. raise HTTPException(
  298. status_code=status.HTTP_403_FORBIDDEN,
  299. detail="只有管理员可以分配任务"
  300. )
  301. if assign_data.mode not in ["round_robin", "equal"]:
  302. raise HTTPException(
  303. status_code=status.HTTP_400_BAD_REQUEST,
  304. detail="分配模式必须是 'round_robin' 或 'equal'"
  305. )
  306. with get_db_connection() as conn:
  307. cursor = conn.cursor()
  308. # 验证所有用户存在
  309. valid_user_ids = []
  310. for user_id in assign_data.user_ids:
  311. cursor.execute("SELECT id FROM users WHERE id = ?", (user_id,))
  312. if cursor.fetchone():
  313. valid_user_ids.append(user_id)
  314. if not valid_user_ids:
  315. raise HTTPException(
  316. status_code=status.HTTP_400_BAD_REQUEST,
  317. detail="没有有效的用户可以分配任务"
  318. )
  319. # 验证所有任务存在
  320. valid_task_ids = []
  321. for task_id in assign_data.task_ids:
  322. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  323. if cursor.fetchone():
  324. valid_task_ids.append(task_id)
  325. if not valid_task_ids:
  326. raise HTTPException(
  327. status_code=status.HTTP_400_BAD_REQUEST,
  328. detail="没有有效的任务可以分配"
  329. )
  330. assignments = []
  331. errors = []
  332. assigned_at = datetime.now()
  333. if assign_data.mode == "round_robin":
  334. # 轮询分配
  335. for i, task_id in enumerate(valid_task_ids):
  336. user_index = i % len(valid_user_ids)
  337. target_user_id = valid_user_ids[user_index]
  338. try:
  339. cursor.execute("""
  340. UPDATE tasks SET assigned_to = ? WHERE id = ?
  341. """, (target_user_id, task_id))
  342. assignments.append(TaskAssignmentResponse(
  343. task_id=task_id,
  344. assigned_to=target_user_id,
  345. assigned_by=user["id"],
  346. assigned_at=assigned_at
  347. ))
  348. except Exception as e:
  349. errors.append({
  350. "task_id": task_id,
  351. "error": str(e)
  352. })
  353. else: # equal 模式
  354. # 平均分配:计算每个用户应该分配的任务数
  355. num_tasks = len(valid_task_ids)
  356. num_users = len(valid_user_ids)
  357. base_count = num_tasks // num_users
  358. extra_count = num_tasks % num_users
  359. task_index = 0
  360. for user_index, target_user_id in enumerate(valid_user_ids):
  361. # 前 extra_count 个用户多分配一个任务
  362. count = base_count + (1 if user_index < extra_count else 0)
  363. for _ in range(count):
  364. if task_index >= len(valid_task_ids):
  365. break
  366. task_id = valid_task_ids[task_index]
  367. task_index += 1
  368. try:
  369. cursor.execute("""
  370. UPDATE tasks SET assigned_to = ? WHERE id = ?
  371. """, (target_user_id, task_id))
  372. assignments.append(TaskAssignmentResponse(
  373. task_id=task_id,
  374. assigned_to=target_user_id,
  375. assigned_by=user["id"],
  376. assigned_at=assigned_at
  377. ))
  378. except Exception as e:
  379. errors.append({
  380. "task_id": task_id,
  381. "error": str(e)
  382. })
  383. return BatchAssignResponse(
  384. success_count=len(assignments),
  385. failed_count=len(errors),
  386. assignments=assignments,
  387. errors=errors
  388. )
  389. @router.get("/{task_id}", response_model=TaskResponse)
  390. async def get_task(request: Request, task_id: str):
  391. """
  392. Get task by ID.
  393. Requires authentication.
  394. """
  395. with get_db_connection() as conn:
  396. cursor = conn.cursor()
  397. cursor.execute("""
  398. SELECT
  399. t.id,
  400. t.project_id,
  401. t.name,
  402. t.data,
  403. t.status,
  404. t.assigned_to,
  405. t.created_at,
  406. COUNT(a.id) as annotation_count
  407. FROM tasks t
  408. LEFT JOIN annotations a ON t.id = a.task_id
  409. WHERE t.id = ?
  410. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  411. """, (task_id,))
  412. row = cursor.fetchone()
  413. if not row:
  414. raise HTTPException(
  415. status_code=status.HTTP_404_NOT_FOUND,
  416. detail=f"Task with id '{task_id}' not found"
  417. )
  418. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  419. progress = calculate_progress(row["data"], row["annotation_count"])
  420. return TaskResponse(
  421. id=row["id"],
  422. project_id=row["project_id"],
  423. name=row["name"],
  424. data=data,
  425. status=row["status"],
  426. assigned_to=row["assigned_to"],
  427. created_at=row["created_at"],
  428. progress=progress
  429. )
  430. @router.put("/{task_id}", response_model=TaskResponse)
  431. async def update_task(request: Request, task_id: str, task: TaskUpdate):
  432. """
  433. Update an existing task.
  434. Requires authentication.
  435. """
  436. with get_db_connection() as conn:
  437. cursor = conn.cursor()
  438. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  439. if not cursor.fetchone():
  440. raise HTTPException(
  441. status_code=status.HTTP_404_NOT_FOUND,
  442. detail=f"Task with id '{task_id}' not found"
  443. )
  444. update_fields = []
  445. update_values = []
  446. if task.name is not None:
  447. update_fields.append("name = ?")
  448. update_values.append(task.name)
  449. if task.data is not None:
  450. update_fields.append("data = ?")
  451. update_values.append(json.dumps(task.data))
  452. if task.status is not None:
  453. update_fields.append("status = ?")
  454. update_values.append(task.status)
  455. if task.assigned_to is not None:
  456. update_fields.append("assigned_to = ?")
  457. update_values.append(task.assigned_to)
  458. if update_fields:
  459. update_values.append(task_id)
  460. cursor.execute(f"""
  461. UPDATE tasks SET {', '.join(update_fields)} WHERE id = ?
  462. """, tuple(update_values))
  463. cursor.execute("""
  464. SELECT
  465. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  466. COUNT(a.id) as annotation_count
  467. FROM tasks t
  468. LEFT JOIN annotations a ON t.id = a.task_id
  469. WHERE t.id = ?
  470. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  471. """, (task_id,))
  472. row = cursor.fetchone()
  473. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  474. progress = calculate_progress(row["data"], row["annotation_count"])
  475. return TaskResponse(
  476. id=row["id"],
  477. project_id=row["project_id"],
  478. name=row["name"],
  479. data=data,
  480. status=row["status"],
  481. assigned_to=row["assigned_to"],
  482. created_at=row["created_at"],
  483. progress=progress
  484. )
  485. @router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
  486. async def delete_task(request: Request, task_id: str):
  487. """
  488. Delete a task and all associated annotations.
  489. Requires authentication and admin role.
  490. """
  491. user = request.state.user
  492. if user["role"] != "admin":
  493. raise HTTPException(
  494. status_code=status.HTTP_403_FORBIDDEN,
  495. detail="只有管理员可以删除任务"
  496. )
  497. with get_db_connection() as conn:
  498. cursor = conn.cursor()
  499. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  500. if not cursor.fetchone():
  501. raise HTTPException(
  502. status_code=status.HTTP_404_NOT_FOUND,
  503. detail=f"Task with id '{task_id}' not found"
  504. )
  505. cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
  506. return None
  507. @router.get("/projects/{project_id}/tasks", response_model=List[TaskResponse])
  508. async def get_project_tasks(request: Request, project_id: str):
  509. """
  510. Get tasks for a specific project.
  511. For admin users: Returns all tasks in the project.
  512. For annotator users: Returns only tasks assigned to them.
  513. Requires authentication.
  514. """
  515. user = request.state.user
  516. user_id = user["id"]
  517. user_role = user["role"]
  518. with get_db_connection() as conn:
  519. cursor = conn.cursor()
  520. cursor.execute("SELECT id FROM projects WHERE id = ?", (project_id,))
  521. if not cursor.fetchone():
  522. raise HTTPException(
  523. status_code=status.HTTP_404_NOT_FOUND,
  524. detail=f"Project with id '{project_id}' not found"
  525. )
  526. if user_role == "admin":
  527. # 管理员:返回项目的所有任务
  528. cursor.execute("""
  529. SELECT
  530. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  531. COUNT(a.id) as annotation_count
  532. FROM tasks t
  533. LEFT JOIN annotations a ON t.id = a.task_id
  534. WHERE t.project_id = ?
  535. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  536. ORDER BY t.created_at DESC
  537. """, (project_id,))
  538. else:
  539. # 标注员:只返回分配给自己的任务
  540. cursor.execute("""
  541. SELECT
  542. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  543. COUNT(a.id) as annotation_count
  544. FROM tasks t
  545. LEFT JOIN annotations a ON t.id = a.task_id
  546. WHERE t.project_id = ? AND t.assigned_to = ?
  547. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  548. ORDER BY t.created_at DESC
  549. """, (project_id, user_id))
  550. rows = cursor.fetchall()
  551. tasks = []
  552. for row in rows:
  553. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  554. progress = calculate_progress(row["data"], row["annotation_count"])
  555. tasks.append(TaskResponse(
  556. id=row["id"],
  557. project_id=row["project_id"],
  558. name=row["name"],
  559. data=data,
  560. status=row["status"],
  561. assigned_to=row["assigned_to"],
  562. created_at=row["created_at"],
  563. progress=progress
  564. ))
  565. return tasks
  566. @router.put("/{task_id}/assign", response_model=TaskAssignmentResponse)
  567. async def assign_task(request: Request, task_id: str, assign_data: TaskAssignRequest):
  568. """
  569. Assign a task to a specific user.
  570. Requires authentication and admin role.
  571. 将任务分配给指定用户,记录分配时间和分配人。
  572. """
  573. user = request.state.user
  574. # 只有管理员可以分配任务
  575. if user["role"] != "admin":
  576. raise HTTPException(
  577. status_code=status.HTTP_403_FORBIDDEN,
  578. detail="只有管理员可以分配任务"
  579. )
  580. with get_db_connection() as conn:
  581. cursor = conn.cursor()
  582. # 验证任务存在
  583. cursor.execute("SELECT id, assigned_to FROM tasks WHERE id = ?", (task_id,))
  584. task_row = cursor.fetchone()
  585. if not task_row:
  586. raise HTTPException(
  587. status_code=status.HTTP_404_NOT_FOUND,
  588. detail=f"任务 '{task_id}' 不存在"
  589. )
  590. # 验证用户存在
  591. cursor.execute("SELECT id FROM users WHERE id = ?", (assign_data.user_id,))
  592. if not cursor.fetchone():
  593. raise HTTPException(
  594. status_code=status.HTTP_404_NOT_FOUND,
  595. detail=f"用户 '{assign_data.user_id}' 不存在"
  596. )
  597. # 更新任务分配
  598. cursor.execute("""
  599. UPDATE tasks SET assigned_to = ? WHERE id = ?
  600. """, (assign_data.user_id, task_id))
  601. assigned_at = datetime.now()
  602. return TaskAssignmentResponse(
  603. task_id=task_id,
  604. assigned_to=assign_data.user_id,
  605. assigned_by=user["id"],
  606. assigned_at=assigned_at
  607. )
  608. @router.post("/preview-assignment/{project_id}", response_model=AssignmentPreviewResponse)
  609. async def preview_assignment(
  610. request: Request,
  611. project_id: str,
  612. preview_request: AssignmentPreviewRequest
  613. ):
  614. """
  615. Preview task assignment distribution for a project.
  616. Shows how tasks would be distributed among selected annotators
  617. without actually performing the assignment.
  618. Requires authentication and admin role.
  619. """
  620. user = request.state.user
  621. if user["role"] != "admin":
  622. raise HTTPException(
  623. status_code=status.HTTP_403_FORBIDDEN,
  624. detail="只有管理员可以预览任务分配"
  625. )
  626. try:
  627. preview = assignment_service.preview_assignment(
  628. project_id=project_id,
  629. annotator_ids=preview_request.user_ids
  630. )
  631. return preview
  632. except ValueError as e:
  633. raise HTTPException(
  634. status_code=status.HTTP_400_BAD_REQUEST,
  635. detail=str(e)
  636. )
  637. @router.post("/dispatch/{project_id}", response_model=DispatchResponse)
  638. async def dispatch_tasks(
  639. request: Request,
  640. project_id: str,
  641. dispatch_request: DispatchRequest
  642. ):
  643. """
  644. Dispatch (assign) all unassigned tasks in a project to selected annotators.
  645. This is a one-click operation that:
  646. 1. Distributes tasks evenly among selected annotators
  647. 2. Updates project status to 'in_progress'
  648. Only works when project is in 'ready' status.
  649. Requires authentication and admin role.
  650. """
  651. user = request.state.user
  652. if user["role"] != "admin":
  653. raise HTTPException(
  654. status_code=status.HTTP_403_FORBIDDEN,
  655. detail="只有管理员可以分发任务"
  656. )
  657. try:
  658. result = assignment_service.dispatch_tasks(
  659. project_id=project_id,
  660. annotator_ids=dispatch_request.user_ids,
  661. admin_id=user["id"]
  662. )
  663. return result
  664. except ValueError as e:
  665. raise HTTPException(
  666. status_code=status.HTTP_400_BAD_REQUEST,
  667. detail=str(e)
  668. )