assignment_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. """
  2. Assignment Service for task distribution.
  3. Handles task assignment preview and dispatch to annotators.
  4. """
  5. from datetime import datetime
  6. from typing import List, Dict, Any, Optional
  7. from database import get_db_connection
  8. from schemas.project import ProjectStatus
  9. from schemas.task import (
  10. DispatchRequest,
  11. AssignmentPreviewRequest,
  12. AnnotatorAssignment,
  13. AssignmentPreviewResponse,
  14. DispatchResponse,
  15. )
  16. class AssignmentService:
  17. """Service for managing task assignments."""
  18. def get_annotator_workload(self, annotator_ids: Optional[List[str]] = None) -> Dict[str, Dict[str, Any]]:
  19. """
  20. Get current workload for annotators.
  21. Args:
  22. annotator_ids: Optional list of annotator IDs to filter
  23. Returns:
  24. Dictionary mapping annotator_id to workload info
  25. """
  26. with get_db_connection() as conn:
  27. cursor = conn.cursor()
  28. # Get all annotators or filter by IDs
  29. if annotator_ids:
  30. placeholders = ",".join(["?" for _ in annotator_ids])
  31. cursor.execute(f"""
  32. SELECT id, username, email
  33. FROM users
  34. WHERE role = 'annotator' AND id IN ({placeholders})
  35. """, annotator_ids)
  36. else:
  37. cursor.execute("""
  38. SELECT id, username, email
  39. FROM users
  40. WHERE role = 'annotator'
  41. """)
  42. annotators = cursor.fetchall()
  43. workload = {}
  44. for annotator in annotators:
  45. annotator_id = annotator["id"]
  46. # Get task counts for this annotator
  47. cursor.execute("""
  48. SELECT
  49. COUNT(*) as total_assigned,
  50. SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
  51. SUM(CASE WHEN status = 'in_progress' THEN 1 ELSE 0 END) as in_progress,
  52. SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending
  53. FROM tasks
  54. WHERE assigned_to = ?
  55. """, (annotator_id,))
  56. counts = cursor.fetchone()
  57. workload[annotator_id] = {
  58. "id": annotator_id,
  59. "username": annotator["username"],
  60. "email": annotator["email"],
  61. "total_assigned": counts["total_assigned"] or 0,
  62. "completed": counts["completed"] or 0,
  63. "in_progress": counts["in_progress"] or 0,
  64. "pending": counts["pending"] or 0,
  65. }
  66. return workload
  67. def preview_assignment(
  68. self,
  69. project_id: str,
  70. annotator_ids: List[str]
  71. ) -> AssignmentPreviewResponse:
  72. """
  73. Preview task assignment distribution.
  74. Args:
  75. project_id: Project ID
  76. annotator_ids: List of annotator IDs to assign tasks to
  77. Returns:
  78. Preview of how tasks would be distributed
  79. """
  80. with get_db_connection() as conn:
  81. cursor = conn.cursor()
  82. # Check project exists and get status
  83. cursor.execute("""
  84. SELECT id, name, status FROM projects WHERE id = ?
  85. """, (project_id,))
  86. project = cursor.fetchone()
  87. if not project:
  88. raise ValueError(f"项目 '{project_id}' 不存在")
  89. # Get total and unassigned tasks count
  90. cursor.execute("""
  91. SELECT
  92. COUNT(*) as total,
  93. SUM(CASE WHEN assigned_to IS NULL THEN 1 ELSE 0 END) as unassigned
  94. FROM tasks
  95. WHERE project_id = ?
  96. """, (project_id,))
  97. counts = cursor.fetchone()
  98. total_count = counts["total"] or 0
  99. unassigned_count = counts["unassigned"] or 0
  100. if unassigned_count == 0:
  101. raise ValueError("没有待分配的任务")
  102. if not annotator_ids:
  103. raise ValueError("请选择至少一个标注人员")
  104. # Validate annotators exist
  105. placeholders = ",".join(["?" for _ in annotator_ids])
  106. cursor.execute(f"""
  107. SELECT id, username FROM users
  108. WHERE id IN ({placeholders}) AND role = 'annotator'
  109. """, annotator_ids)
  110. valid_annotators = cursor.fetchall()
  111. if len(valid_annotators) != len(annotator_ids):
  112. raise ValueError("部分标注人员ID无效")
  113. # Calculate distribution (round-robin style, equal distribution)
  114. num_annotators = len(annotator_ids)
  115. base_count = unassigned_count // num_annotators
  116. remainder = unassigned_count % num_annotators
  117. # Get current workload
  118. workload = self.get_annotator_workload(annotator_ids)
  119. assignments = []
  120. for i, annotator in enumerate(valid_annotators):
  121. annotator_id = annotator["id"]
  122. # First 'remainder' annotators get one extra task
  123. task_count = base_count + (1 if i < remainder else 0)
  124. current_workload = workload.get(annotator_id, {})
  125. # Current workload = total assigned - completed (active tasks)
  126. active_tasks = current_workload.get("total_assigned", 0) - current_workload.get("completed", 0)
  127. assignments.append(AnnotatorAssignment(
  128. user_id=annotator_id,
  129. username=annotator["username"],
  130. task_count=task_count,
  131. percentage=round(task_count / unassigned_count * 100, 1) if unassigned_count > 0 else 0,
  132. current_workload=active_tasks,
  133. ))
  134. return AssignmentPreviewResponse(
  135. project_id=project_id,
  136. total_tasks=total_count,
  137. unassigned_tasks=unassigned_count,
  138. assignments=assignments,
  139. )
  140. def dispatch_tasks(
  141. self,
  142. project_id: str,
  143. annotator_ids: List[str],
  144. admin_id: str
  145. ) -> DispatchResponse:
  146. """
  147. Dispatch tasks to annotators.
  148. Args:
  149. project_id: Project ID
  150. annotator_ids: List of annotator IDs to assign tasks to
  151. admin_id: ID of admin performing the dispatch
  152. Returns:
  153. Dispatch result with assignment details
  154. """
  155. with get_db_connection() as conn:
  156. cursor = conn.cursor()
  157. # Check project exists and status
  158. cursor.execute("""
  159. SELECT id, name, status FROM projects WHERE id = ?
  160. """, (project_id,))
  161. project = cursor.fetchone()
  162. if not project:
  163. raise ValueError(f"项目 '{project_id}' 不存在")
  164. current_status = ProjectStatus(project["status"]) if project["status"] else ProjectStatus.DRAFT
  165. # Only allow dispatch in ready status
  166. if current_status != ProjectStatus.READY:
  167. raise ValueError(f"只能在 ready 状态下分发任务,当前状态: {current_status.value}")
  168. # Get unassigned task IDs
  169. cursor.execute("""
  170. SELECT id FROM tasks
  171. WHERE project_id = ? AND assigned_to IS NULL
  172. ORDER BY id
  173. """, (project_id,))
  174. unassigned_tasks = [row["id"] for row in cursor.fetchall()]
  175. if not unassigned_tasks:
  176. raise ValueError("没有待分配的任务")
  177. if not annotator_ids:
  178. raise ValueError("请选择至少一个标注人员")
  179. # Validate annotators
  180. placeholders = ",".join(["?" for _ in annotator_ids])
  181. cursor.execute(f"""
  182. SELECT id, username FROM users
  183. WHERE id IN ({placeholders}) AND role = 'annotator'
  184. """, annotator_ids)
  185. valid_annotators = {row["id"]: row["username"] for row in cursor.fetchall()}
  186. if len(valid_annotators) != len(annotator_ids):
  187. raise ValueError("部分标注人员ID无效")
  188. # Distribute tasks
  189. num_annotators = len(annotator_ids)
  190. assignments_result = {aid: {"count": 0, "name": valid_annotators[aid]} for aid in annotator_ids}
  191. for i, task_id in enumerate(unassigned_tasks):
  192. annotator_id = annotator_ids[i % num_annotators]
  193. # Update task assignment
  194. cursor.execute("""
  195. UPDATE tasks
  196. SET assigned_to = ?, status = 'pending'
  197. WHERE id = ?
  198. """, (annotator_id, task_id))
  199. assignments_result[annotator_id]["count"] += 1
  200. # Update project status to in_progress
  201. cursor.execute("""
  202. UPDATE projects
  203. SET status = ?, updated_at = ?
  204. WHERE id = ?
  205. """, (ProjectStatus.IN_PROGRESS.value, datetime.now(), project_id))
  206. # Build response
  207. total_assigned = len(unassigned_tasks)
  208. # Get updated workload
  209. workload = self.get_annotator_workload(annotator_ids)
  210. assignments = []
  211. for aid, info in assignments_result.items():
  212. current_workload = workload.get(aid, {})
  213. active_tasks = current_workload.get("total_assigned", 0) - current_workload.get("completed", 0)
  214. assignments.append(AnnotatorAssignment(
  215. user_id=aid,
  216. username=info["name"],
  217. task_count=info["count"],
  218. percentage=round(info["count"] / total_assigned * 100, 1) if total_assigned > 0 else 0,
  219. current_workload=active_tasks,
  220. ))
  221. return DispatchResponse(
  222. project_id=project_id,
  223. success=True,
  224. total_assigned=total_assigned,
  225. assignments=assignments,
  226. project_status=ProjectStatus.IN_PROGRESS.value,
  227. )
  228. # Singleton instance
  229. assignment_service = AssignmentService()