| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- """
- Authentication Middleware for JWT token verification.
- Validates JWT tokens and attaches user info to request state.
- """
- from fastapi import Request, HTTPException, status
- from fastapi.responses import JSONResponse
- from starlette.middleware.base import BaseHTTPMiddleware
- from services.jwt_service import JWTService
- import jwt
- class AuthMiddleware(BaseHTTPMiddleware):
- """
- Authentication middleware for JWT token verification.
- Validates JWT tokens and attaches user info to request state.
- """
-
- # Public endpoints that don't require authentication
- PUBLIC_PATHS = {
- "/",
- "/health",
- "/docs",
- "/openapi.json",
- "/redoc",
- "/api/auth/register",
- "/api/auth/login",
- "/api/auth/refresh",
- "/api/oauth/status",
- "/api/oauth/login",
- "/api/oauth/callback"
- }
-
- async def dispatch(self, request: Request, call_next):
- """
- Process each request through authentication.
-
- Args:
- request: FastAPI Request object
- call_next: Next middleware or route handler
-
- Returns:
- Response from next handler or error response
- """
- # Skip authentication for public paths
- if request.url.path in self.PUBLIC_PATHS:
- return await call_next(request)
-
- # Skip authentication for OPTIONS requests (CORS preflight)
- if request.method == "OPTIONS":
- return await call_next(request)
-
- # Extract token from Authorization header
- auth_header = request.headers.get("Authorization")
-
- if not auth_header:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "缺少认证令牌",
- "error_type": "missing_token"
- }
- )
-
- # Verify Bearer token format
- parts = auth_header.split()
- if len(parts) != 2 or parts[0].lower() != "bearer":
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "无效的认证令牌格式",
- "error_type": "invalid_token_format"
- }
- )
-
- token = parts[1]
-
- try:
- # Verify and decode token
- payload = JWTService.verify_token(token, "access")
-
- if not payload:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "无效的认证令牌",
- "error_type": "invalid_token"
- }
- )
-
- # Attach user info to request state
- request.state.user = {
- "id": payload["sub"],
- "username": payload["username"],
- "email": payload["email"],
- "role": payload["role"]
- }
-
- # Continue to route handler
- response = await call_next(request)
- return response
-
- except jwt.ExpiredSignatureError:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "认证令牌已过期",
- "error_type": "token_expired"
- }
- )
- except jwt.InvalidTokenError:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "无效的认证令牌",
- "error_type": "invalid_token"
- }
- )
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={
- "detail": "认证过程发生错误",
- "error_type": "auth_error"
- }
- )
- def require_role(*allowed_roles: str):
- """
- Decorator to check user role.
-
- Usage:
- @require_role("admin", "annotator")
- async def my_endpoint(request: Request):
- ...
-
- Args:
- allowed_roles: Tuple of allowed role names
-
- Returns:
- Decorator function
- """
- def decorator(func):
- async def wrapper(request: Request, *args, **kwargs):
- user = getattr(request.state, "user", None)
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="未认证"
- )
-
- if user["role"] not in allowed_roles:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="权限不足"
- )
-
- return await func(request, *args, **kwargs)
- return wrapper
- return decorator
|