auth_middleware.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. Authentication Middleware for SSO token verification.
  3. Validates SSO tokens via the SSO center's userinfo endpoint,
  4. with an in-memory cache to reduce external calls.
  5. Also supports admin tokens generated by generate_admin_token.py script.
  6. """
  7. import logging
  8. from datetime import datetime, timezone
  9. from fastapi import Request, HTTPException, status
  10. from fastapi.responses import JSONResponse
  11. from starlette.middleware.base import BaseHTTPMiddleware
  12. from services.token_cache_service import TokenCacheService
  13. from services.oauth_service import OAuthService
  14. from config import settings
  15. from database import get_db_connection
  16. logger = logging.getLogger(__name__)
  17. # 全局 token 缓存实例
  18. # SSO token 有效期 600 秒,缓存设置为 550 秒(留 50 秒余量)
  19. token_cache = TokenCacheService(
  20. ttl_seconds=getattr(settings, 'TOKEN_CACHE_TTL', 550)
  21. )
  22. def verify_admin_token(token: str) -> dict:
  23. """
  24. 验证管理员 Token(从数据库查询)
  25. Args:
  26. token: Token 字符串
  27. Returns:
  28. dict: 用户信息字典,或 None(Token 无效或已过期)
  29. """
  30. try:
  31. with get_db_connection() as conn:
  32. cursor = conn.cursor()
  33. cursor.execute("""
  34. SELECT at.user_id, u.username, u.email, u.role, at.expires_at
  35. FROM admin_tokens at
  36. JOIN users u ON at.user_id = u.id
  37. WHERE at.token = %s AND at.expires_at > %s
  38. """, (token, datetime.now(timezone.utc)))
  39. row = cursor.fetchone()
  40. if not row:
  41. return None
  42. return {
  43. "id": row["user_id"],
  44. "username": row["username"],
  45. "email": row["email"],
  46. "role": row["role"],
  47. "is_admin_token": True,
  48. }
  49. except Exception as e:
  50. logger.error(f"验证管理员 Token 失败:{e}")
  51. return None
  52. class AuthMiddleware(BaseHTTPMiddleware):
  53. """
  54. SSO Token 认证中间件。
  55. 先查本地缓存,未命中则调用 SSO userinfo 端点验证。
  56. """
  57. PUBLIC_PATHS = {
  58. "/",
  59. "/health",
  60. "/docs",
  61. "/openapi.json",
  62. "/redoc",
  63. "/api/oauth/status",
  64. "/api/oauth/login",
  65. "/api/oauth/callback",
  66. "/api/oauth/refresh",
  67. }
  68. async def dispatch(self, request: Request, call_next):
  69. # Skip authentication for public paths
  70. logger.debug(f"AuthMiddleware: path={request.url.path}, method={request.method}")
  71. if request.url.path in self.PUBLIC_PATHS:
  72. logger.debug(f"Skipping auth for public path: {request.url.path}")
  73. return await call_next(request)
  74. # Skip authentication for OPTIONS requests (CORS preflight)
  75. if request.method == "OPTIONS":
  76. return await call_next(request)
  77. # Check if OAuth/SSO is enabled
  78. if not settings.OAUTH_ENABLED:
  79. return JSONResponse(
  80. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  81. content={
  82. "detail": "SSO 认证未配置",
  83. "error_type": "sso_not_configured"
  84. }
  85. )
  86. # Extract token from Authorization header
  87. auth_header = request.headers.get("Authorization")
  88. if not auth_header:
  89. return JSONResponse(
  90. status_code=status.HTTP_401_UNAUTHORIZED,
  91. content={
  92. "detail": "缺少认证令牌",
  93. "error_type": "missing_token"
  94. }
  95. )
  96. # Verify Bearer token format
  97. parts = auth_header.split()
  98. if len(parts) != 2 or parts[0].lower() != "bearer":
  99. return JSONResponse(
  100. status_code=status.HTTP_401_UNAUTHORIZED,
  101. content={
  102. "detail": "无效的认证令牌格式",
  103. "error_type": "invalid_token_format"
  104. }
  105. )
  106. sso_token = parts[1]
  107. try:
  108. # 1. 先检查是否是管理员 Token(以 admin_token_ 开头)
  109. user_info = None
  110. if sso_token.startswith("admin_token_"):
  111. logger.debug("检测到管理员 Token,尝试从数据库验证")
  112. user_info = verify_admin_token(sso_token)
  113. if user_info:
  114. logger.info(f"管理员 Token 验证成功:{user_info['username']}")
  115. # 2. 如果不是管理员 Token,查本地缓存(SSO token)
  116. if user_info is None:
  117. user_info = token_cache.get(sso_token)
  118. # 3. 缓存未命中,调 SSO profile 验证(含角色信息)
  119. if user_info is None:
  120. user_info = await OAuthService.verify_sso_token(sso_token)
  121. # 3. 同步用户到本地数据库(更新角色),获取本地用户ID
  122. try:
  123. local_user = OAuthService.sync_user_from_oauth(oauth_user_info=user_info)
  124. # 将本地user.id也存入user_info,供后续使用
  125. user_info["local_user_id"] = local_user.id
  126. except Exception as sync_err:
  127. logger.warning(f"用户同步失败(不影响认证): {sync_err}")
  128. # 4. 写入缓存
  129. token_cache.set(sso_token, user_info)
  130. # 提取用户信息,优先使用本地用户ID
  131. # 如果缓存中没有 local_user_id,则重新同步用户
  132. local_user_id = user_info.get("local_user_id")
  133. if not local_user_id:
  134. try:
  135. local_user = OAuthService.sync_user_from_oauth(oauth_user_info=user_info)
  136. local_user_id = local_user.id
  137. # 更新缓存,避免下次重复同步
  138. user_info["local_user_id"] = local_user_id
  139. token_cache.set(sso_token, user_info)
  140. except Exception as sync_err:
  141. logger.warning(f"重新同步用户失败,使用SSO ID: {sync_err}")
  142. user_id = local_user_id or user_info.get("id") or user_info.get("sub") if not user_info.get("is_admin_token") else user_info.get("id")
  143. username = (
  144. user_info.get("username")
  145. or user_info.get("preferred_username")
  146. or user_info.get("name")
  147. )
  148. email = user_info.get("email", "")
  149. role = user_info.get("role", "viewer")
  150. # Attach user info to request state
  151. request.state.user = {
  152. "id": str(user_id),
  153. "username": username,
  154. "email": email,
  155. "role": role,
  156. }
  157. response = await call_next(request)
  158. return response
  159. except HTTPException as e:
  160. error_type = "invalid_token"
  161. if e.status_code == 503:
  162. error_type = "sso_unavailable"
  163. elif e.status_code == 401:
  164. # SSO 返回 401 说明 token 过期或无效,统一标记为 token_expired
  165. # 让前端有机会用 refresh_token 刷新
  166. error_type = "token_expired"
  167. # 同时清除本地缓存中的过期 token
  168. token_cache.invalidate(sso_token)
  169. return JSONResponse(
  170. status_code=e.status_code,
  171. content={
  172. "detail": e.detail,
  173. "error_type": error_type
  174. }
  175. )
  176. except Exception as e:
  177. logger.error("认证过程发生错误:%s", str(e))
  178. return JSONResponse(
  179. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  180. content={
  181. "detail": "认证过程发生错误",
  182. "error_type": "auth_error"
  183. }
  184. )
  185. def require_role(*allowed_roles: str):
  186. """
  187. Decorator to check user role.
  188. Usage:
  189. @require_role("admin", "annotator")
  190. async def my_endpoint(request: Request):
  191. ...
  192. Args:
  193. allowed_roles: Tuple of allowed role names
  194. Returns:
  195. Decorator function
  196. """
  197. def decorator(func):
  198. async def wrapper(request: Request, *args, **kwargs):
  199. user = getattr(request.state, "user", None)
  200. if not user:
  201. raise HTTPException(
  202. status_code=status.HTTP_401_UNAUTHORIZED,
  203. detail="未认证"
  204. )
  205. if user["role"] not in allowed_roles:
  206. raise HTTPException(
  207. status_code=status.HTTP_403_FORBIDDEN,
  208. detail="权限不足"
  209. )
  210. return await func(request, *args, **kwargs)
  211. return wrapper
  212. return decorator