rate_limiter.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """
  2. 限流服务模块
  3. 基于 Redis 的滑动窗口限流器实现。
  4. 需求引用: 5.1, 5.2, 5.3, 5.4, 5.5
  5. """
  6. from typing import Optional, Tuple
  7. import time
  8. import os
  9. import logging
  10. from app.core.redis import redis_manager
  11. logger = logging.getLogger(__name__)
  12. # 不同端点的限流配置
  13. RATE_LIMIT_CONFIG = {
  14. "/api/llm/chat": {"limit": 30, "window": 60}, # 对话:每分钟30次
  15. "/api/llm/stream": {"limit": 30, "window": 60}, # 流式对话:每分钟30次
  16. "/api/image/generate": {"limit": 10, "window": 60}, # 生图:每分钟10次
  17. "/api/video/generate": {"limit": 5, "window": 60}, # 生视频:每分钟5次
  18. "/api/audio/tts": {"limit": 20, "window": 60}, # 语音合成:每分钟20次
  19. "/api/audio/asr": {"limit": 20, "window": 60}, # 语音识别:每分钟20次
  20. "default": {"limit": 100, "window": 60} # 默认:每分钟100次
  21. }
  22. class RateLimiter:
  23. """基于 Redis 的滑动窗口限流器
  24. 使用 Redis Sorted Set 实现滑动窗口算法:
  25. - 每个请求记录为 Sorted Set 的一个成员,score 为请求时间戳
  26. - 检查时移除窗口外的旧记录,统计窗口内的请求数
  27. - 支持分布式部署,多实例共享限流计数
  28. """
  29. def __init__(self):
  30. self.default_limit = int(os.getenv("RATE_LIMIT_DEFAULT", "100"))
  31. self.default_window = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
  32. async def is_allowed(
  33. self,
  34. key: str,
  35. limit: int = None,
  36. window: int = None
  37. ) -> Tuple[bool, int]:
  38. """
  39. 检查请求是否允许
  40. 使用滑动窗口算法:
  41. 1. 移除窗口外的旧请求记录
  42. 2. 添加当前请求记录
  43. 3. 统计窗口内请求数
  44. 4. 判断是否超过限制
  45. Args:
  46. key: 限流键(通常为 user_id 或 IP)
  47. limit: 限流阈值,默认使用环境变量配置
  48. window: 时间窗口(秒),默认使用环境变量配置
  49. Returns:
  50. Tuple[bool, int]: (是否允许, 剩余配额)
  51. """
  52. limit = limit or self.default_limit
  53. window = window or self.default_window
  54. redis = redis_manager.get_client()
  55. if not redis:
  56. # 降级:无 Redis 时不限流
  57. logger.warning("Redis 不可用,限流降级为不限流")
  58. return True, limit
  59. now = time.time()
  60. window_start = now - window
  61. redis_key = f"ratelimit:{key}"
  62. try:
  63. # 使用 pipeline 批量执行,减少网络往返
  64. pipe = redis.pipeline()
  65. # 移除窗口外的请求记录
  66. pipe.zremrangebyscore(redis_key, 0, window_start)
  67. # 添加当前请求(使用时间戳作为 score 和 member 的一部分确保唯一性)
  68. member = f"{now}:{id(self)}:{os.getpid()}"
  69. pipe.zadd(redis_key, {member: now})
  70. # 获取窗口内请求数
  71. pipe.zcard(redis_key)
  72. # 设置过期时间(窗口时间 + 1秒缓冲)
  73. pipe.expire(redis_key, window + 1)
  74. results = await pipe.execute()
  75. current_count = results[2]
  76. if current_count > limit:
  77. # 超过限制,移除刚添加的请求记录
  78. await redis.zrem(redis_key, member)
  79. return False, 0
  80. remaining = limit - current_count
  81. return True, remaining
  82. except Exception as e:
  83. logger.error(f"限流检查失败: {e}")
  84. # 异常时降级为不限流
  85. return True, limit
  86. def get_retry_after(self, window: int = None) -> int:
  87. """获取重试等待时间
  88. Args:
  89. window: 时间窗口(秒)
  90. Returns:
  91. int: 建议的重试等待时间(秒)
  92. """
  93. return window or self.default_window
  94. def get_limit_config(self, path: str) -> dict:
  95. """获取指定路径的限流配置
  96. 根据请求路径匹配限流配置,支持前缀匹配。
  97. Args:
  98. path: API 路径
  99. Returns:
  100. dict: 包含 limit 和 window 的配置字典
  101. """
  102. # 精确匹配
  103. if path in RATE_LIMIT_CONFIG:
  104. return RATE_LIMIT_CONFIG[path]
  105. # 前缀匹配(按路径长度降序排列,优先匹配更具体的路径)
  106. sorted_paths = sorted(
  107. [p for p in RATE_LIMIT_CONFIG.keys() if p != "default"],
  108. key=len,
  109. reverse=True
  110. )
  111. for config_path in sorted_paths:
  112. if path.startswith(config_path):
  113. return RATE_LIMIT_CONFIG[config_path]
  114. # 返回默认配置
  115. return RATE_LIMIT_CONFIG["default"]
  116. async def get_current_count(self, key: str, window: int = None) -> int:
  117. """获取当前窗口内的请求计数
  118. Args:
  119. key: 限流键
  120. window: 时间窗口(秒)
  121. Returns:
  122. int: 当前请求计数
  123. """
  124. window = window or self.default_window
  125. redis = redis_manager.get_client()
  126. if not redis:
  127. return 0
  128. now = time.time()
  129. window_start = now - window
  130. redis_key = f"ratelimit:{key}"
  131. try:
  132. # 先清理过期记录
  133. await redis.zremrangebyscore(redis_key, 0, window_start)
  134. # 获取当前计数
  135. count = await redis.zcard(redis_key)
  136. return count
  137. except Exception as e:
  138. logger.error(f"获取限流计数失败: {e}")
  139. return 0
  140. async def reset(self, key: str) -> bool:
  141. """重置指定键的限流计数
  142. Args:
  143. key: 限流键
  144. Returns:
  145. bool: 是否重置成功
  146. """
  147. redis = redis_manager.get_client()
  148. if not redis:
  149. return False
  150. redis_key = f"ratelimit:{key}"
  151. try:
  152. await redis.delete(redis_key)
  153. return True
  154. except Exception as e:
  155. logger.error(f"重置限流计数失败: {e}")
  156. return False
  157. def build_rate_limit_key(
  158. user_id: Optional[str] = None,
  159. ip: Optional[str] = None,
  160. path: Optional[str] = None
  161. ) -> str:
  162. """构建限流键
  163. 优先使用 user_id,其次使用 IP 地址。
  164. 可选择性地加入 path 实现端点级别的限流。
  165. Args:
  166. user_id: 用户 ID
  167. ip: 客户端 IP 地址
  168. path: API 路径
  169. Returns:
  170. str: 限流键
  171. """
  172. identifier = user_id if user_id else f"ip:{ip or 'unknown'}"
  173. if path:
  174. return f"{identifier}:{path}"
  175. return identifier
  176. # 全局单例
  177. rate_limiter = RateLimiter()