| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- """
- 数据加密服务
- 提供敏感数据的加密和解密功能,确保传输安全
- """
- import base64
- import json
- from typing import Dict, Any, Optional
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives import padding
- from cryptography.hazmat.backends import default_backend
- import os
- from datetime import datetime, timedelta
- class DataEncryptionService:
- """数据加密服务类"""
- def __init__(self):
- self.encryption_key = os.getenv("ENCRYPTION_KEY")
- if not self.encryption_key:
- raise ValueError("Environment variable ENCRYPTION_KEY must be set for DataEncryptionService.")
- # 确保密钥是字符串类型
- self.encryption_key = str(self.encryption_key)
-
- # 与前端保持一致,只取前32位,不进行填充
- self.encryption_key = self.encryption_key[:32]
- def encrypt(self, data: Any) -> str:
- """
- 加密数据
- Args:
- data: 要加密的数据(可以是字符串、字典等)
- Returns:
- Base64编码的加密字符串
- """
- # 序列化数据
- if isinstance(data, (dict, list)):
- data_str = json.dumps(data, ensure_ascii=False)
- else:
- data_str = str(data)
- # 生成随机IV
- iv = os.urandom(16)
- # 创建加密器
- cipher = Cipher(
- algorithms.AES(self.encryption_key.encode()),
- modes.CBC(iv),
- backend=default_backend()
- )
- encryptor = cipher.encryptor()
- # 添加PKCS7填充
- padder = padding.PKCS7(128).padder()
- padded_data = padder.update(data_str.encode()) + padder.finalize()
- # 加密数据
- encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
- # 组合IV和加密数据
- combined = iv + encrypted_data
- # Base64编码
- return base64.b64encode(combined).decode()
- def decrypt(self, encrypted_data: str, session_key: Optional[str] = None) -> Any:
- """
- 解密数据
- Args:
- encrypted_data: Base64编码的加密字符串
- session_key: 会话密钥,如果提供则使用会话密钥解密
- Returns:
- 解密后的原始数据
- """
- # Base64解码
- combined = base64.b64decode(encrypted_data)
- # 分离IV和加密数据
- iv = combined[:16]
- encrypted_data = combined[16:]
- # 使用会话密钥或默认密钥
- key = session_key if session_key else self.encryption_key
- # 创建解密器
- cipher = Cipher(
- algorithms.AES(key.encode()),
- modes.CBC(iv),
- backend=default_backend()
- )
- decryptor = cipher.decryptor()
- # 解密数据
- padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
- # 移除PKCS7填充
- unpadder = padding.PKCS7(128).unpadder()
- data_str = unpadder.update(padded_data) + unpadder.finalize()
- # 尝试解析JSON,如果不是JSON则返回字符串
- try:
- return json.loads(data_str.decode())
- except json.JSONDecodeError:
- return data_str.decode()
- def encrypt_sensitive_fields(self, data: Dict[str, Any], sensitive_fields: list) -> Dict[str, Any]:
- """
- 加密字典中的敏感字段
- Args:
- data: 原始数据字典
- sensitive_fields: 需要加密的字段名列表
- Returns:
- 加密后的数据字典
- """
- encrypted_data = data.copy()
- for field in sensitive_fields:
- if field in encrypted_data and encrypted_data[field] is not None:
- encrypted_data[field] = self.encrypt(encrypted_data[field])
- return encrypted_data
- def decrypt_sensitive_fields(self, data: Dict[str, Any], sensitive_fields: list) -> Dict[str, Any]:
- """
- 解密字典中的敏感字段
- Args:
- data: 加密的数据字典
- sensitive_fields: 需要解密的字段名列表
- Returns:
- 解密后的数据字典
- """
- decrypted_data = data.copy()
- for field in sensitive_fields:
- if field in decrypted_data and decrypted_data[field] is not None:
- try:
- decrypted_data[field] = self.decrypt(decrypted_data[field])
- except Exception:
- # 如果解密失败,保留原值
- pass
- return decrypted_data
- def generate_secure_token(self, user_id: str, expires_in: int = 3600) -> str:
- """
- 生成安全令牌(用于敏感操作)
- Args:
- user_id: 用户ID
- expires_in: 过期时间(秒)
- Returns:
- 加密的令牌
- """
- token_data = {
- "user_id": user_id,
- "timestamp": datetime.utcnow().isoformat(),
- "expires_in": expires_in
- }
- return self.encrypt(token_data)
- def verify_secure_token(self, token: str) -> Optional[Dict[str, Any]]:
- """
- 验证安全令牌
- Args:
- token: 加密的令牌
- Returns:
- 令牌数据或None(如果无效)
- """
- try:
- token_data = self.decrypt(token)
- # 检查过期时间
- timestamp = datetime.fromisoformat(token_data["timestamp"])
- expires_delta = timedelta(seconds=token_data["expires_in"])
- if datetime.utcnow() > timestamp + expires_delta:
- return None
- return token_data
- except Exception:
- return None
- @staticmethod
- def mask_sensitive_data(data: str, mask_char: str = '*', visible_chars: int = 4) -> str:
- """
- 掩码敏感数据(如手机号、邮箱等)
- Args:
- data: 原始数据
- mask_char: 掩码字符
- visible_chars: 显示的字符数
- Returns:
- 掩码后的数据
- """
- if not data or len(data) <= visible_chars * 2:
- return data
- # 对邮箱特殊处理
- if '@' in data:
- parts = data.split('@')
- username = parts[0]
- domain = parts[1]
- if len(username) > visible_chars:
- username = username[:visible_chars] + mask_char * (len(username) - visible_chars)
- return f"{username}@{domain}"
- # 对手机号特殊处理
- if data.isdigit() and len(data) >= 11:
- return data[:3] + mask_char * (len(data) - 7) + data[-4:]
- # 默认处理
- return data[:visible_chars] + mask_char * (len(data) - visible_chars * 2) + data[-visible_chars:]
- # 全局实例
- encryption_service = DataEncryptionService()
|