test_task_assignment_api.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. """
  2. Task Assignment API tests.
  3. Tests for task assignment endpoints.
  4. """
  5. import pytest
  6. import uuid
  7. import json
  8. from fastapi.testclient import TestClient
  9. from main import app
  10. from database import get_db_connection, init_database
  11. from services.jwt_service import JWTService
  12. import bcrypt
  13. # 测试客户端
  14. client = TestClient(app)
  15. @pytest.fixture(scope="module")
  16. def setup_database():
  17. """初始化测试数据库"""
  18. init_database()
  19. yield
  20. @pytest.fixture
  21. def admin_user(setup_database):
  22. """创建管理员用户"""
  23. admin_id = f"admin_{uuid.uuid4().hex[:8]}"
  24. password_hash = bcrypt.hashpw("admin123".encode(), bcrypt.gensalt()).decode()
  25. with get_db_connection() as conn:
  26. cursor = conn.cursor()
  27. cursor.execute("""
  28. INSERT INTO users (id, username, email, password_hash, role)
  29. VALUES (?, ?, ?, ?, 'admin')
  30. """, (admin_id, f"test_admin_{admin_id}", f"admin_{admin_id}@test.com", password_hash))
  31. user_data = {
  32. "id": admin_id,
  33. "username": f"test_admin_{admin_id}",
  34. "email": f"admin_{admin_id}@test.com",
  35. "role": "admin"
  36. }
  37. token = JWTService.create_access_token(user_data)
  38. yield {"token": token, "user_id": admin_id, "user_data": user_data}
  39. # 清理
  40. with get_db_connection() as conn:
  41. cursor = conn.cursor()
  42. cursor.execute("DELETE FROM users WHERE id = ?", (admin_id,))
  43. @pytest.fixture
  44. def annotator_users(setup_database):
  45. """创建多个标注人员用户"""
  46. annotators = []
  47. for i in range(3):
  48. annotator_id = f"annotator_{uuid.uuid4().hex[:8]}"
  49. password_hash = bcrypt.hashpw("annotator123".encode(), bcrypt.gensalt()).decode()
  50. with get_db_connection() as conn:
  51. cursor = conn.cursor()
  52. cursor.execute("""
  53. INSERT INTO users (id, username, email, password_hash, role)
  54. VALUES (?, ?, ?, ?, 'annotator')
  55. """, (annotator_id, f"test_annotator_{annotator_id}", f"annotator_{annotator_id}@test.com", password_hash))
  56. user_data = {
  57. "id": annotator_id,
  58. "username": f"test_annotator_{annotator_id}",
  59. "email": f"annotator_{annotator_id}@test.com",
  60. "role": "annotator"
  61. }
  62. token = JWTService.create_access_token(user_data)
  63. annotators.append({"token": token, "user_id": annotator_id, "user_data": user_data})
  64. yield annotators
  65. # 清理
  66. with get_db_connection() as conn:
  67. cursor = conn.cursor()
  68. for annotator in annotators:
  69. cursor.execute("DELETE FROM users WHERE id = ?", (annotator["user_id"],))
  70. @pytest.fixture
  71. def test_project(setup_database):
  72. """创建测试项目"""
  73. project_id = f"proj_{uuid.uuid4().hex[:8]}"
  74. with get_db_connection() as conn:
  75. cursor = conn.cursor()
  76. cursor.execute("""
  77. INSERT INTO projects (id, name, description, config)
  78. VALUES (?, ?, ?, ?)
  79. """, (project_id, "Test Project", "Test Description", "<View></View>"))
  80. yield project_id
  81. # 清理
  82. with get_db_connection() as conn:
  83. cursor = conn.cursor()
  84. cursor.execute("DELETE FROM projects WHERE id = ?", (project_id,))
  85. @pytest.fixture
  86. def test_tasks(setup_database, test_project):
  87. """创建测试任务"""
  88. task_ids = []
  89. for i in range(5):
  90. task_id = f"task_{uuid.uuid4().hex[:8]}"
  91. data = json.dumps({"items": [{"id": j} for j in range(3)]})
  92. with get_db_connection() as conn:
  93. cursor = conn.cursor()
  94. cursor.execute("""
  95. INSERT INTO tasks (id, project_id, name, data, status)
  96. VALUES (?, ?, ?, ?, 'pending')
  97. """, (task_id, test_project, f"Test Task {i}", data))
  98. task_ids.append(task_id)
  99. yield task_ids
  100. # 清理
  101. with get_db_connection() as conn:
  102. cursor = conn.cursor()
  103. for task_id in task_ids:
  104. cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
  105. class TestTaskAssignment:
  106. """单个任务分配测试"""
  107. def test_assign_task_without_auth(self, setup_database, test_tasks):
  108. """未认证时应返回 401"""
  109. response = client.put(
  110. f"/api/tasks/{test_tasks[0]}/assign",
  111. json={"user_id": "some_user"}
  112. )
  113. assert response.status_code == 401
  114. def test_assign_task_as_annotator(self, annotator_users, test_tasks):
  115. """标注人员分配任务应返回 403"""
  116. headers = {"Authorization": f"Bearer {annotator_users[0]['token']}"}
  117. response = client.put(
  118. f"/api/tasks/{test_tasks[0]}/assign",
  119. json={"user_id": annotator_users[1]["user_id"]},
  120. headers=headers
  121. )
  122. assert response.status_code == 403
  123. def test_assign_task_as_admin(self, admin_user, annotator_users, test_tasks):
  124. """管理员可以分配任务"""
  125. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  126. response = client.put(
  127. f"/api/tasks/{test_tasks[0]}/assign",
  128. json={"user_id": annotator_users[0]["user_id"]},
  129. headers=headers
  130. )
  131. assert response.status_code == 200
  132. data = response.json()
  133. assert data["task_id"] == test_tasks[0]
  134. assert data["assigned_to"] == annotator_users[0]["user_id"]
  135. assert data["assigned_by"] == admin_user["user_id"]
  136. def test_assign_nonexistent_task(self, admin_user, annotator_users):
  137. """分配不存在的任务应返回 404"""
  138. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  139. response = client.put(
  140. "/api/tasks/nonexistent_task/assign",
  141. json={"user_id": annotator_users[0]["user_id"]},
  142. headers=headers
  143. )
  144. assert response.status_code == 404
  145. def test_assign_to_nonexistent_user(self, admin_user, test_tasks):
  146. """分配给不存在的用户应返回 404"""
  147. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  148. response = client.put(
  149. f"/api/tasks/{test_tasks[0]}/assign",
  150. json={"user_id": "nonexistent_user"},
  151. headers=headers
  152. )
  153. assert response.status_code == 404
  154. class TestBatchAssignment:
  155. """批量任务分配测试"""
  156. def test_batch_assign_without_auth(self, setup_database, test_tasks, annotator_users):
  157. """未认证时应返回 401"""
  158. response = client.post(
  159. "/api/tasks/batch-assign",
  160. json={
  161. "task_ids": test_tasks[:2],
  162. "user_ids": [annotator_users[0]["user_id"]],
  163. "mode": "round_robin"
  164. }
  165. )
  166. assert response.status_code == 401
  167. def test_batch_assign_round_robin(self, admin_user, annotator_users, test_tasks):
  168. """轮询分配模式测试"""
  169. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  170. user_ids = [a["user_id"] for a in annotator_users[:2]]
  171. response = client.post(
  172. "/api/tasks/batch-assign",
  173. json={
  174. "task_ids": test_tasks,
  175. "user_ids": user_ids,
  176. "mode": "round_robin"
  177. },
  178. headers=headers
  179. )
  180. assert response.status_code == 200
  181. data = response.json()
  182. assert data["success_count"] == len(test_tasks)
  183. assert data["failed_count"] == 0
  184. # 验证轮询分配
  185. assignments = data["assignments"]
  186. for i, assignment in enumerate(assignments):
  187. expected_user = user_ids[i % len(user_ids)]
  188. assert assignment["assigned_to"] == expected_user
  189. def test_batch_assign_equal(self, admin_user, annotator_users, test_tasks):
  190. """平均分配模式测试"""
  191. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  192. user_ids = [a["user_id"] for a in annotator_users[:2]]
  193. response = client.post(
  194. "/api/tasks/batch-assign",
  195. json={
  196. "task_ids": test_tasks,
  197. "user_ids": user_ids,
  198. "mode": "equal"
  199. },
  200. headers=headers
  201. )
  202. assert response.status_code == 200
  203. data = response.json()
  204. assert data["success_count"] == len(test_tasks)
  205. # 验证平均分配:每个用户分配的任务数差异不超过 1
  206. user_task_counts = {}
  207. for assignment in data["assignments"]:
  208. user_id = assignment["assigned_to"]
  209. user_task_counts[user_id] = user_task_counts.get(user_id, 0) + 1
  210. counts = list(user_task_counts.values())
  211. assert max(counts) - min(counts) <= 1
  212. def test_batch_assign_invalid_mode(self, admin_user, annotator_users, test_tasks):
  213. """无效分配模式应返回 400"""
  214. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  215. response = client.post(
  216. "/api/tasks/batch-assign",
  217. json={
  218. "task_ids": test_tasks[:2],
  219. "user_ids": [annotator_users[0]["user_id"]],
  220. "mode": "invalid_mode"
  221. },
  222. headers=headers
  223. )
  224. assert response.status_code == 400
  225. class TestMyTasks:
  226. """当前用户任务列表测试"""
  227. def test_my_tasks_without_auth(self, setup_database):
  228. """未认证时应返回 401"""
  229. response = client.get("/api/tasks/my-tasks")
  230. assert response.status_code == 401
  231. def test_my_tasks_empty(self, annotator_users):
  232. """没有分配任务时返回空列表"""
  233. headers = {"Authorization": f"Bearer {annotator_users[0]['token']}"}
  234. response = client.get("/api/tasks/my-tasks", headers=headers)
  235. assert response.status_code == 200
  236. data = response.json()
  237. assert data["total"] == 0
  238. assert data["tasks"] == []
  239. def test_my_tasks_with_assignments(self, admin_user, annotator_users, test_tasks):
  240. """分配任务后可以看到自己的任务"""
  241. # 先分配任务
  242. admin_headers = {"Authorization": f"Bearer {admin_user['token']}"}
  243. target_user = annotator_users[0]
  244. # 分配 2 个任务给第一个标注人员
  245. for task_id in test_tasks[:2]:
  246. client.put(
  247. f"/api/tasks/{task_id}/assign",
  248. json={"user_id": target_user["user_id"]},
  249. headers=admin_headers
  250. )
  251. # 查询自己的任务
  252. user_headers = {"Authorization": f"Bearer {target_user['token']}"}
  253. response = client.get("/api/tasks/my-tasks", headers=user_headers)
  254. assert response.status_code == 200
  255. data = response.json()
  256. assert data["total"] == 2
  257. assert len(data["tasks"]) == 2
  258. # 验证所有任务都是分配给当前用户的
  259. for task in data["tasks"]:
  260. assert task["assigned_to"] == target_user["user_id"]
  261. def test_my_tasks_only_shows_own_tasks(self, admin_user, annotator_users, test_tasks):
  262. """标注人员只能看到分配给自己的任务"""
  263. admin_headers = {"Authorization": f"Bearer {admin_user['token']}"}
  264. # 分配不同任务给不同用户
  265. client.put(
  266. f"/api/tasks/{test_tasks[0]}/assign",
  267. json={"user_id": annotator_users[0]["user_id"]},
  268. headers=admin_headers
  269. )
  270. client.put(
  271. f"/api/tasks/{test_tasks[1]}/assign",
  272. json={"user_id": annotator_users[1]["user_id"]},
  273. headers=admin_headers
  274. )
  275. # 第一个用户只能看到自己的任务
  276. user1_headers = {"Authorization": f"Bearer {annotator_users[0]['token']}"}
  277. response1 = client.get("/api/tasks/my-tasks", headers=user1_headers)
  278. data1 = response1.json()
  279. for task in data1["tasks"]:
  280. assert task["assigned_to"] == annotator_users[0]["user_id"]
  281. # 第二个用户只能看到自己的任务
  282. user2_headers = {"Authorization": f"Bearer {annotator_users[1]['token']}"}
  283. response2 = client.get("/api/tasks/my-tasks", headers=user2_headers)
  284. data2 = response2.json()
  285. for task in data2["tasks"]:
  286. assert task["assigned_to"] == annotator_users[1]["user_id"]