| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- """
- Authentication Middleware for SSO token verification.
- Validates SSO tokens via the SSO center's userinfo endpoint,
- with an in-memory cache to reduce external calls.
- """
- import logging
- from fastapi import Request, HTTPException, status
- from fastapi.responses import JSONResponse
- from starlette.middleware.base import BaseHTTPMiddleware
- from services.token_cache_service import TokenCacheService
- from services.oauth_service import OAuthService
- from config import settings
- logger = logging.getLogger(__name__)
- # 全局 token 缓存实例
- # SSO token 有效期 600 秒,缓存设置为 550 秒(留 50 秒余量)
- token_cache = TokenCacheService(
- ttl_seconds=getattr(settings, 'TOKEN_CACHE_TTL', 550)
- )
- class AuthMiddleware(BaseHTTPMiddleware):
- """
- SSO Token 认证中间件。
- 先查本地缓存,未命中则调用 SSO userinfo 端点验证。
- """
- PUBLIC_PATHS = {
- "/",
- "/health",
- "/docs",
- "/openapi.json",
- "/redoc",
- "/api/oauth/status",
- "/api/oauth/login",
- "/api/oauth/callback",
- "/api/oauth/refresh",
- }
- async def dispatch(self, request: Request, call_next):
- # Skip authentication for public paths
- logger.debug(f"AuthMiddleware: path={request.url.path}, method={request.method}")
- if request.url.path in self.PUBLIC_PATHS:
- logger.debug(f"Skipping auth for public path: {request.url.path}")
- return await call_next(request)
- # Skip authentication for OPTIONS requests (CORS preflight)
- if request.method == "OPTIONS":
- return await call_next(request)
- # Check if OAuth/SSO is enabled
- if not settings.OAUTH_ENABLED:
- return JSONResponse(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- content={
- "detail": "SSO 认证未配置",
- "error_type": "sso_not_configured"
- }
- )
- # Extract token from Authorization header
- auth_header = request.headers.get("Authorization")
- if not auth_header:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "缺少认证令牌",
- "error_type": "missing_token"
- }
- )
- # Verify Bearer token format
- parts = auth_header.split()
- if len(parts) != 2 or parts[0].lower() != "bearer":
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={
- "detail": "无效的认证令牌格式",
- "error_type": "invalid_token_format"
- }
- )
- sso_token = parts[1]
- try:
- # 1. 先查本地缓存
- user_info = token_cache.get(sso_token)
- if user_info is None:
- # 2. 缓存未命中,调 SSO profile 验证(含角色信息)
- user_info = await OAuthService.verify_sso_token(sso_token)
- # 3. 同步用户到本地数据库(更新角色)
- try:
- OAuthService.sync_user_from_oauth(user_info)
- except Exception as sync_err:
- logger.warning(f"用户同步失败(不影响认证): {sync_err}")
- # 4. 写入缓存
- token_cache.set(sso_token, user_info)
- # 提取用户信息
- user_id = user_info.get("id") or user_info.get("sub")
- username = (
- user_info.get("username")
- or user_info.get("preferred_username")
- or user_info.get("name")
- )
- email = user_info.get("email", "")
- role = user_info.get("role", "viewer")
- # Attach user info to request state
- request.state.user = {
- "id": str(user_id),
- "username": username,
- "email": email,
- "role": role,
- }
- response = await call_next(request)
- return response
- except HTTPException as e:
- error_type = "invalid_token"
- if e.status_code == 503:
- error_type = "sso_unavailable"
- elif e.status_code == 401:
- # SSO 返回 401 说明 token 过期或无效,统一标记为 token_expired
- # 让前端有机会用 refresh_token 刷新
- error_type = "token_expired"
- # 同时清除本地缓存中的过期 token
- token_cache.invalidate(sso_token)
- return JSONResponse(
- status_code=e.status_code,
- content={
- "detail": e.detail,
- "error_type": error_type
- }
- )
- except Exception as e:
- logger.error(f"认证过程发生错误: {e}")
- return JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content={
- "detail": "认证过程发生错误",
- "error_type": "auth_error"
- }
- )
- def require_role(*allowed_roles: str):
- """
- Decorator to check user role.
- Usage:
- @require_role("admin", "annotator")
- async def my_endpoint(request: Request):
- ...
- Args:
- allowed_roles: Tuple of allowed role names
- Returns:
- Decorator function
- """
- def decorator(func):
- async def wrapper(request: Request, *args, **kwargs):
- user = getattr(request.state, "user", None)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="未认证"
- )
- if user["role"] not in allowed_roles:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="权限不足"
- )
- return await func(request, *args, **kwargs)
- return wrapper
- return decorator
|