""" 数据加密服务 提供敏感数据的加密和解密功能,确保传输安全 """ import base64 import json from typing import Dict, Any, Optional from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding from cryptography.hazmat.backends import default_backend import os from datetime import datetime, timedelta class DataEncryptionService: """数据加密服务类""" def __init__(self): self.encryption_key = os.getenv("ENCRYPTION_KEY") if not self.encryption_key: raise ValueError("Environment variable ENCRYPTION_KEY must be set for DataEncryptionService.") # 确保密钥是字符串类型 self.encryption_key = str(self.encryption_key) # 与前端保持一致,只取前32位,不进行填充 self.encryption_key = self.encryption_key[:32] def encrypt(self, data: Any) -> str: """ 加密数据 Args: data: 要加密的数据(可以是字符串、字典等) Returns: Base64编码的加密字符串 """ # 序列化数据 if isinstance(data, (dict, list)): data_str = json.dumps(data, ensure_ascii=False) else: data_str = str(data) # 生成随机IV iv = os.urandom(16) # 创建加密器 cipher = Cipher( algorithms.AES(self.encryption_key.encode()), modes.CBC(iv), backend=default_backend() ) encryptor = cipher.encryptor() # 添加PKCS7填充 padder = padding.PKCS7(128).padder() padded_data = padder.update(data_str.encode()) + padder.finalize() # 加密数据 encrypted_data = encryptor.update(padded_data) + encryptor.finalize() # 组合IV和加密数据 combined = iv + encrypted_data # Base64编码 return base64.b64encode(combined).decode() def decrypt(self, encrypted_data: str, session_key: Optional[str] = None) -> Any: """ 解密数据 Args: encrypted_data: Base64编码的加密字符串 session_key: 会话密钥,如果提供则使用会话密钥解密 Returns: 解密后的原始数据 """ # Base64解码 combined = base64.b64decode(encrypted_data) # 分离IV和加密数据 iv = combined[:16] encrypted_data = combined[16:] # 使用会话密钥或默认密钥 key = session_key if session_key else self.encryption_key # 创建解密器 cipher = Cipher( algorithms.AES(key.encode()), modes.CBC(iv), backend=default_backend() ) decryptor = cipher.decryptor() # 解密数据 padded_data = decryptor.update(encrypted_data) + decryptor.finalize() # 移除PKCS7填充 unpadder = padding.PKCS7(128).unpadder() data_str = unpadder.update(padded_data) + unpadder.finalize() # 尝试解析JSON,如果不是JSON则返回字符串 try: return json.loads(data_str.decode()) except json.JSONDecodeError: return data_str.decode() def encrypt_sensitive_fields(self, data: Dict[str, Any], sensitive_fields: list) -> Dict[str, Any]: """ 加密字典中的敏感字段 Args: data: 原始数据字典 sensitive_fields: 需要加密的字段名列表 Returns: 加密后的数据字典 """ encrypted_data = data.copy() for field in sensitive_fields: if field in encrypted_data and encrypted_data[field] is not None: encrypted_data[field] = self.encrypt(encrypted_data[field]) return encrypted_data def decrypt_sensitive_fields(self, data: Dict[str, Any], sensitive_fields: list) -> Dict[str, Any]: """ 解密字典中的敏感字段 Args: data: 加密的数据字典 sensitive_fields: 需要解密的字段名列表 Returns: 解密后的数据字典 """ decrypted_data = data.copy() for field in sensitive_fields: if field in decrypted_data and decrypted_data[field] is not None: try: decrypted_data[field] = self.decrypt(decrypted_data[field]) except Exception: # 如果解密失败,保留原值 pass return decrypted_data def generate_secure_token(self, user_id: str, expires_in: int = 3600) -> str: """ 生成安全令牌(用于敏感操作) Args: user_id: 用户ID expires_in: 过期时间(秒) Returns: 加密的令牌 """ token_data = { "user_id": user_id, "timestamp": datetime.utcnow().isoformat(), "expires_in": expires_in } return self.encrypt(token_data) def verify_secure_token(self, token: str) -> Optional[Dict[str, Any]]: """ 验证安全令牌 Args: token: 加密的令牌 Returns: 令牌数据或None(如果无效) """ try: token_data = self.decrypt(token) # 检查过期时间 timestamp = datetime.fromisoformat(token_data["timestamp"]) expires_delta = timedelta(seconds=token_data["expires_in"]) if datetime.utcnow() > timestamp + expires_delta: return None return token_data except Exception: return None @staticmethod def mask_sensitive_data(data: str, mask_char: str = '*', visible_chars: int = 4) -> str: """ 掩码敏感数据(如手机号、邮箱等) Args: data: 原始数据 mask_char: 掩码字符 visible_chars: 显示的字符数 Returns: 掩码后的数据 """ if not data or len(data) <= visible_chars * 2: return data # 对邮箱特殊处理 if '@' in data: parts = data.split('@') username = parts[0] domain = parts[1] if len(username) > visible_chars: username = username[:visible_chars] + mask_char * (len(username) - visible_chars) return f"{username}@{domain}" # 对手机号特殊处理 if data.isdigit() and len(data) >= 11: return data[:3] + mask_char * (len(data) - 7) + data[-4:] # 默认处理 return data[:visible_chars] + mask_char * (len(data) - visible_chars * 2) + data[-visible_chars:] # 全局实例 encryption_service = DataEncryptionService()