auth_middleware.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """
  2. Authentication Middleware for JWT verification.
  3. Validates locally-signed JWT tokens.
  4. Also supports admin tokens generated by generate_admin_token.py script.
  5. """
  6. import logging
  7. from datetime import datetime, timezone
  8. from fastapi import Request, HTTPException, status
  9. from fastapi.responses import JSONResponse
  10. from starlette.middleware.base import BaseHTTPMiddleware
  11. from services import jwt_service
  12. from config import settings
  13. from database import get_db_connection
  14. logger = logging.getLogger(__name__)
  15. def verify_admin_token(token: str) -> dict:
  16. """
  17. 验证管理员 Token(从数据库查询)
  18. """
  19. try:
  20. with get_db_connection() as conn:
  21. cursor = conn.cursor()
  22. cursor.execute(
  23. """
  24. SELECT at.user_id, u.username, u.email, u.role, at.expires_at
  25. FROM admin_tokens at
  26. JOIN users u ON at.user_id = u.id
  27. WHERE at.token = %s AND at.expires_at > %s
  28. """,
  29. (token, datetime.now(timezone.utc)),
  30. )
  31. row = cursor.fetchone()
  32. if not row:
  33. return None
  34. return {
  35. "id": row["user_id"],
  36. "username": row["username"],
  37. "email": row["email"],
  38. "role": row["role"],
  39. "is_admin_token": True,
  40. }
  41. except Exception as e:
  42. logger.error(f"验证管理员 Token 失败:{e}")
  43. return None
  44. class AuthMiddleware(BaseHTTPMiddleware):
  45. """
  46. JWT 认证中间件。
  47. 验证本地签发的 JWT,或管理员 Token。
  48. """
  49. PUBLIC_PATHS = {
  50. "/",
  51. "/health",
  52. "/docs",
  53. "/openapi.json",
  54. "/redoc",
  55. "/api/oauth/status",
  56. "/api/oauth/login",
  57. "/api/oauth/exchange-code",
  58. "/api/oauth/refresh",
  59. "/api/oauth/logout",
  60. }
  61. OPEN_API_PREFIX = "/api/v1/open/"
  62. async def dispatch(self, request: Request, call_next):
  63. # Skip authentication for public paths
  64. if request.url.path in self.PUBLIC_PATHS:
  65. return await call_next(request)
  66. # Skip authentication for Open API routes (they use their own Bearer token check)
  67. if request.url.path.startswith(self.OPEN_API_PREFIX):
  68. return await call_next(request)
  69. # Skip authentication for OPTIONS requests (CORS preflight)
  70. if request.method == "OPTIONS":
  71. return await call_next(request)
  72. # Extract token from Authorization header
  73. auth_header = request.headers.get("Authorization")
  74. if not auth_header:
  75. return JSONResponse(
  76. status_code=status.HTTP_401_UNAUTHORIZED,
  77. content={
  78. "detail": "缺少认证令牌",
  79. "error_type": "missing_token",
  80. },
  81. )
  82. # Verify Bearer token format
  83. parts = auth_header.split()
  84. if len(parts) != 2 or parts[0].lower() != "bearer":
  85. return JSONResponse(
  86. status_code=status.HTTP_401_UNAUTHORIZED,
  87. content={
  88. "detail": "无效的认证令牌格式",
  89. "error_type": "invalid_token_format",
  90. },
  91. )
  92. token = parts[1]
  93. try:
  94. # 1. 先检查是否是管理员 Token(以 admin_token_ 开头)
  95. user_info = None
  96. if token.startswith("admin_token_"):
  97. user_info = verify_admin_token(token)
  98. if user_info:
  99. logger.info(f"管理员 Token 验证成功:{user_info['username']}")
  100. # 2. 如果不是管理员 Token,验证 JWT
  101. if user_info is None:
  102. payload = jwt_service.verify_token(token)
  103. role = payload.get("role", "")
  104. if role not in ("admin", "annotator", "viewer"):
  105. return JSONResponse(
  106. status_code=status.HTTP_403_FORBIDDEN,
  107. content={
  108. "detail": "未被识别的 SSO 角色,无权限访问",
  109. "error_type": "unrecognized_role",
  110. },
  111. )
  112. user_info = {
  113. "id": payload.get("sub"),
  114. "username": payload.get("username"),
  115. "email": payload.get("email", ""),
  116. "role": role,
  117. }
  118. # Attach user info to request state
  119. request.state.user = {
  120. "id": str(user_info["id"]),
  121. "username": user_info["username"],
  122. "email": user_info["email"],
  123. "role": user_info["role"],
  124. }
  125. response = await call_next(request)
  126. return response
  127. except HTTPException as e:
  128. error_type = "invalid_token"
  129. if e.status_code == 401:
  130. error_type = "token_expired"
  131. return JSONResponse(
  132. status_code=e.status_code,
  133. content={
  134. "detail": e.detail,
  135. "error_type": error_type,
  136. },
  137. )
  138. except Exception as e:
  139. logger.error("认证过程发生错误:%s", str(e))
  140. return JSONResponse(
  141. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  142. content={
  143. "detail": "认证过程发生错误",
  144. "error_type": "auth_error",
  145. },
  146. )
  147. def require_role(*allowed_roles: str):
  148. """
  149. Decorator to check user role.
  150. """
  151. def decorator(func):
  152. async def wrapper(request: Request, *args, **kwargs):
  153. user = getattr(request.state, "user", None)
  154. if not user:
  155. raise HTTPException(
  156. status_code=status.HTTP_401_UNAUTHORIZED,
  157. detail="未认证",
  158. )
  159. if user["role"] not in allowed_roles:
  160. raise HTTPException(
  161. status_code=status.HTTP_403_FORBIDDEN,
  162. detail="权限不足",
  163. )
  164. return await func(request, *args, **kwargs)
  165. return wrapper
  166. return decorator