security_utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. """
  2. 安全工具类
  3. 提供各种安全相关的功能
  4. """
  5. import hashlib
  6. import hmac
  7. import time
  8. import uuid
  9. from typing import Dict, Any, Optional
  10. from fastapi import Request, HTTPException, status
  11. from starlette.requests import Request as StarletteRequest
  12. class SecurityUtils:
  13. """安全工具类"""
  14. @staticmethod
  15. def generate_nonce() -> str:
  16. """生成随机数(用于防止重放攻击)"""
  17. return str(uuid.uuid4())
  18. @staticmethod
  19. def generate_timestamp() -> str:
  20. """生成时间戳"""
  21. return str(int(time.time()))
  22. @staticmethod
  23. def calculate_signature(data: str, secret: str) -> str:
  24. """
  25. 计算HMAC-SHA256签名
  26. Args:
  27. data: 要签名的数据
  28. secret: 密钥
  29. Returns:
  30. 十六进制格式的签名
  31. """
  32. return hmac.new(
  33. secret.encode(),
  34. data.encode(),
  35. hashlib.sha256
  36. ).hexdigest()
  37. @staticmethod
  38. def verify_signature(data: str, signature: str, secret: str) -> bool:
  39. """
  40. 验证签名
  41. Args:
  42. data: 原始数据
  43. signature: 要验证的签名
  44. secret: 密钥
  45. Returns:
  46. True 如果签名有效
  47. """
  48. expected_signature = SecurityUtils.calculate_signature(data, secret)
  49. return hmac.compare_digest(expected_signature, signature)
  50. @staticmethod
  51. def validate_request_headers(request: Request, required_headers: list) -> bool:
  52. """
  53. 验证请求头是否包含必要的字段
  54. Args:
  55. request: FastAPI请求对象
  56. required_headers: 必要的请求头列表
  57. Returns:
  58. True 如果所有必要字段都存在
  59. """
  60. for header in required_headers:
  61. if not request.headers.get(header):
  62. return False
  63. return True
  64. @staticmethod
  65. def rate_limit_key(ip: str, endpoint: str) -> str:
  66. """
  67. 生成限流键
  68. Args:
  69. ip: IP地址
  70. endpoint: API端点
  71. Returns:
  72. 限流键
  73. """
  74. return f"rate_limit:{ip}:{endpoint}"
  75. @staticmethod
  76. def sanitize_input(input_data: str) -> str:
  77. """
  78. 清理输入数据,防止XSS攻击
  79. Args:
  80. input_data: 输入数据
  81. Returns:
  82. 清理后的数据
  83. """
  84. # 移除潜在的XSS字符
  85. xss_chars = ['<', '>', '"', "'", '(', ')', '&', ';', '`']
  86. sanitized = input_data
  87. for char in xss_chars:
  88. sanitized = sanitized.replace(char, '')
  89. return sanitized
  90. @staticmethod
  91. def validate_csrf_token(request: Request, stored_token: str) -> bool:
  92. """
  93. 验证CSRF令牌
  94. Args:
  95. request: 请求对象
  96. stored_token: 存储的令牌
  97. Returns:
  98. True 如果令牌有效
  99. """
  100. # 从请求头或表单数据中获取令牌
  101. token = request.headers.get('X-CSRF-Token') or request.form.get('csrf_token')
  102. return token and hmac.compare_digest(token, stored_token)
  103. @staticmethod
  104. def get_client_ip(request: Request) -> str:
  105. """
  106. 获取客户端真实IP地址
  107. Args:
  108. request: 请求对象
  109. Returns:
  110. 客户端IP地址
  111. """
  112. # 处理代理服务器的情况
  113. x_forwarded_for = request.headers.get('X-Forwarded-For')
  114. if x_forwarded_for:
  115. # 如果有多个IP,取第一个(真实IP)
  116. ip = x_forwarded_for.split(',')[0].strip()
  117. else:
  118. ip = request.client.host if request.client else "unknown"
  119. return ip
  120. @staticmethod
  121. def check_sql_injection(input_data: str) -> bool:
  122. """
  123. 检查SQL注入攻击
  124. Args:
  125. input_data: 输入数据
  126. Returns:
  127. True 如果检测到SQL注入
  128. """
  129. sql_patterns = [
  130. r'union\s+select',
  131. r'or\s+1\s*=\s*1',
  132. r'and\s+1\s*=\s*1',
  133. r'drop\s+table',
  134. r'insert\s+into',
  135. r'delete\s+from',
  136. r'update\s+\w+\s+set',
  137. r'exec\s*\(',
  138. r'select\s+\*\s+from',
  139. r'--',
  140. r'/\*.*?\*/',
  141. r'xp_',
  142. r'sysobjects',
  143. r'information_schema',
  144. r'mysql\.info',
  145. r'pg_catalog'
  146. ]
  147. import re
  148. for pattern in sql_patterns:
  149. if re.search(pattern, input_data, re.IGNORECASE):
  150. return True
  151. return False
  152. @staticmethod
  153. def check_path_traversal(input_data: str) -> bool:
  154. """
  155. 检查路径遍历攻击
  156. Args:
  157. input_data: 输入数据
  158. Returns:
  159. True 如果检测到路径遍历
  160. """
  161. traversal_chars = ['../', '..\\', '.\\', '.\\']
  162. return any(char in input_data for char in traversal_chars)
  163. @staticmethod
  164. def secure_headers() -> Dict[str, str]:
  165. """
  166. 生成安全HTTP头
  167. Returns:
  168. 安全头字典
  169. """
  170. return {
  171. 'X-Content-Type-Options': 'nosniff',
  172. 'X-Frame-Options': 'DENY',
  173. 'X-XSS-Protection': '1; mode=block',
  174. 'Strict-Transport-Security': 'max-age=31536000; includeSubDomains',
  175. 'Content-Security-Policy': "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'",
  176. 'Referrer-Policy': 'strict-origin-when-cross-origin',
  177. 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'
  178. }