auth_middleware.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. """
  2. Authentication Middleware for JWT token verification.
  3. Validates JWT tokens and attaches user info to request state.
  4. """
  5. from fastapi import Request, HTTPException, status
  6. from fastapi.responses import JSONResponse
  7. from starlette.middleware.base import BaseHTTPMiddleware
  8. from services.jwt_service import JWTService
  9. import jwt
  10. class AuthMiddleware(BaseHTTPMiddleware):
  11. """
  12. Authentication middleware for JWT token verification.
  13. Validates JWT tokens and attaches user info to request state.
  14. """
  15. # Public endpoints that don't require authentication
  16. PUBLIC_PATHS = {
  17. "/",
  18. "/health",
  19. "/docs",
  20. "/openapi.json",
  21. "/redoc",
  22. "/api/auth/register",
  23. "/api/auth/login",
  24. "/api/auth/refresh",
  25. "/api/oauth/status",
  26. "/api/oauth/login",
  27. "/api/oauth/callback"
  28. }
  29. async def dispatch(self, request: Request, call_next):
  30. """
  31. Process each request through authentication.
  32. Args:
  33. request: FastAPI Request object
  34. call_next: Next middleware or route handler
  35. Returns:
  36. Response from next handler or error response
  37. """
  38. # Skip authentication for public paths
  39. if request.url.path in self.PUBLIC_PATHS:
  40. return await call_next(request)
  41. # Skip authentication for OPTIONS requests (CORS preflight)
  42. if request.method == "OPTIONS":
  43. return await call_next(request)
  44. # Extract token from Authorization header
  45. auth_header = request.headers.get("Authorization")
  46. if not auth_header:
  47. return JSONResponse(
  48. status_code=status.HTTP_401_UNAUTHORIZED,
  49. content={
  50. "detail": "缺少认证令牌",
  51. "error_type": "missing_token"
  52. }
  53. )
  54. # Verify Bearer token format
  55. parts = auth_header.split()
  56. if len(parts) != 2 or parts[0].lower() != "bearer":
  57. return JSONResponse(
  58. status_code=status.HTTP_401_UNAUTHORIZED,
  59. content={
  60. "detail": "无效的认证令牌格式",
  61. "error_type": "invalid_token_format"
  62. }
  63. )
  64. token = parts[1]
  65. try:
  66. # Verify and decode token
  67. payload = JWTService.verify_token(token, "access")
  68. if not payload:
  69. return JSONResponse(
  70. status_code=status.HTTP_401_UNAUTHORIZED,
  71. content={
  72. "detail": "无效的认证令牌",
  73. "error_type": "invalid_token"
  74. }
  75. )
  76. # Attach user info to request state
  77. request.state.user = {
  78. "id": payload["sub"],
  79. "username": payload["username"],
  80. "email": payload["email"],
  81. "role": payload["role"]
  82. }
  83. # Continue to route handler
  84. response = await call_next(request)
  85. return response
  86. except jwt.ExpiredSignatureError:
  87. return JSONResponse(
  88. status_code=status.HTTP_401_UNAUTHORIZED,
  89. content={
  90. "detail": "认证令牌已过期",
  91. "error_type": "token_expired"
  92. }
  93. )
  94. except jwt.InvalidTokenError:
  95. return JSONResponse(
  96. status_code=status.HTTP_401_UNAUTHORIZED,
  97. content={
  98. "detail": "无效的认证令牌",
  99. "error_type": "invalid_token"
  100. }
  101. )
  102. except Exception as e:
  103. return JSONResponse(
  104. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  105. content={
  106. "detail": "认证过程发生错误",
  107. "error_type": "auth_error"
  108. }
  109. )
  110. def require_role(*allowed_roles: str):
  111. """
  112. Decorator to check user role.
  113. Usage:
  114. @require_role("admin", "annotator")
  115. async def my_endpoint(request: Request):
  116. ...
  117. Args:
  118. allowed_roles: Tuple of allowed role names
  119. Returns:
  120. Decorator function
  121. """
  122. def decorator(func):
  123. async def wrapper(request: Request, *args, **kwargs):
  124. user = getattr(request.state, "user", None)
  125. if not user:
  126. raise HTTPException(
  127. status_code=status.HTTP_401_UNAUTHORIZED,
  128. detail="未认证"
  129. )
  130. if user["role"] not in allowed_roles:
  131. raise HTTPException(
  132. status_code=status.HTTP_403_FORBIDDEN,
  133. detail="权限不足"
  134. )
  135. return await func(request, *args, **kwargs)
  136. return wrapper
  137. return decorator