""" 安全工具类 提供各种安全相关的功能 """ import hashlib import hmac import time import uuid from typing import Dict, Any, Optional from fastapi import Request, HTTPException, status from starlette.requests import Request as StarletteRequest class SecurityUtils: """安全工具类""" @staticmethod def generate_nonce() -> str: """生成随机数(用于防止重放攻击)""" return str(uuid.uuid4()) @staticmethod def generate_timestamp() -> str: """生成时间戳""" return str(int(time.time())) @staticmethod def calculate_signature(data: str, secret: str) -> str: """ 计算HMAC-SHA256签名 Args: data: 要签名的数据 secret: 密钥 Returns: 十六进制格式的签名 """ return hmac.new( secret.encode(), data.encode(), hashlib.sha256 ).hexdigest() @staticmethod def verify_signature(data: str, signature: str, secret: str) -> bool: """ 验证签名 Args: data: 原始数据 signature: 要验证的签名 secret: 密钥 Returns: True 如果签名有效 """ expected_signature = SecurityUtils.calculate_signature(data, secret) return hmac.compare_digest(expected_signature, signature) @staticmethod def validate_request_headers(request: Request, required_headers: list) -> bool: """ 验证请求头是否包含必要的字段 Args: request: FastAPI请求对象 required_headers: 必要的请求头列表 Returns: True 如果所有必要字段都存在 """ for header in required_headers: if not request.headers.get(header): return False return True @staticmethod def rate_limit_key(ip: str, endpoint: str) -> str: """ 生成限流键 Args: ip: IP地址 endpoint: API端点 Returns: 限流键 """ return f"rate_limit:{ip}:{endpoint}" @staticmethod def sanitize_input(input_data: str) -> str: """ 清理输入数据,防止XSS攻击 Args: input_data: 输入数据 Returns: 清理后的数据 """ # 移除潜在的XSS字符 xss_chars = ['<', '>', '"', "'", '(', ')', '&', ';', '`'] sanitized = input_data for char in xss_chars: sanitized = sanitized.replace(char, '') return sanitized @staticmethod def validate_csrf_token(request: Request, stored_token: str) -> bool: """ 验证CSRF令牌 Args: request: 请求对象 stored_token: 存储的令牌 Returns: True 如果令牌有效 """ # 从请求头或表单数据中获取令牌 token = request.headers.get('X-CSRF-Token') or request.form.get('csrf_token') return token and hmac.compare_digest(token, stored_token) @staticmethod def get_client_ip(request: Request) -> str: """ 获取客户端真实IP地址 Args: request: 请求对象 Returns: 客户端IP地址 """ # 处理代理服务器的情况 x_forwarded_for = request.headers.get('X-Forwarded-For') if x_forwarded_for: # 如果有多个IP,取第一个(真实IP) ip = x_forwarded_for.split(',')[0].strip() else: ip = request.client.host if request.client else "unknown" return ip @staticmethod def check_sql_injection(input_data: str) -> bool: """ 检查SQL注入攻击 Args: input_data: 输入数据 Returns: True 如果检测到SQL注入 """ sql_patterns = [ r'union\s+select', r'or\s+1\s*=\s*1', r'and\s+1\s*=\s*1', r'drop\s+table', r'insert\s+into', r'delete\s+from', r'update\s+\w+\s+set', r'exec\s*\(', r'select\s+\*\s+from', r'--', r'/\*.*?\*/', r'xp_', r'sysobjects', r'information_schema', r'mysql\.info', r'pg_catalog' ] import re for pattern in sql_patterns: if re.search(pattern, input_data, re.IGNORECASE): return True return False @staticmethod def check_path_traversal(input_data: str) -> bool: """ 检查路径遍历攻击 Args: input_data: 输入数据 Returns: True 如果检测到路径遍历 """ traversal_chars = ['../', '..\\', '.\\', '.\\'] return any(char in input_data for char in traversal_chars) @staticmethod def secure_headers() -> Dict[str, str]: """ 生成安全HTTP头 Returns: 安全头字典 """ return { 'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Strict-Transport-Security': 'max-age=31536000; includeSubDomains', 'Content-Security-Policy': "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'", 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()' }