rate_limit_middleware.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. """
  2. 限流中间件模块
  3. 基于 Redis 的分布式限流中间件,拦截所有请求进行限流检查。
  4. 需求引用: 5.1, 5.2
  5. """
  6. import logging
  7. from typing import Optional
  8. from fastapi import Request
  9. from starlette.middleware.base import BaseHTTPMiddleware
  10. from starlette.responses import JSONResponse
  11. from app.services.rate_limiter import rate_limiter, build_rate_limit_key
  12. logger = logging.getLogger(__name__)
  13. # 不需要限流的路径前缀
  14. EXCLUDED_PATHS = [
  15. "/health",
  16. "/static",
  17. "/exports",
  18. "/docs",
  19. "/openapi.json",
  20. "/redoc",
  21. "/api/sso/config",
  22. "/api/auth/verify",
  23. ]
  24. class RateLimitMiddleware(BaseHTTPMiddleware):
  25. """限流中间件
  26. 拦截所有请求进行限流检查:
  27. 1. 跳过不需要限流的路径(健康检查、静态资源等)
  28. 2. 从 JWT token 提取用户 ID,或使用客户端 IP
  29. 3. 根据请求路径获取对应的限流配置
  30. 4. 调用限流服务检查是否超限
  31. 5. 超限返回 429,否则继续处理并添加响应头
  32. """
  33. async def dispatch(self, request: Request, call_next):
  34. path = request.url.path
  35. # 1. 跳过不需要限流的路径
  36. if self._should_skip(path, request.method):
  37. return await call_next(request)
  38. # 2. 提取用户标识(user_id 或 IP)
  39. user_id = self._extract_user_id(request)
  40. client_ip = self._get_client_ip(request)
  41. # 3. 获取路径对应的限流配置
  42. limit_config = rate_limiter.get_limit_config(path)
  43. limit = limit_config["limit"]
  44. window = limit_config["window"]
  45. # 4. 构建限流键并检查
  46. rate_limit_key = build_rate_limit_key(
  47. user_id=user_id,
  48. ip=client_ip,
  49. path=path
  50. )
  51. allowed, remaining = await rate_limiter.is_allowed(
  52. key=rate_limit_key,
  53. limit=limit,
  54. window=window
  55. )
  56. # 5. 超限返回 429
  57. if not allowed:
  58. retry_after = rate_limiter.get_retry_after(window)
  59. logger.warning(
  60. f"限流触发: key={rate_limit_key}, path={path}, "
  61. f"limit={limit}, window={window}s"
  62. )
  63. return JSONResponse(
  64. status_code=429,
  65. content={
  66. "detail": "请求过于频繁,请稍后再试",
  67. "error": "rate_limit_exceeded",
  68. "retry_after": retry_after
  69. },
  70. headers={
  71. "Retry-After": str(retry_after),
  72. "X-RateLimit-Limit": str(limit),
  73. "X-RateLimit-Remaining": "0",
  74. "X-RateLimit-Reset": str(window)
  75. }
  76. )
  77. # 6. 继续处理请求
  78. response = await call_next(request)
  79. # 7. 添加限流相关响应头
  80. response.headers["X-RateLimit-Limit"] = str(limit)
  81. response.headers["X-RateLimit-Remaining"] = str(remaining)
  82. response.headers["X-RateLimit-Reset"] = str(window)
  83. return response
  84. def _should_skip(self, path: str, method: str = None) -> bool:
  85. """判断是否跳过限流检查
  86. CORS 预检(OPTIONS)请求必须跳过限流,
  87. 否则预检被拦截会导致前端拿不到 CORS 响应头。
  88. """
  89. if method == 'OPTIONS':
  90. return True
  91. for excluded in EXCLUDED_PATHS:
  92. if path.startswith(excluded):
  93. return True
  94. return False
  95. def _extract_user_id(self, request: Request) -> Optional[str]:
  96. """从请求中提取用户 ID
  97. 尝试从 Authorization header 中解析 JWT token 获取 user_id。
  98. 解析失败时返回 None,将使用 IP 地址进行限流。
  99. """
  100. auth_header = request.headers.get("Authorization")
  101. if not auth_header or not auth_header.startswith("Bearer "):
  102. return None
  103. try:
  104. from app.services.auth_service import AuthService
  105. token = auth_header[7:]
  106. payload = AuthService.verify_token(token)
  107. return payload.get("user_id")
  108. except Exception:
  109. # JWT 解析失败,尝试管理员 token
  110. try:
  111. from app.services.admin_auth_service import AdminAuthService
  112. token = auth_header[7:]
  113. payload = AdminAuthService.verify_token(token)
  114. admin_id = payload.get("admin_id")
  115. if admin_id:
  116. return f"admin_{admin_id}"
  117. except Exception:
  118. pass
  119. return None
  120. def _get_client_ip(self, request: Request) -> str:
  121. """获取客户端 IP 地址
  122. 优先从 X-Forwarded-For 或 X-Real-IP 头获取(反向代理场景),
  123. 否则使用直连客户端 IP。
  124. """
  125. # 检查 X-Forwarded-For(可能包含多个 IP,取第一个)
  126. forwarded_for = request.headers.get("X-Forwarded-For")
  127. if forwarded_for:
  128. return forwarded_for.split(",")[0].strip()
  129. # 检查 X-Real-IP
  130. real_ip = request.headers.get("X-Real-IP")
  131. if real_ip:
  132. return real_ip.strip()
  133. # 使用直连客户端 IP
  134. if request.client:
  135. return request.client.host
  136. return "unknown"