auth_middleware.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """
  2. Authentication Middleware for SSO token verification.
  3. Validates SSO tokens via the SSO center's userinfo endpoint,
  4. with an in-memory cache to reduce external calls.
  5. """
  6. import logging
  7. from fastapi import Request, HTTPException, status
  8. from fastapi.responses import JSONResponse
  9. from starlette.middleware.base import BaseHTTPMiddleware
  10. from services.token_cache_service import TokenCacheService
  11. from services.oauth_service import OAuthService
  12. from config import settings
  13. logger = logging.getLogger(__name__)
  14. # 全局 token 缓存实例
  15. # SSO token 有效期 600 秒,缓存设置为 550 秒(留 50 秒余量)
  16. token_cache = TokenCacheService(
  17. ttl_seconds=getattr(settings, 'TOKEN_CACHE_TTL', 550)
  18. )
  19. class AuthMiddleware(BaseHTTPMiddleware):
  20. """
  21. SSO Token 认证中间件。
  22. 先查本地缓存,未命中则调用 SSO userinfo 端点验证。
  23. """
  24. PUBLIC_PATHS = {
  25. "/",
  26. "/health",
  27. "/docs",
  28. "/openapi.json",
  29. "/redoc",
  30. "/api/oauth/status",
  31. "/api/oauth/login",
  32. "/api/oauth/callback",
  33. "/api/oauth/refresh",
  34. }
  35. async def dispatch(self, request: Request, call_next):
  36. # Skip authentication for public paths
  37. logger.debug(f"AuthMiddleware: path={request.url.path}, method={request.method}")
  38. if request.url.path in self.PUBLIC_PATHS:
  39. logger.debug(f"Skipping auth for public path: {request.url.path}")
  40. return await call_next(request)
  41. # Skip authentication for OPTIONS requests (CORS preflight)
  42. if request.method == "OPTIONS":
  43. return await call_next(request)
  44. # Check if OAuth/SSO is enabled
  45. if not settings.OAUTH_ENABLED:
  46. return JSONResponse(
  47. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  48. content={
  49. "detail": "SSO 认证未配置",
  50. "error_type": "sso_not_configured"
  51. }
  52. )
  53. # Extract token from Authorization header
  54. auth_header = request.headers.get("Authorization")
  55. if not auth_header:
  56. return JSONResponse(
  57. status_code=status.HTTP_401_UNAUTHORIZED,
  58. content={
  59. "detail": "缺少认证令牌",
  60. "error_type": "missing_token"
  61. }
  62. )
  63. # Verify Bearer token format
  64. parts = auth_header.split()
  65. if len(parts) != 2 or parts[0].lower() != "bearer":
  66. return JSONResponse(
  67. status_code=status.HTTP_401_UNAUTHORIZED,
  68. content={
  69. "detail": "无效的认证令牌格式",
  70. "error_type": "invalid_token_format"
  71. }
  72. )
  73. sso_token = parts[1]
  74. try:
  75. # 1. 先查本地缓存
  76. user_info = token_cache.get(sso_token)
  77. if user_info is None:
  78. # 2. 缓存未命中,调 SSO profile 验证(含角色信息)
  79. user_info = await OAuthService.verify_sso_token(sso_token)
  80. # 3. 同步用户到本地数据库(更新角色),获取本地用户ID
  81. try:
  82. local_user = OAuthService.sync_user_from_oauth(oauth_user_info=user_info)
  83. # 将本地user.id也存入user_info,供后续使用
  84. user_info["local_user_id"] = local_user.id
  85. except Exception as sync_err:
  86. logger.warning(f"用户同步失败(不影响认证): {sync_err}")
  87. # 4. 写入缓存
  88. token_cache.set(sso_token, user_info)
  89. # 提取用户信息,优先使用本地用户ID
  90. # 如果缓存中没有 local_user_id,则重新同步用户
  91. local_user_id = user_info.get("local_user_id")
  92. if not local_user_id:
  93. try:
  94. local_user = OAuthService.sync_user_from_oauth(oauth_user_info=user_info)
  95. local_user_id = local_user.id
  96. # 更新缓存,避免下次重复同步
  97. user_info["local_user_id"] = local_user_id
  98. token_cache.set(sso_token, user_info)
  99. except Exception as sync_err:
  100. logger.warning(f"重新同步用户失败,使用SSO ID: {sync_err}")
  101. user_id = local_user_id or user_info.get("id") or user_info.get("sub")
  102. username = (
  103. user_info.get("username")
  104. or user_info.get("preferred_username")
  105. or user_info.get("name")
  106. )
  107. email = user_info.get("email", "")
  108. role = user_info.get("role", "viewer")
  109. # Attach user info to request state
  110. request.state.user = {
  111. "id": str(user_id),
  112. "username": username,
  113. "email": email,
  114. "role": role,
  115. }
  116. response = await call_next(request)
  117. return response
  118. except HTTPException as e:
  119. error_type = "invalid_token"
  120. if e.status_code == 503:
  121. error_type = "sso_unavailable"
  122. elif e.status_code == 401:
  123. # SSO 返回 401 说明 token 过期或无效,统一标记为 token_expired
  124. # 让前端有机会用 refresh_token 刷新
  125. error_type = "token_expired"
  126. # 同时清除本地缓存中的过期 token
  127. token_cache.invalidate(sso_token)
  128. return JSONResponse(
  129. status_code=e.status_code,
  130. content={
  131. "detail": e.detail,
  132. "error_type": error_type
  133. }
  134. )
  135. except Exception as e:
  136. logger.error("认证过程发生错误:%s", str(e))
  137. return JSONResponse(
  138. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  139. content={
  140. "detail": "认证过程发生错误",
  141. "error_type": "auth_error"
  142. }
  143. )
  144. def require_role(*allowed_roles: str):
  145. """
  146. Decorator to check user role.
  147. Usage:
  148. @require_role("admin", "annotator")
  149. async def my_endpoint(request: Request):
  150. ...
  151. Args:
  152. allowed_roles: Tuple of allowed role names
  153. Returns:
  154. Decorator function
  155. """
  156. def decorator(func):
  157. async def wrapper(request: Request, *args, **kwargs):
  158. user = getattr(request.state, "user", None)
  159. if not user:
  160. raise HTTPException(
  161. status_code=status.HTTP_401_UNAUTHORIZED,
  162. detail="未认证"
  163. )
  164. if user["role"] not in allowed_roles:
  165. raise HTTPException(
  166. status_code=status.HTTP_403_FORBIDDEN,
  167. detail="权限不足"
  168. )
  169. return await func(request, *args, **kwargs)
  170. return wrapper
  171. return decorator