""" Statistics API tests. Tests for project and platform statistics endpoints. """ import pytest import uuid import json from fastapi.testclient import TestClient from main import app from database import get_db_connection, init_database from services.jwt_service import JWTService import bcrypt # 测试客户端 client = TestClient(app) @pytest.fixture(scope="module") def setup_database(): """初始化测试数据库""" init_database() yield @pytest.fixture def admin_user(setup_database): """创建管理员用户""" admin_id = f"admin_{uuid.uuid4().hex[:8]}" password_hash = bcrypt.hashpw("admin123".encode(), bcrypt.gensalt()).decode() with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO users (id, username, email, password_hash, role) VALUES (?, ?, ?, ?, 'admin') """, (admin_id, f"test_admin_{admin_id}", f"admin_{admin_id}@test.com", password_hash)) user_data = { "id": admin_id, "username": f"test_admin_{admin_id}", "email": f"admin_{admin_id}@test.com", "role": "admin" } token = JWTService.create_access_token(user_data) yield {"token": token, "user_id": admin_id} # 清理 with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM users WHERE id = ?", (admin_id,)) @pytest.fixture def annotator_user(setup_database): """创建标注人员用户""" annotator_id = f"annotator_{uuid.uuid4().hex[:8]}" password_hash = bcrypt.hashpw("annotator123".encode(), bcrypt.gensalt()).decode() with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO users (id, username, email, password_hash, role) VALUES (?, ?, ?, ?, 'annotator') """, (annotator_id, f"test_annotator_{annotator_id}", f"annotator_{annotator_id}@test.com", password_hash)) user_data = { "id": annotator_id, "username": f"test_annotator_{annotator_id}", "email": f"annotator_{annotator_id}@test.com", "role": "annotator" } token = JWTService.create_access_token(user_data) yield {"token": token, "user_id": annotator_id} # 清理 with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM users WHERE id = ?", (annotator_id,)) @pytest.fixture def test_project_with_tasks(setup_database, annotator_user): """创建带任务的测试项目""" project_id = f"proj_{uuid.uuid4().hex[:8]}" with get_db_connection() as conn: cursor = conn.cursor() # 创建项目 cursor.execute(""" INSERT INTO projects (id, name, description, config) VALUES (?, ?, ?, ?) """, (project_id, "Test Project", "Test Description", "")) # 创建任务 task_ids = [] for i in range(5): task_id = f"task_{uuid.uuid4().hex[:8]}" status = ["pending", "in_progress", "completed"][i % 3] data = json.dumps({"items": [{"id": j} for j in range(3)]}) assigned_to = annotator_user["user_id"] if i < 3 else None cursor.execute(""" INSERT INTO tasks (id, project_id, name, data, status, assigned_to) VALUES (?, ?, ?, ?, ?, ?) """, (task_id, project_id, f"Test Task {i}", data, status, assigned_to)) task_ids.append(task_id) # 创建一些标注 for i, task_id in enumerate(task_ids[:2]): annotation_id = f"ann_{uuid.uuid4().hex[:8]}" cursor.execute(""" INSERT INTO annotations (id, task_id, user_id, result) VALUES (?, ?, ?, ?) """, (annotation_id, task_id, annotator_user["user_id"], json.dumps({"label": "test"}))) yield project_id # 清理 with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM projects WHERE id = ?", (project_id,)) class TestOverviewStatistics: """平台总览统计测试""" def test_overview_without_auth(self, setup_database): """未认证时应返回 401""" response = client.get("/api/statistics/overview") assert response.status_code == 401 def test_overview_as_annotator(self, annotator_user): """标注人员访问应返回 403""" headers = {"Authorization": f"Bearer {annotator_user['token']}"} response = client.get("/api/statistics/overview", headers=headers) assert response.status_code == 403 def test_overview_as_admin(self, admin_user): """管理员可以获取平台统计""" headers = {"Authorization": f"Bearer {admin_user['token']}"} response = client.get("/api/statistics/overview", headers=headers) assert response.status_code == 200 data = response.json() # 检查返回的字段 assert "total_projects" in data assert "total_tasks" in data assert "completed_tasks" in data assert "in_progress_tasks" in data assert "pending_tasks" in data assert "total_users" in data assert "admin_count" in data assert "annotator_count" in data assert "total_annotations" in data assert "overall_completion_rate" in data class TestProjectStatistics: """项目统计测试""" def test_project_stats_without_auth(self, setup_database, test_project_with_tasks): """未认证时应返回 401""" response = client.get(f"/api/statistics/projects/{test_project_with_tasks}") assert response.status_code == 401 def test_project_stats_as_annotator(self, annotator_user, test_project_with_tasks): """标注人员可以查看项目统计""" headers = {"Authorization": f"Bearer {annotator_user['token']}"} response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers) assert response.status_code == 200 def test_project_stats_as_admin(self, admin_user, test_project_with_tasks): """管理员可以获取项目统计""" headers = {"Authorization": f"Bearer {admin_user['token']}"} response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers) assert response.status_code == 200 data = response.json() # 检查返回的字段 assert data["project_id"] == test_project_with_tasks assert "project_name" in data assert "total_tasks" in data assert "completed_tasks" in data assert "in_progress_tasks" in data assert "pending_tasks" in data assert "total_items" in data assert "annotated_items" in data assert "task_completion_rate" in data assert "data_completion_rate" in data assert "user_stats" in data # 验证任务统计 assert data["total_tasks"] == 5 def test_project_stats_nonexistent(self, admin_user): """获取不存在项目的统计应返回 404""" headers = {"Authorization": f"Bearer {admin_user['token']}"} response = client.get("/api/statistics/projects/nonexistent_project", headers=headers) assert response.status_code == 404 def test_project_stats_user_breakdown(self, admin_user, test_project_with_tasks, annotator_user): """项目统计应包含用户分解""" headers = {"Authorization": f"Bearer {admin_user['token']}"} response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers) assert response.status_code == 200 data = response.json() # 检查用户统计 user_stats = data["user_stats"] assert isinstance(user_stats, list) # 应该有标注人员的统计 if len(user_stats) > 0: user_stat = user_stats[0] assert "user_id" in user_stat assert "username" in user_stat assert "assigned_tasks" in user_stat assert "completed_tasks" in user_stat assert "annotation_count" in user_stat assert "completion_rate" in user_stat class TestStatisticsAccuracy: """统计准确性测试""" def test_task_count_accuracy(self, admin_user, test_project_with_tasks): """任务数量统计应准确""" headers = {"Authorization": f"Bearer {admin_user['token']}"} response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers) assert response.status_code == 200 data = response.json() # 总任务数应等于各状态任务数之和 total = data["total_tasks"] sum_by_status = ( data["completed_tasks"] + data["in_progress_tasks"] + data["pending_tasks"] ) assert total == sum_by_status def test_completion_rate_calculation(self, admin_user, test_project_with_tasks): """完成率计算应正确""" headers = {"Authorization": f"Bearer {admin_user['token']}"} response = client.get(f"/api/statistics/projects/{test_project_with_tasks}", headers=headers) assert response.status_code == 200 data = response.json() # 验证完成率计算 if data["total_tasks"] > 0: expected_rate = round(data["completed_tasks"] / data["total_tasks"] * 100, 2) assert data["task_completion_rate"] == expected_rate