| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- """
- 限流中间件模块
- 基于 Redis 的分布式限流中间件,拦截所有请求进行限流检查。
- 需求引用: 5.1, 5.2
- """
- import logging
- from typing import Optional
- from fastapi import Request
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.responses import JSONResponse
- from app.services.rate_limiter import rate_limiter, build_rate_limit_key
- logger = logging.getLogger(__name__)
- # 不需要限流的路径前缀
- EXCLUDED_PATHS = [
- "/health",
- "/static",
- "/exports",
- "/docs",
- "/openapi.json",
- "/redoc",
- "/api/sso/config",
- "/api/auth/verify",
- ]
- class RateLimitMiddleware(BaseHTTPMiddleware):
- """限流中间件
-
- 拦截所有请求进行限流检查:
- 1. 跳过不需要限流的路径(健康检查、静态资源等)
- 2. 从 JWT token 提取用户 ID,或使用客户端 IP
- 3. 根据请求路径获取对应的限流配置
- 4. 调用限流服务检查是否超限
- 5. 超限返回 429,否则继续处理并添加响应头
- """
-
- async def dispatch(self, request: Request, call_next):
- path = request.url.path
- # 1. 跳过不需要限流的路径
- if self._should_skip(path, request.method):
- return await call_next(request)
-
- # 2. 提取用户标识(user_id 或 IP)
- user_id = self._extract_user_id(request)
- client_ip = self._get_client_ip(request)
-
- # 3. 获取路径对应的限流配置
- limit_config = rate_limiter.get_limit_config(path)
- limit = limit_config["limit"]
- window = limit_config["window"]
-
- # 4. 构建限流键并检查
- rate_limit_key = build_rate_limit_key(
- user_id=user_id,
- ip=client_ip,
- path=path
- )
-
- allowed, remaining = await rate_limiter.is_allowed(
- key=rate_limit_key,
- limit=limit,
- window=window
- )
-
- # 5. 超限返回 429
- if not allowed:
- retry_after = rate_limiter.get_retry_after(window)
- logger.warning(
- f"限流触发: key={rate_limit_key}, path={path}, "
- f"limit={limit}, window={window}s"
- )
- return JSONResponse(
- status_code=429,
- content={
- "detail": "请求过于频繁,请稍后再试",
- "error": "rate_limit_exceeded",
- "retry_after": retry_after
- },
- headers={
- "Retry-After": str(retry_after),
- "X-RateLimit-Limit": str(limit),
- "X-RateLimit-Remaining": "0",
- "X-RateLimit-Reset": str(window)
- }
- )
-
- # 6. 继续处理请求
- response = await call_next(request)
-
- # 7. 添加限流相关响应头
- response.headers["X-RateLimit-Limit"] = str(limit)
- response.headers["X-RateLimit-Remaining"] = str(remaining)
- response.headers["X-RateLimit-Reset"] = str(window)
-
- return response
-
- def _should_skip(self, path: str, method: str = None) -> bool:
- """判断是否跳过限流检查
- CORS 预检(OPTIONS)请求必须跳过限流,
- 否则预检被拦截会导致前端拿不到 CORS 响应头。
- """
- if method == 'OPTIONS':
- return True
- for excluded in EXCLUDED_PATHS:
- if path.startswith(excluded):
- return True
- return False
-
- def _extract_user_id(self, request: Request) -> Optional[str]:
- """从请求中提取用户 ID
-
- 尝试从 Authorization header 中解析 JWT token 获取 user_id。
- 解析失败时返回 None,将使用 IP 地址进行限流。
- """
- auth_header = request.headers.get("Authorization")
- if not auth_header or not auth_header.startswith("Bearer "):
- return None
-
- try:
- from app.services.auth_service import AuthService
- token = auth_header[7:]
- payload = AuthService.verify_token(token)
- return payload.get("user_id")
- except Exception:
- # JWT 解析失败,尝试管理员 token
- try:
- from app.services.admin_auth_service import AdminAuthService
- token = auth_header[7:]
- payload = AdminAuthService.verify_token(token)
- admin_id = payload.get("admin_id")
- if admin_id:
- return f"admin_{admin_id}"
- except Exception:
- pass
-
- return None
-
- def _get_client_ip(self, request: Request) -> str:
- """获取客户端 IP 地址
-
- 优先从 X-Forwarded-For 或 X-Real-IP 头获取(反向代理场景),
- 否则使用直连客户端 IP。
- """
- # 检查 X-Forwarded-For(可能包含多个 IP,取第一个)
- forwarded_for = request.headers.get("X-Forwarded-For")
- if forwarded_for:
- return forwarded_for.split(",")[0].strip()
-
- # 检查 X-Real-IP
- real_ip = request.headers.get("X-Real-IP")
- if real_ip:
- return real_ip.strip()
-
- # 使用直连客户端 IP
- if request.client:
- return request.client.host
-
- return "unknown"
|