| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- """
- Authentication Middleware for JWT verification.
- Validates locally-signed JWT tokens.
- Also supports admin tokens generated by generate_admin_token.py script.
- """
- import logging
- from datetime import datetime, timezone
- from fastapi import Request, HTTPException, status
- from fastapi.responses import JSONResponse
- from starlette.middleware.base import BaseHTTPMiddleware
- from services import jwt_service
- from config import settings
- from database import get_db_connection
- logger = logging.getLogger(__name__)
- def verify_admin_token(token: str) -> dict:
- """
- 验证管理员 Token(从数据库查询)
- """
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(
- """
- SELECT at.user_id, u.username, u.email, u.role, at.expires_at
- FROM admin_tokens at
- JOIN users u ON at.user_id = u.id
- WHERE at.token = %s AND at.expires_at > %s
- """,
- (token, datetime.now(timezone.utc)),
- )
- row = cursor.fetchone()
- if not row:
- return None
- return {
- "id": row["user_id"],
- "username": row["username"],
- "email": row["email"],
- "role": row["role"],
- "is_admin_token": True,
- }
- except Exception as e:
- logger.error(f"验证管理员 Token 失败:{e}")
- return None
- class AuthMiddleware(BaseHTTPMiddleware):
- """
- JWT 认证中间件。
- 验证本地签发的 JWT,或管理员 Token。
- """
- PUBLIC_PATHS = {
- "/",
- "/health",
- "/docs",
- "/openapi.json",
- "/redoc",
- "/api/oauth/status",
- "/api/oauth/login",
- "/api/oauth/exchange-code",
- "/api/oauth/refresh",
- "/api/oauth/logout",
- }
- OPEN_API_PREFIX = "/api/v1/open/"
- async def dispatch(self, request: Request, call_next):
- # Skip authentication for public paths
- if request.url.path in self.PUBLIC_PATHS:
- return await call_next(request)
- # Skip authentication for Open API routes (they use their own Bearer token check)
- if request.url.path.startswith(self.OPEN_API_PREFIX):
- return await call_next(request)
- # Skip authentication for OPTIONS requests (CORS preflight)
- if request.method == "OPTIONS":
- return await call_next(request)
- # Extract token from Authorization header
- auth_header = request.headers.get("Authorization")
- if not auth_header:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "缺少认证令牌",
- "error_type": "missing_token",
- },
- )
- # Verify Bearer token format
- parts = auth_header.split()
- if len(parts) != 2 or parts[0].lower() != "bearer":
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "无效的认证令牌格式",
- "error_type": "invalid_token_format",
- },
- )
- token = parts[1]
- try:
- # 1. 先检查是否是管理员 Token(以 admin_token_ 开头)
- user_info = None
- if token.startswith("admin_token_"):
- user_info = verify_admin_token(token)
- if user_info:
- logger.info(f"管理员 Token 验证成功:{user_info['username']}")
- # 2. 如果不是管理员 Token,验证 JWT
- if user_info is None:
- payload = jwt_service.verify_token(token)
- role = payload.get("role", "")
- if role not in ("admin", "annotator", "viewer"):
- return JSONResponse(
- status_code=status.HTTP_403_FORBIDDEN,
- content={
- "detail": "未被识别的 SSO 角色,无权限访问",
- "error_type": "unrecognized_role",
- },
- )
- user_info = {
- "id": payload.get("sub"),
- "username": payload.get("username"),
- "email": payload.get("email", ""),
- "role": role,
- }
- # Attach user info to request state
- request.state.user = {
- "id": str(user_info["id"]),
- "username": user_info["username"],
- "email": user_info["email"],
- "role": user_info["role"],
- }
- response = await call_next(request)
- return response
- except HTTPException as e:
- error_type = "invalid_token"
- if e.status_code == 401:
- error_type = "token_expired"
- return JSONResponse(
- status_code=e.status_code,
- content={
- "detail": e.detail,
- "error_type": error_type,
- },
- )
- except Exception as e:
- logger.error("认证过程发生错误:%s", str(e))
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={
- "detail": "认证过程发生错误",
- "error_type": "auth_error",
- },
- )
- def require_role(*allowed_roles: str):
- """
- Decorator to check user role.
- """
- def decorator(func):
- async def wrapper(request: Request, *args, **kwargs):
- user = getattr(request.state, "user", None)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="未认证",
- )
- if user["role"] not in allowed_roles:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="权限不足",
- )
- return await func(request, *args, **kwargs)
- return wrapper
- return decorator
|