task.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Task API router.
  3. Provides CRUD endpoints for task management.
  4. """
  5. import uuid
  6. import json
  7. from typing import List, Optional
  8. from fastapi import APIRouter, HTTPException, status, Query, Request
  9. from database import get_db_connection
  10. from schemas.task import TaskCreate, TaskUpdate, TaskResponse
  11. from models import Task
  12. router = APIRouter(
  13. prefix="/api/tasks",
  14. tags=["tasks"]
  15. )
  16. def calculate_progress(data_str: str, annotation_count: int) -> float:
  17. """计算任务进度"""
  18. try:
  19. data = json.loads(data_str) if isinstance(data_str, str) else data_str
  20. items = data.get('items', [])
  21. if not items:
  22. return 0.0
  23. return min(annotation_count / len(items), 1.0)
  24. except:
  25. return 0.0
  26. @router.get("", response_model=List[TaskResponse])
  27. async def list_tasks(
  28. request: Request,
  29. project_id: Optional[str] = Query(None, description="Filter by project ID"),
  30. status_filter: Optional[str] = Query(None, alias="status", description="Filter by status"),
  31. assigned_to: Optional[str] = Query(None, description="Filter by assigned user")
  32. ):
  33. """
  34. List all tasks with optional filters.
  35. Requires authentication.
  36. """
  37. with get_db_connection() as conn:
  38. cursor = conn.cursor()
  39. # Build query with filters
  40. query = """
  41. SELECT
  42. t.id,
  43. t.project_id,
  44. t.name,
  45. t.data,
  46. t.status,
  47. t.assigned_to,
  48. t.created_at,
  49. COUNT(a.id) as annotation_count
  50. FROM tasks t
  51. LEFT JOIN annotations a ON t.id = a.task_id
  52. WHERE 1=1
  53. """
  54. params = []
  55. if project_id:
  56. query += " AND t.project_id = ?"
  57. params.append(project_id)
  58. if status_filter:
  59. query += " AND t.status = ?"
  60. params.append(status_filter)
  61. if assigned_to:
  62. query += " AND t.assigned_to = ?"
  63. params.append(assigned_to)
  64. query += " GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at ORDER BY t.created_at DESC"
  65. cursor.execute(query, tuple(params))
  66. rows = cursor.fetchall()
  67. tasks = []
  68. for row in rows:
  69. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  70. progress = calculate_progress(row["data"], row["annotation_count"])
  71. tasks.append(TaskResponse(
  72. id=row["id"],
  73. project_id=row["project_id"],
  74. name=row["name"],
  75. data=data,
  76. status=row["status"],
  77. assigned_to=row["assigned_to"],
  78. created_at=row["created_at"],
  79. progress=progress
  80. ))
  81. return tasks
  82. @router.post("", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
  83. async def create_task(request: Request, task: TaskCreate):
  84. """
  85. Create a new task.
  86. Requires authentication.
  87. """
  88. task_id = f"task_{uuid.uuid4().hex[:12]}"
  89. user = request.state.user
  90. assigned_to = task.assigned_to if task.assigned_to else user["id"]
  91. with get_db_connection() as conn:
  92. cursor = conn.cursor()
  93. # Verify project exists
  94. cursor.execute("SELECT id FROM projects WHERE id = ?", (task.project_id,))
  95. if not cursor.fetchone():
  96. raise HTTPException(
  97. status_code=status.HTTP_404_NOT_FOUND,
  98. detail=f"Project with id '{task.project_id}' not found"
  99. )
  100. data_json = json.dumps(task.data)
  101. cursor.execute("""
  102. INSERT INTO tasks (id, project_id, name, data, status, assigned_to)
  103. VALUES (?, ?, ?, ?, 'pending', ?)
  104. """, (task_id, task.project_id, task.name, data_json, assigned_to))
  105. cursor.execute("""
  106. SELECT id, project_id, name, data, status, assigned_to, created_at
  107. FROM tasks WHERE id = ?
  108. """, (task_id,))
  109. row = cursor.fetchone()
  110. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  111. return TaskResponse(
  112. id=row["id"],
  113. project_id=row["project_id"],
  114. name=row["name"],
  115. data=data,
  116. status=row["status"],
  117. assigned_to=row["assigned_to"],
  118. created_at=row["created_at"],
  119. progress=0.0
  120. )
  121. @router.get("/{task_id}", response_model=TaskResponse)
  122. async def get_task(request: Request, task_id: str):
  123. """
  124. Get task by ID.
  125. Requires authentication.
  126. """
  127. with get_db_connection() as conn:
  128. cursor = conn.cursor()
  129. cursor.execute("""
  130. SELECT
  131. t.id,
  132. t.project_id,
  133. t.name,
  134. t.data,
  135. t.status,
  136. t.assigned_to,
  137. t.created_at,
  138. COUNT(a.id) as annotation_count
  139. FROM tasks t
  140. LEFT JOIN annotations a ON t.id = a.task_id
  141. WHERE t.id = ?
  142. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  143. """, (task_id,))
  144. row = cursor.fetchone()
  145. if not row:
  146. raise HTTPException(
  147. status_code=status.HTTP_404_NOT_FOUND,
  148. detail=f"Task with id '{task_id}' not found"
  149. )
  150. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  151. progress = calculate_progress(row["data"], row["annotation_count"])
  152. return TaskResponse(
  153. id=row["id"],
  154. project_id=row["project_id"],
  155. name=row["name"],
  156. data=data,
  157. status=row["status"],
  158. assigned_to=row["assigned_to"],
  159. created_at=row["created_at"],
  160. progress=progress
  161. )
  162. @router.put("/{task_id}", response_model=TaskResponse)
  163. async def update_task(request: Request, task_id: str, task: TaskUpdate):
  164. """
  165. Update an existing task.
  166. Requires authentication.
  167. """
  168. with get_db_connection() as conn:
  169. cursor = conn.cursor()
  170. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  171. if not cursor.fetchone():
  172. raise HTTPException(
  173. status_code=status.HTTP_404_NOT_FOUND,
  174. detail=f"Task with id '{task_id}' not found"
  175. )
  176. update_fields = []
  177. update_values = []
  178. if task.name is not None:
  179. update_fields.append("name = ?")
  180. update_values.append(task.name)
  181. if task.data is not None:
  182. update_fields.append("data = ?")
  183. update_values.append(json.dumps(task.data))
  184. if task.status is not None:
  185. update_fields.append("status = ?")
  186. update_values.append(task.status)
  187. if task.assigned_to is not None:
  188. update_fields.append("assigned_to = ?")
  189. update_values.append(task.assigned_to)
  190. if update_fields:
  191. update_values.append(task_id)
  192. cursor.execute(f"""
  193. UPDATE tasks SET {', '.join(update_fields)} WHERE id = ?
  194. """, tuple(update_values))
  195. cursor.execute("""
  196. SELECT
  197. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  198. COUNT(a.id) as annotation_count
  199. FROM tasks t
  200. LEFT JOIN annotations a ON t.id = a.task_id
  201. WHERE t.id = ?
  202. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  203. """, (task_id,))
  204. row = cursor.fetchone()
  205. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  206. progress = calculate_progress(row["data"], row["annotation_count"])
  207. return TaskResponse(
  208. id=row["id"],
  209. project_id=row["project_id"],
  210. name=row["name"],
  211. data=data,
  212. status=row["status"],
  213. assigned_to=row["assigned_to"],
  214. created_at=row["created_at"],
  215. progress=progress
  216. )
  217. @router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
  218. async def delete_task(request: Request, task_id: str):
  219. """
  220. Delete a task and all associated annotations.
  221. Requires authentication and admin role.
  222. """
  223. user = request.state.user
  224. if user["role"] != "admin":
  225. raise HTTPException(
  226. status_code=status.HTTP_403_FORBIDDEN,
  227. detail="只有管理员可以删除任务"
  228. )
  229. with get_db_connection() as conn:
  230. cursor = conn.cursor()
  231. cursor.execute("SELECT id FROM tasks WHERE id = ?", (task_id,))
  232. if not cursor.fetchone():
  233. raise HTTPException(
  234. status_code=status.HTTP_404_NOT_FOUND,
  235. detail=f"Task with id '{task_id}' not found"
  236. )
  237. cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
  238. return None
  239. @router.get("/projects/{project_id}/tasks", response_model=List[TaskResponse])
  240. async def get_project_tasks(request: Request, project_id: str):
  241. """
  242. Get all tasks for a specific project.
  243. Requires authentication.
  244. """
  245. with get_db_connection() as conn:
  246. cursor = conn.cursor()
  247. cursor.execute("SELECT id FROM projects WHERE id = ?", (project_id,))
  248. if not cursor.fetchone():
  249. raise HTTPException(
  250. status_code=status.HTTP_404_NOT_FOUND,
  251. detail=f"Project with id '{project_id}' not found"
  252. )
  253. cursor.execute("""
  254. SELECT
  255. t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at,
  256. COUNT(a.id) as annotation_count
  257. FROM tasks t
  258. LEFT JOIN annotations a ON t.id = a.task_id
  259. WHERE t.project_id = ?
  260. GROUP BY t.id, t.project_id, t.name, t.data, t.status, t.assigned_to, t.created_at
  261. ORDER BY t.created_at DESC
  262. """, (project_id,))
  263. rows = cursor.fetchall()
  264. tasks = []
  265. for row in rows:
  266. data = json.loads(row["data"]) if isinstance(row["data"], str) else row["data"]
  267. progress = calculate_progress(row["data"], row["annotation_count"])
  268. tasks.append(TaskResponse(
  269. id=row["id"],
  270. project_id=row["project_id"],
  271. name=row["name"],
  272. data=data,
  273. status=row["status"],
  274. assigned_to=row["assigned_to"],
  275. created_at=row["created_at"],
  276. progress=progress
  277. ))
  278. return tasks