rate_limit.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # coding=utf-8
  2. """
  3. 应用 API 限流中间件
  4. 基于滑动窗口算法实现 API 调用限流
  5. """
  6. import time
  7. from collections import defaultdict
  8. from django.http import JsonResponse
  9. from django.utils.deprecation import MiddlewareMixin
  10. from common.utils.logger import maxkb_logger
  11. class RateLimitMiddleware(MiddlewareMixin):
  12. """
  13. 应用 API 限流中间件
  14. 使用内存中的滑动窗口算法实现限流
  15. 适用于单实例部署场景
  16. """
  17. def __init__(self, get_response=None):
  18. super().__init__(get_response)
  19. # 内存存储:{app_id: [(timestamp, count)]}
  20. self._requests = defaultdict(list)
  21. self._last_cleanup = time.time()
  22. def process_request(self, request):
  23. """
  24. 处理请求限流检查
  25. """
  26. # 只对应用 API 进行限流检查
  27. path = request.path
  28. if not self._is_app_api(path):
  29. return None
  30. # 获取应用 ID
  31. app_id = self._extract_app_id(path, request)
  32. if not app_id:
  33. return None
  34. # 获取或创建限流配置
  35. rate_config = self._get_rate_config(app_id)
  36. if not rate_config or not rate_config.get('is_enabled'):
  37. return None
  38. # 检查是否超过限流
  39. client_ip = self._get_client_ip(request)
  40. is_limited = self._check_rate_limit(
  41. app_id,
  42. rate_config['max_requests'],
  43. rate_config['window_seconds']
  44. )
  45. if is_limited:
  46. maxkb_logger.warning(
  47. f"Rate limit exceeded for app {app_id} from {client_ip}"
  48. )
  49. return JsonResponse({
  50. 'code': 429,
  51. 'message': '请求过于频繁,请稍后再试',
  52. 'data': {
  53. 'retry_after': rate_config.get('window_seconds', 60)
  54. }
  55. }, status=429)
  56. # 记录请求
  57. self._record_request(app_id, client_ip, path)
  58. return None
  59. def _is_app_api(self, path: str) -> bool:
  60. """判断是否为应用 API"""
  61. return '/api/workspace/' in path and '/application/' in path
  62. def _extract_app_id(self, path: str, request) -> str:
  63. """从路径或请求中提取应用 ID"""
  64. # 尝试从路径中提取
  65. parts = path.split('/')
  66. for i, part in enumerate(parts):
  67. if part == 'application' and i + 1 < len(parts):
  68. return parts[i + 1]
  69. return None
  70. def _get_client_ip(self, request) -> str:
  71. """获取客户端 IP"""
  72. x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
  73. if x_forwarded_for:
  74. return x_forwarded_for.split(',')[0].strip()
  75. return request.META.get('REMOTE_ADDR', '')
  76. def _get_rate_config(self, app_id: str) -> dict:
  77. """
  78. 获取应用的限流配置
  79. 从数据库查询,实际生产中应使用缓存
  80. """
  81. try:
  82. from application.models.rate_limit import RateLimit
  83. from application.models.application import Application
  84. try:
  85. application = Application.objects.get(id=app_id)
  86. except Application.DoesNotExist:
  87. return None
  88. try:
  89. rate_limit = RateLimit.objects.get(application=application)
  90. return {
  91. 'is_enabled': rate_limit.is_enabled,
  92. 'max_requests': rate_limit.max_requests,
  93. 'window_seconds': rate_limit.window_seconds,
  94. }
  95. except RateLimit.DoesNotExist:
  96. return None
  97. except Exception as e:
  98. maxkb_logger.error(f"Failed to get rate config: {e}")
  99. return None
  100. def _check_rate_limit(self, app_id: str, max_requests: int,
  101. window_seconds: int) -> bool:
  102. """
  103. 检查是否超过限流
  104. 使用滑动窗口算法
  105. """
  106. now = time.time()
  107. window_start = now - window_seconds
  108. # 清理过期记录
  109. self._requests[app_id] = [
  110. ts for ts in self._requests[app_id]
  111. if ts > window_start
  112. ]
  113. # 检查是否超过限制
  114. return len(self._requests[app_id]) >= max_requests
  115. def _record_request(self, app_id: str, client_ip: str, path: str):
  116. """记录请求"""
  117. now = time.time()
  118. self._requests[app_id].append(now)
  119. # 定期清理过期数据
  120. if now - self._last_cleanup > 300: # 5分钟清理一次
  121. self._cleanup_old_requests()
  122. self._last_cleanup = now
  123. def _cleanup_old_requests(self):
  124. """清理所有过期的请求记录"""
  125. now = time.time()
  126. for app_id in list(self._requests.keys()):
  127. self._requests[app_id] = [
  128. ts for ts in self._requests[app_id]
  129. if now - ts < 3600 # 保留1小时内的记录
  130. ]
  131. if not self._requests[app_id]:
  132. del self._requests[app_id]