""" Authentication Middleware for JWT verification. Validates locally-signed JWT tokens. Also supports admin tokens generated by generate_admin_token.py script. """ import logging from datetime import datetime, timezone from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from services import jwt_service from config import settings from database import get_db_connection logger = logging.getLogger(__name__) def verify_admin_token(token: str) -> dict: """ 验证管理员 Token(从数据库查询) """ try: with get_db_connection() as conn: cursor = conn.cursor() cursor.execute( """ SELECT at.user_id, u.username, u.email, u.role, at.expires_at FROM admin_tokens at JOIN users u ON at.user_id = u.id WHERE at.token = %s AND at.expires_at > %s """, (token, datetime.now(timezone.utc)), ) row = cursor.fetchone() if not row: return None return { "id": row["user_id"], "username": row["username"], "email": row["email"], "role": row["role"], "is_admin_token": True, } except Exception as e: logger.error(f"验证管理员 Token 失败:{e}") return None class AuthMiddleware(BaseHTTPMiddleware): """ JWT 认证中间件。 验证本地签发的 JWT,或管理员 Token。 """ PUBLIC_PATHS = { "/", "/health", "/docs", "/openapi.json", "/redoc", "/api/oauth/status", "/api/oauth/login", "/api/oauth/exchange-code", "/api/oauth/refresh", "/api/oauth/logout", } OPEN_API_PREFIX = "/api/v1/open/" async def dispatch(self, request: Request, call_next): # Skip authentication for public paths if request.url.path in self.PUBLIC_PATHS: return await call_next(request) # Skip authentication for Open API routes (they use their own Bearer token check) if request.url.path.startswith(self.OPEN_API_PREFIX): return await call_next(request) # Skip authentication for OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return await call_next(request) # Extract token from Authorization header auth_header = request.headers.get("Authorization") if not auth_header: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "缺少认证令牌", "error_type": "missing_token", }, ) # Verify Bearer token format parts = auth_header.split() if len(parts) != 2 or parts[0].lower() != "bearer": return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "无效的认证令牌格式", "error_type": "invalid_token_format", }, ) token = parts[1] try: # 1. 先检查是否是管理员 Token(以 admin_token_ 开头) user_info = None if token.startswith("admin_token_"): user_info = verify_admin_token(token) if user_info: logger.info(f"管理员 Token 验证成功:{user_info['username']}") # 2. 如果不是管理员 Token,验证 JWT if user_info is None: payload = jwt_service.verify_token(token) role = payload.get("role", "") if role not in ("admin", "annotator", "viewer"): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={ "detail": "未被识别的 SSO 角色,无权限访问", "error_type": "unrecognized_role", }, ) user_info = { "id": payload.get("sub"), "username": payload.get("username"), "email": payload.get("email", ""), "role": role, } # Attach user info to request state request.state.user = { "id": str(user_info["id"]), "username": user_info["username"], "email": user_info["email"], "role": user_info["role"], } response = await call_next(request) return response except HTTPException as e: error_type = "invalid_token" if e.status_code == 401: error_type = "token_expired" return JSONResponse( status_code=e.status_code, content={ "detail": e.detail, "error_type": error_type, }, ) except Exception as e: logger.error("认证过程发生错误:%s", str(e)) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={ "detail": "认证过程发生错误", "error_type": "auth_error", }, ) def require_role(*allowed_roles: str): """ Decorator to check user role. """ def decorator(func): async def wrapper(request: Request, *args, **kwargs): user = getattr(request.state, "user", None) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未认证", ) if user["role"] not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="权限不足", ) return await func(request, *args, **kwargs) return wrapper return decorator