test_statistics_api.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """
  2. Statistics API tests.
  3. Tests for project and platform statistics endpoints.
  4. """
  5. import pytest
  6. import uuid
  7. import json
  8. from fastapi.testclient import TestClient
  9. from main import app
  10. from database import get_db_connection, init_database
  11. from services.jwt_service import JWTService
  12. import bcrypt
  13. # 测试客户端
  14. client = TestClient(app)
  15. @pytest.fixture(scope="module")
  16. def setup_database():
  17. """初始化测试数据库"""
  18. init_database()
  19. yield
  20. @pytest.fixture
  21. def admin_user(setup_database):
  22. """创建管理员用户"""
  23. admin_id = f"admin_{uuid.uuid4().hex[:8]}"
  24. password_hash = bcrypt.hashpw("admin123".encode(), bcrypt.gensalt()).decode()
  25. with get_db_connection() as conn:
  26. cursor = conn.cursor()
  27. cursor.execute("""
  28. INSERT INTO users (id, username, email, password_hash, role)
  29. VALUES (?, ?, ?, ?, 'admin')
  30. """, (admin_id, f"test_admin_{admin_id}", f"admin_{admin_id}@test.com", password_hash))
  31. user_data = {
  32. "id": admin_id,
  33. "username": f"test_admin_{admin_id}",
  34. "email": f"admin_{admin_id}@test.com",
  35. "role": "admin"
  36. }
  37. token = JWTService.create_access_token(user_data)
  38. yield {"token": token, "user_id": admin_id}
  39. # 清理
  40. with get_db_connection() as conn:
  41. cursor = conn.cursor()
  42. cursor.execute("DELETE FROM users WHERE id = ?", (admin_id,))
  43. @pytest.fixture
  44. def annotator_user(setup_database):
  45. """创建标注人员用户"""
  46. annotator_id = f"annotator_{uuid.uuid4().hex[:8]}"
  47. password_hash = bcrypt.hashpw("annotator123".encode(), bcrypt.gensalt()).decode()
  48. with get_db_connection() as conn:
  49. cursor = conn.cursor()
  50. cursor.execute("""
  51. INSERT INTO users (id, username, email, password_hash, role)
  52. VALUES (?, ?, ?, ?, 'annotator')
  53. """, (annotator_id, f"test_annotator_{annotator_id}", f"annotator_{annotator_id}@test.com", password_hash))
  54. user_data = {
  55. "id": annotator_id,
  56. "username": f"test_annotator_{annotator_id}",
  57. "email": f"annotator_{annotator_id}@test.com",
  58. "role": "annotator"
  59. }
  60. token = JWTService.create_access_token(user_data)
  61. yield {"token": token, "user_id": annotator_id}
  62. # 清理
  63. with get_db_connection() as conn:
  64. cursor = conn.cursor()
  65. cursor.execute("DELETE FROM users WHERE id = ?", (annotator_id,))
  66. @pytest.fixture
  67. def test_project_with_tasks(setup_database, annotator_user):
  68. """创建带任务的测试项目"""
  69. project_id = f"proj_{uuid.uuid4().hex[:8]}"
  70. with get_db_connection() as conn:
  71. cursor = conn.cursor()
  72. # 创建项目
  73. cursor.execute("""
  74. INSERT INTO projects (id, name, description, config)
  75. VALUES (?, ?, ?, ?)
  76. """, (project_id, "Test Project", "Test Description", "<View></View>"))
  77. # 创建任务
  78. task_ids = []
  79. for i in range(5):
  80. task_id = f"task_{uuid.uuid4().hex[:8]}"
  81. status = ["pending", "in_progress", "completed"][i % 3]
  82. data = json.dumps({"items": [{"id": j} for j in range(3)]})
  83. assigned_to = annotator_user["user_id"] if i < 3 else None
  84. cursor.execute("""
  85. INSERT INTO tasks (id, project_id, name, data, status, assigned_to)
  86. VALUES (?, ?, ?, ?, ?, ?)
  87. """, (task_id, project_id, f"Test Task {i}", data, status, assigned_to))
  88. task_ids.append(task_id)
  89. # 创建一些标注
  90. for i, task_id in enumerate(task_ids[:2]):
  91. annotation_id = f"ann_{uuid.uuid4().hex[:8]}"
  92. cursor.execute("""
  93. INSERT INTO annotations (id, task_id, user_id, result)
  94. VALUES (?, ?, ?, ?)
  95. """, (annotation_id, task_id, annotator_user["user_id"], json.dumps({"label": "test"})))
  96. yield project_id
  97. # 清理
  98. with get_db_connection() as conn:
  99. cursor = conn.cursor()
  100. cursor.execute("DELETE FROM projects WHERE id = ?", (project_id,))
  101. class TestOverviewStatistics:
  102. """平台总览统计测试"""
  103. def test_overview_without_auth(self, setup_database):
  104. """未认证时应返回 401"""
  105. response = client.get("/api/statistics/overview")
  106. assert response.status_code == 401
  107. def test_overview_as_annotator(self, annotator_user):
  108. """标注人员访问应返回 403"""
  109. headers = {"Authorization": f"Bearer {annotator_user['token']}"}
  110. response = client.get("/api/statistics/overview", headers=headers)
  111. assert response.status_code == 403
  112. def test_overview_as_admin(self, admin_user):
  113. """管理员可以获取平台统计"""
  114. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  115. response = client.get("/api/statistics/overview", headers=headers)
  116. assert response.status_code == 200
  117. data = response.json()
  118. # 检查返回的字段
  119. assert "total_projects" in data
  120. assert "total_tasks" in data
  121. assert "completed_tasks" in data
  122. assert "in_progress_tasks" in data
  123. assert "pending_tasks" in data
  124. assert "total_users" in data
  125. assert "admin_count" in data
  126. assert "annotator_count" in data
  127. assert "total_annotations" in data
  128. assert "overall_completion_rate" in data
  129. class TestProjectStatistics:
  130. """项目统计测试"""
  131. def test_project_stats_without_auth(self, setup_database, test_project_with_tasks):
  132. """未认证时应返回 401"""
  133. response = client.get(f"/api/statistics/projects/{test_project_with_tasks}")
  134. assert response.status_code == 401
  135. def test_project_stats_as_annotator(self, annotator_user, test_project_with_tasks):
  136. """标注人员可以查看项目统计"""
  137. headers = {"Authorization": f"Bearer {annotator_user['token']}"}
  138. response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers)
  139. assert response.status_code == 200
  140. def test_project_stats_as_admin(self, admin_user, test_project_with_tasks):
  141. """管理员可以获取项目统计"""
  142. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  143. response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers)
  144. assert response.status_code == 200
  145. data = response.json()
  146. # 检查返回的字段
  147. assert data["project_id"] == test_project_with_tasks
  148. assert "project_name" in data
  149. assert "total_tasks" in data
  150. assert "completed_tasks" in data
  151. assert "in_progress_tasks" in data
  152. assert "pending_tasks" in data
  153. assert "total_items" in data
  154. assert "annotated_items" in data
  155. assert "task_completion_rate" in data
  156. assert "data_completion_rate" in data
  157. assert "user_stats" in data
  158. # 验证任务统计
  159. assert data["total_tasks"] == 5
  160. def test_project_stats_nonexistent(self, admin_user):
  161. """获取不存在项目的统计应返回 404"""
  162. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  163. response = client.get("/api/statistics/projects/nonexistent_project", headers=headers)
  164. assert response.status_code == 404
  165. def test_project_stats_user_breakdown(self, admin_user, test_project_with_tasks, annotator_user):
  166. """项目统计应包含用户分解"""
  167. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  168. response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers)
  169. assert response.status_code == 200
  170. data = response.json()
  171. # 检查用户统计
  172. user_stats = data["user_stats"]
  173. assert isinstance(user_stats, list)
  174. # 应该有标注人员的统计
  175. if len(user_stats) > 0:
  176. user_stat = user_stats[0]
  177. assert "user_id" in user_stat
  178. assert "username" in user_stat
  179. assert "assigned_tasks" in user_stat
  180. assert "completed_tasks" in user_stat
  181. assert "annotation_count" in user_stat
  182. assert "completion_rate" in user_stat
  183. class TestStatisticsAccuracy:
  184. """统计准确性测试"""
  185. def test_task_count_accuracy(self, admin_user, test_project_with_tasks):
  186. """任务数量统计应准确"""
  187. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  188. response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers)
  189. assert response.status_code == 200
  190. data = response.json()
  191. # 总任务数应等于各状态任务数之和
  192. total = data["total_tasks"]
  193. sum_by_status = (
  194. data["completed_tasks"] +
  195. data["in_progress_tasks"] +
  196. data["pending_tasks"]
  197. )
  198. assert total == sum_by_status
  199. def test_completion_rate_calculation(self, admin_user, test_project_with_tasks):
  200. """完成率计算应正确"""
  201. headers = {"Authorization": f"Bearer {admin_user['token']}"}
  202. response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers)
  203. assert response.status_code == 200
  204. data = response.json()
  205. # 验证完成率计算
  206. if data["total_tasks"] > 0:
  207. expected_rate = round(data["completed_tasks"] / data["total_tasks"] * 100, 2)
  208. assert data["task_completion_rate"] == expected_rate