| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- """
- 限流服务模块
- 基于 Redis 的滑动窗口限流器实现。
- 需求引用: 5.1, 5.2, 5.3, 5.4, 5.5
- """
- from typing import Optional, Tuple
- import time
- import os
- import logging
- from app.core.redis import redis_manager
- logger = logging.getLogger(__name__)
- # 不同端点的限流配置
- RATE_LIMIT_CONFIG = {
- "/api/llm/chat": {"limit": 30, "window": 60}, # 对话:每分钟30次
- "/api/llm/stream": {"limit": 30, "window": 60}, # 流式对话:每分钟30次
- "/api/image/generate": {"limit": 10, "window": 60}, # 生图:每分钟10次
- "/api/video/generate": {"limit": 5, "window": 60}, # 生视频:每分钟5次
- "/api/audio/tts": {"limit": 20, "window": 60}, # 语音合成:每分钟20次
- "/api/audio/asr": {"limit": 20, "window": 60}, # 语音识别:每分钟20次
- "default": {"limit": 100, "window": 60} # 默认:每分钟100次
- }
- class RateLimiter:
- """基于 Redis 的滑动窗口限流器
-
- 使用 Redis Sorted Set 实现滑动窗口算法:
- - 每个请求记录为 Sorted Set 的一个成员,score 为请求时间戳
- - 检查时移除窗口外的旧记录,统计窗口内的请求数
- - 支持分布式部署,多实例共享限流计数
- """
-
- def __init__(self):
- self.default_limit = int(os.getenv("RATE_LIMIT_DEFAULT", "100"))
- self.default_window = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
-
- async def is_allowed(
- self,
- key: str,
- limit: int = None,
- window: int = None
- ) -> Tuple[bool, int]:
- """
- 检查请求是否允许
-
- 使用滑动窗口算法:
- 1. 移除窗口外的旧请求记录
- 2. 添加当前请求记录
- 3. 统计窗口内请求数
- 4. 判断是否超过限制
-
- Args:
- key: 限流键(通常为 user_id 或 IP)
- limit: 限流阈值,默认使用环境变量配置
- window: 时间窗口(秒),默认使用环境变量配置
-
- Returns:
- Tuple[bool, int]: (是否允许, 剩余配额)
- """
- limit = limit or self.default_limit
- window = window or self.default_window
-
- redis = redis_manager.get_client()
- if not redis:
- # 降级:无 Redis 时不限流
- logger.warning("Redis 不可用,限流降级为不限流")
- return True, limit
-
- now = time.time()
- window_start = now - window
- redis_key = f"ratelimit:{key}"
-
- try:
- # 使用 pipeline 批量执行,减少网络往返
- pipe = redis.pipeline()
- # 移除窗口外的请求记录
- pipe.zremrangebyscore(redis_key, 0, window_start)
- # 添加当前请求(使用时间戳作为 score 和 member 的一部分确保唯一性)
- member = f"{now}:{id(self)}:{os.getpid()}"
- pipe.zadd(redis_key, {member: now})
- # 获取窗口内请求数
- pipe.zcard(redis_key)
- # 设置过期时间(窗口时间 + 1秒缓冲)
- pipe.expire(redis_key, window + 1)
-
- results = await pipe.execute()
- current_count = results[2]
-
- if current_count > limit:
- # 超过限制,移除刚添加的请求记录
- await redis.zrem(redis_key, member)
- return False, 0
-
- remaining = limit - current_count
- return True, remaining
-
- except Exception as e:
- logger.error(f"限流检查失败: {e}")
- # 异常时降级为不限流
- return True, limit
-
- def get_retry_after(self, window: int = None) -> int:
- """获取重试等待时间
-
- Args:
- window: 时间窗口(秒)
-
- Returns:
- int: 建议的重试等待时间(秒)
- """
- return window or self.default_window
-
- def get_limit_config(self, path: str) -> dict:
- """获取指定路径的限流配置
-
- 根据请求路径匹配限流配置,支持前缀匹配。
-
- Args:
- path: API 路径
-
- Returns:
- dict: 包含 limit 和 window 的配置字典
- """
- # 精确匹配
- if path in RATE_LIMIT_CONFIG:
- return RATE_LIMIT_CONFIG[path]
-
- # 前缀匹配(按路径长度降序排列,优先匹配更具体的路径)
- sorted_paths = sorted(
- [p for p in RATE_LIMIT_CONFIG.keys() if p != "default"],
- key=len,
- reverse=True
- )
-
- for config_path in sorted_paths:
- if path.startswith(config_path):
- return RATE_LIMIT_CONFIG[config_path]
-
- # 返回默认配置
- return RATE_LIMIT_CONFIG["default"]
-
- async def get_current_count(self, key: str, window: int = None) -> int:
- """获取当前窗口内的请求计数
-
- Args:
- key: 限流键
- window: 时间窗口(秒)
-
- Returns:
- int: 当前请求计数
- """
- window = window or self.default_window
-
- redis = redis_manager.get_client()
- if not redis:
- return 0
-
- now = time.time()
- window_start = now - window
- redis_key = f"ratelimit:{key}"
-
- try:
- # 先清理过期记录
- await redis.zremrangebyscore(redis_key, 0, window_start)
- # 获取当前计数
- count = await redis.zcard(redis_key)
- return count
- except Exception as e:
- logger.error(f"获取限流计数失败: {e}")
- return 0
-
- async def reset(self, key: str) -> bool:
- """重置指定键的限流计数
-
- Args:
- key: 限流键
-
- Returns:
- bool: 是否重置成功
- """
- redis = redis_manager.get_client()
- if not redis:
- return False
-
- redis_key = f"ratelimit:{key}"
-
- try:
- await redis.delete(redis_key)
- return True
- except Exception as e:
- logger.error(f"重置限流计数失败: {e}")
- return False
- def build_rate_limit_key(
- user_id: Optional[str] = None,
- ip: Optional[str] = None,
- path: Optional[str] = None
- ) -> str:
- """构建限流键
-
- 优先使用 user_id,其次使用 IP 地址。
- 可选择性地加入 path 实现端点级别的限流。
-
- Args:
- user_id: 用户 ID
- ip: 客户端 IP 地址
- path: API 路径
-
- Returns:
- str: 限流键
- """
- identifier = user_id if user_id else f"ip:{ip or 'unknown'}"
-
- if path:
- return f"{identifier}:{path}"
-
- return identifier
- # 全局单例
- rate_limiter = RateLimiter()
|