""" User API router. Provides endpoints for user management and statistics. """ from typing import List, Optional from fastapi import APIRouter, HTTPException, status, Query, Request from database import get_db_connection from schemas.user import ( UserResponse, UserWithStatsResponse, UserListResponse, UserStatsResponse, AssignableUserResponse, TaskStats ) router = APIRouter( prefix="/api/users", tags=["users"] ) def require_admin(request: Request) -> dict: """ 验证当前用户是否为管理员。 Args: request: FastAPI Request 对象 Returns: 当前用户信息 Raises: HTTPException: 401 未认证或 403 权限不足 """ user = getattr(request.state, "user", None) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未认证" ) if user["role"] != "admin": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有管理员可以访问此接口" ) return user def get_user_task_stats(cursor, user_id: str) -> TaskStats: """ 获取用户的任务统计信息。 Args: cursor: 数据库游标 user_id: 用户ID Returns: TaskStats 对象 """ # 查询任务统计 cursor.execute(""" SELECT COUNT(*) as assigned_count, SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed_count, SUM(CASE WHEN status = 'in_progress' THEN 1 ELSE 0 END) as in_progress_count, SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending_count FROM tasks WHERE assigned_to = ? """, (user_id,)) task_row = cursor.fetchone() assigned_count = task_row["assigned_count"] or 0 completed_count = task_row["completed_count"] or 0 in_progress_count = task_row["in_progress_count"] or 0 pending_count = task_row["pending_count"] or 0 # 查询标注数量 cursor.execute(""" SELECT COUNT(*) as annotation_count FROM annotations WHERE user_id = ? """, (user_id,)) annotation_row = cursor.fetchone() annotation_count = annotation_row["annotation_count"] or 0 # 计算完成率 completion_rate = 0.0 if assigned_count > 0: completion_rate = round(completed_count / assigned_count * 100, 2) return TaskStats( assigned_count=assigned_count, completed_count=completed_count, in_progress_count=in_progress_count, pending_count=pending_count, annotation_count=annotation_count, completion_rate=completion_rate ) @router.get("", response_model=UserListResponse) async def list_users( request: Request, role: Optional[str] = Query(None, description="按角色筛选"), search: Optional[str] = Query(None, description="按用户名或邮箱搜索"), skip: int = Query(0, ge=0, description="跳过记录数"), limit: int = Query(50, ge=1, le=100, description="返回记录数") ): """ 获取用户列表(管理员权限)。 支持按角色筛选和关键词搜索。 Args: request: FastAPI Request 对象 role: 角色筛选(admin/annotator/viewer) search: 搜索关键词(用户名或邮箱) skip: 分页偏移 limit: 每页数量 Returns: 用户列表和总数 """ require_admin(request) with get_db_connection() as conn: cursor = conn.cursor() # 构建查询条件 where_clauses = [] params = [] if role: where_clauses.append("role = ?") params.append(role) if search: where_clauses.append("(username LIKE ? OR email LIKE ?)") search_pattern = f"%{search}%" params.extend([search_pattern, search_pattern]) where_sql = "" if where_clauses: where_sql = "WHERE " + " AND ".join(where_clauses) # 查询总数 count_sql = f"SELECT COUNT(*) as total FROM users {where_sql}" cursor.execute(count_sql, tuple(params)) total = cursor.fetchone()["total"] # 查询用户列表 query_sql = f""" SELECT id, username, email, role, created_at FROM users {where_sql} ORDER BY created_at DESC LIMIT ? OFFSET ? """ params.extend([limit, skip]) cursor.execute(query_sql, tuple(params)) rows = cursor.fetchall() users = [] for row in rows: # 获取每个用户的任务统计 task_stats = get_user_task_stats(cursor, row["id"]) users.append(UserWithStatsResponse( id=row["id"], username=row["username"], email=row["email"], role=row["role"], created_at=row["created_at"], task_stats=task_stats )) return UserListResponse(users=users, total=total) @router.get("/annotators", response_model=List[AssignableUserResponse]) async def list_annotators( request: Request, search: Optional[str] = Query(None, description="按用户名或邮箱搜索") ): """ 获取可分配任务的用户列表。 返回所有标注人员(annotator 角色)及其当前工作量。 Args: request: FastAPI Request 对象 search: 搜索关键词 Returns: 可分配用户列表 """ # 验证用户已登录 user = getattr(request.state, "user", None) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未认证" ) with get_db_connection() as conn: cursor = conn.cursor() # 构建查询 where_clauses = ["role = 'annotator'"] params = [] if search: where_clauses.append("(username LIKE ? OR email LIKE ?)") search_pattern = f"%{search}%" params.extend([search_pattern, search_pattern]) where_sql = "WHERE " + " AND ".join(where_clauses) # 查询用户及其任务统计 query_sql = f""" SELECT u.id, u.username, u.email, COUNT(CASE WHEN t.status != 'completed' THEN 1 END) as current_task_count, COUNT(CASE WHEN t.status = 'completed' THEN 1 END) as completed_task_count FROM users u LEFT JOIN tasks t ON u.id = t.assigned_to {where_sql} GROUP BY u.id, u.username, u.email ORDER BY u.username """ cursor.execute(query_sql, tuple(params)) rows = cursor.fetchall() annotators = [] for row in rows: annotators.append(AssignableUserResponse( id=row["id"], username=row["username"], email=row["email"], current_task_count=row["current_task_count"] or 0, completed_task_count=row["completed_task_count"] or 0 )) return annotators @router.get("/{user_id}", response_model=UserWithStatsResponse) async def get_user(request: Request, user_id: str): """ 获取用户详情(管理员权限)。 Args: request: FastAPI Request 对象 user_id: 用户ID Returns: 用户详情及统计信息 """ require_admin(request) with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT id, username, email, role, created_at FROM users WHERE id = ? """, (user_id,)) row = cursor.fetchone() if not row: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"用户 '{user_id}' 不存在" ) task_stats = get_user_task_stats(cursor, user_id) return UserWithStatsResponse( id=row["id"], username=row["username"], email=row["email"], role=row["role"], created_at=row["created_at"], task_stats=task_stats ) @router.get("/{user_id}/stats", response_model=UserStatsResponse) async def get_user_stats(request: Request, user_id: str): """ 获取用户详细统计信息(管理员权限)。 包含用户基本信息、任务统计和最近任务列表。 Args: request: FastAPI Request 对象 user_id: 用户ID Returns: 用户详细统计信息 """ require_admin(request) with get_db_connection() as conn: cursor = conn.cursor() # 查询用户信息 cursor.execute(""" SELECT id, username, email, role, created_at FROM users WHERE id = ? """, (user_id,)) row = cursor.fetchone() if not row: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"用户 '{user_id}' 不存在" ) user = UserResponse( id=row["id"], username=row["username"], email=row["email"], role=row["role"], created_at=row["created_at"] ) # 获取任务统计 task_stats = get_user_task_stats(cursor, user_id) # 获取最近任务 cursor.execute(""" SELECT t.id, t.name, t.status, t.created_at, p.name as project_name FROM tasks t LEFT JOIN projects p ON t.project_id = p.id WHERE t.assigned_to = ? ORDER BY t.created_at DESC LIMIT 10 """, (user_id,)) recent_tasks = [] for task_row in cursor.fetchall(): recent_tasks.append({ "id": task_row["id"], "name": task_row["name"], "status": task_row["status"], "created_at": str(task_row["created_at"]), "project_name": task_row["project_name"] }) return UserStatsResponse( user=user, task_stats=task_stats, recent_tasks=recent_tasks )