auth_middleware.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. """
  2. Authentication Middleware for SSO token verification.
  3. Validates SSO tokens via the SSO center's userinfo endpoint,
  4. with an in-memory cache to reduce external calls.
  5. """
  6. import logging
  7. from fastapi import Request, HTTPException, status
  8. from fastapi.responses import JSONResponse
  9. from starlette.middleware.base import BaseHTTPMiddleware
  10. from services.token_cache_service import TokenCacheService
  11. from services.oauth_service import OAuthService
  12. from config import settings
  13. logger = logging.getLogger(__name__)
  14. # 全局 token 缓存实例
  15. # SSO token 有效期 600 秒,缓存设置为 550 秒(留 50 秒余量)
  16. token_cache = TokenCacheService(
  17. ttl_seconds=getattr(settings, 'TOKEN_CACHE_TTL', 550)
  18. )
  19. class AuthMiddleware(BaseHTTPMiddleware):
  20. """
  21. SSO Token 认证中间件。
  22. 先查本地缓存,未命中则调用 SSO userinfo 端点验证。
  23. """
  24. PUBLIC_PATHS = {
  25. "/",
  26. "/health",
  27. "/docs",
  28. "/openapi.json",
  29. "/redoc",
  30. "/api/oauth/status",
  31. "/api/oauth/login",
  32. "/api/oauth/callback",
  33. "/api/oauth/refresh",
  34. }
  35. async def dispatch(self, request: Request, call_next):
  36. # Skip authentication for public paths
  37. logger.debug(f"AuthMiddleware: path={request.url.path}, method={request.method}")
  38. if request.url.path in self.PUBLIC_PATHS:
  39. logger.debug(f"Skipping auth for public path: {request.url.path}")
  40. return await call_next(request)
  41. # Skip authentication for OPTIONS requests (CORS preflight)
  42. if request.method == "OPTIONS":
  43. return await call_next(request)
  44. # Check if OAuth/SSO is enabled
  45. if not settings.OAUTH_ENABLED:
  46. return JSONResponse(
  47. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  48. content={
  49. "detail": "SSO 认证未配置",
  50. "error_type": "sso_not_configured"
  51. }
  52. )
  53. # Extract token from Authorization header
  54. auth_header = request.headers.get("Authorization")
  55. if not auth_header:
  56. return JSONResponse(
  57. status_code=status.HTTP_401_UNAUTHORIZED,
  58. content={
  59. "detail": "缺少认证令牌",
  60. "error_type": "missing_token"
  61. }
  62. )
  63. # Verify Bearer token format
  64. parts = auth_header.split()
  65. if len(parts) != 2 or parts[0].lower() != "bearer":
  66. return JSONResponse(
  67. status_code=status.HTTP_401_UNAUTHORIZED,
  68. content={
  69. "detail": "无效的认证令牌格式",
  70. "error_type": "invalid_token_format"
  71. }
  72. )
  73. sso_token = parts[1]
  74. try:
  75. # 1. 先查本地缓存
  76. user_info = token_cache.get(sso_token)
  77. if user_info is None:
  78. # 2. 缓存未命中,调 SSO profile 验证(含角色信息)
  79. user_info = await OAuthService.verify_sso_token(sso_token)
  80. # 3. 同步用户到本地数据库(更新角色)
  81. try:
  82. OAuthService.sync_user_from_oauth(user_info)
  83. except Exception as sync_err:
  84. logger.warning(f"用户同步失败(不影响认证): {sync_err}")
  85. # 4. 写入缓存
  86. token_cache.set(sso_token, user_info)
  87. # 提取用户信息
  88. user_id = user_info.get("id") or user_info.get("sub")
  89. username = (
  90. user_info.get("username")
  91. or user_info.get("preferred_username")
  92. or user_info.get("name")
  93. )
  94. email = user_info.get("email", "")
  95. role = user_info.get("role", "viewer")
  96. # Attach user info to request state
  97. request.state.user = {
  98. "id": str(user_id),
  99. "username": username,
  100. "email": email,
  101. "role": role,
  102. }
  103. response = await call_next(request)
  104. return response
  105. except HTTPException as e:
  106. error_type = "invalid_token"
  107. if e.status_code == 503:
  108. error_type = "sso_unavailable"
  109. elif e.status_code == 401:
  110. # SSO 返回 401 说明 token 过期或无效,统一标记为 token_expired
  111. # 让前端有机会用 refresh_token 刷新
  112. error_type = "token_expired"
  113. # 同时清除本地缓存中的过期 token
  114. token_cache.invalidate(sso_token)
  115. return JSONResponse(
  116. status_code=e.status_code,
  117. content={
  118. "detail": e.detail,
  119. "error_type": error_type
  120. }
  121. )
  122. except Exception as e:
  123. logger.error(f"认证过程发生错误: {e}")
  124. return JSONResponse(
  125. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  126. content={
  127. "detail": "认证过程发生错误",
  128. "error_type": "auth_error"
  129. }
  130. )
  131. def require_role(*allowed_roles: str):
  132. """
  133. Decorator to check user role.
  134. Usage:
  135. @require_role("admin", "annotator")
  136. async def my_endpoint(request: Request):
  137. ...
  138. Args:
  139. allowed_roles: Tuple of allowed role names
  140. Returns:
  141. Decorator function
  142. """
  143. def decorator(func):
  144. async def wrapper(request: Request, *args, **kwargs):
  145. user = getattr(request.state, "user", None)
  146. if not user:
  147. raise HTTPException(
  148. status_code=status.HTTP_401_UNAUTHORIZED,
  149. detail="未认证"
  150. )
  151. if user["role"] not in allowed_roles:
  152. raise HTTPException(
  153. status_code=status.HTTP_403_FORBIDDEN,
  154. detail="权限不足"
  155. )
  156. return await func(request, *args, **kwargs)
  157. return wrapper
  158. return decorator