""" Authentication Middleware for JWT token verification. Validates JWT tokens and attaches user info to request state. """ from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from services.jwt_service import JWTService import jwt class AuthMiddleware(BaseHTTPMiddleware): """ Authentication middleware for JWT token verification. Validates JWT tokens and attaches user info to request state. """ # Public endpoints that don't require authentication PUBLIC_PATHS = { "/", "/health", "/docs", "/openapi.json", "/redoc", "/api/auth/register", "/api/auth/login", "/api/auth/refresh", "/api/oauth/status", "/api/oauth/login", "/api/oauth/callback" } async def dispatch(self, request: Request, call_next): """ Process each request through authentication. Args: request: FastAPI Request object call_next: Next middleware or route handler Returns: Response from next handler or error response """ # Skip authentication for public paths if request.url.path in self.PUBLIC_PATHS: return await call_next(request) # Skip authentication for OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return await call_next(request) # Extract token from Authorization header auth_header = request.headers.get("Authorization") if not auth_header: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "缺少认证令牌", "error_type": "missing_token" } ) # Verify Bearer token format parts = auth_header.split() if len(parts) != 2 or parts[0].lower() != "bearer": return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "无效的认证令牌格式", "error_type": "invalid_token_format" } ) token = parts[1] try: # Verify and decode token payload = JWTService.verify_token(token, "access") if not payload: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "无效的认证令牌", "error_type": "invalid_token" } ) # Attach user info to request state request.state.user = { "id": payload["sub"], "username": payload["username"], "email": payload["email"], "role": payload["role"] } # Continue to route handler response = await call_next(request) return response except jwt.ExpiredSignatureError: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "认证令牌已过期", "error_type": "token_expired" } ) except jwt.InvalidTokenError: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "无效的认证令牌", "error_type": "invalid_token" } ) except Exception as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={ "detail": "认证过程发生错误", "error_type": "auth_error" } ) def require_role(*allowed_roles: str): """ Decorator to check user role. Usage: @require_role("admin", "annotator") async def my_endpoint(request: Request): ... Args: allowed_roles: Tuple of allowed role names Returns: Decorator function """ def decorator(func): async def wrapper(request: Request, *args, **kwargs): user = getattr(request.state, "user", None) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未认证" ) if user["role"] not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="权限不足" ) return await func(request, *args, **kwargs) return wrapper return decorator