""" 限流中间件模块 基于 Redis 的分布式限流中间件,拦截所有请求进行限流检查。 需求引用: 5.1, 5.2 """ import logging from typing import Optional from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from app.services.rate_limiter import rate_limiter, build_rate_limit_key logger = logging.getLogger(__name__) # 不需要限流的路径前缀 EXCLUDED_PATHS = [ "/health", "/static", "/exports", "/docs", "/openapi.json", "/redoc", "/api/sso/config", "/api/auth/verify", ] class RateLimitMiddleware(BaseHTTPMiddleware): """限流中间件 拦截所有请求进行限流检查: 1. 跳过不需要限流的路径(健康检查、静态资源等) 2. 从 JWT token 提取用户 ID,或使用客户端 IP 3. 根据请求路径获取对应的限流配置 4. 调用限流服务检查是否超限 5. 超限返回 429,否则继续处理并添加响应头 """ async def dispatch(self, request: Request, call_next): path = request.url.path # 1. 跳过不需要限流的路径 if self._should_skip(path, request.method): return await call_next(request) # 2. 提取用户标识(user_id 或 IP) user_id = self._extract_user_id(request) client_ip = self._get_client_ip(request) # 3. 获取路径对应的限流配置 limit_config = rate_limiter.get_limit_config(path) limit = limit_config["limit"] window = limit_config["window"] # 4. 构建限流键并检查 rate_limit_key = build_rate_limit_key( user_id=user_id, ip=client_ip, path=path ) allowed, remaining = await rate_limiter.is_allowed( key=rate_limit_key, limit=limit, window=window ) # 5. 超限返回 429 if not allowed: retry_after = rate_limiter.get_retry_after(window) logger.warning( f"限流触发: key={rate_limit_key}, path={path}, " f"limit={limit}, window={window}s" ) return JSONResponse( status_code=429, content={ "detail": "请求过于频繁,请稍后再试", "error": "rate_limit_exceeded", "retry_after": retry_after }, headers={ "Retry-After": str(retry_after), "X-RateLimit-Limit": str(limit), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(window) } ) # 6. 继续处理请求 response = await call_next(request) # 7. 添加限流相关响应头 response.headers["X-RateLimit-Limit"] = str(limit) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(window) return response def _should_skip(self, path: str, method: str = None) -> bool: """判断是否跳过限流检查 CORS 预检(OPTIONS)请求必须跳过限流, 否则预检被拦截会导致前端拿不到 CORS 响应头。 """ if method == 'OPTIONS': return True for excluded in EXCLUDED_PATHS: if path.startswith(excluded): return True return False def _extract_user_id(self, request: Request) -> Optional[str]: """从请求中提取用户 ID 尝试从 Authorization header 中解析 JWT token 获取 user_id。 解析失败时返回 None,将使用 IP 地址进行限流。 """ auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): return None try: from app.services.auth_service import AuthService token = auth_header[7:] payload = AuthService.verify_token(token) return payload.get("user_id") except Exception: # JWT 解析失败,尝试管理员 token try: from app.services.admin_auth_service import AdminAuthService token = auth_header[7:] payload = AdminAuthService.verify_token(token) admin_id = payload.get("admin_id") if admin_id: return f"admin_{admin_id}" except Exception: pass return None def _get_client_ip(self, request: Request) -> str: """获取客户端 IP 地址 优先从 X-Forwarded-For 或 X-Real-IP 头获取(反向代理场景), 否则使用直连客户端 IP。 """ # 检查 X-Forwarded-For(可能包含多个 IP,取第一个) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() # 检查 X-Real-IP real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip.strip() # 使用直连客户端 IP if request.client: return request.client.host return "unknown"