data_encryption_service.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """
  2. 数据加密服务
  3. 提供敏感数据的加密和解密功能,确保传输安全
  4. """
  5. import base64
  6. import json
  7. from typing import Dict, Any, Optional
  8. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  9. from cryptography.hazmat.primitives import padding
  10. from cryptography.hazmat.backends import default_backend
  11. import os
  12. from datetime import datetime, timedelta
  13. class DataEncryptionService:
  14. """数据加密服务类"""
  15. def __init__(self):
  16. self.encryption_key = os.getenv("ENCRYPTION_KEY")
  17. if not self.encryption_key:
  18. raise ValueError("Environment variable ENCRYPTION_KEY must be set for DataEncryptionService.")
  19. # 确保密钥是字符串类型
  20. self.encryption_key = str(self.encryption_key)
  21. # 与前端保持一致,只取前32位,不进行填充
  22. self.encryption_key = self.encryption_key[:32]
  23. def encrypt(self, data: Any) -> str:
  24. """
  25. 加密数据
  26. Args:
  27. data: 要加密的数据(可以是字符串、字典等)
  28. Returns:
  29. Base64编码的加密字符串
  30. """
  31. # 序列化数据
  32. if isinstance(data, (dict, list)):
  33. data_str = json.dumps(data, ensure_ascii=False)
  34. else:
  35. data_str = str(data)
  36. # 生成随机IV
  37. iv = os.urandom(16)
  38. # 创建加密器
  39. cipher = Cipher(
  40. algorithms.AES(self.encryption_key.encode()),
  41. modes.CBC(iv),
  42. backend=default_backend()
  43. )
  44. encryptor = cipher.encryptor()
  45. # 添加PKCS7填充
  46. padder = padding.PKCS7(128).padder()
  47. padded_data = padder.update(data_str.encode()) + padder.finalize()
  48. # 加密数据
  49. encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
  50. # 组合IV和加密数据
  51. combined = iv + encrypted_data
  52. # Base64编码
  53. return base64.b64encode(combined).decode()
  54. def decrypt(self, encrypted_data: str, session_key: Optional[str] = None) -> Any:
  55. """
  56. 解密数据
  57. Args:
  58. encrypted_data: Base64编码的加密字符串
  59. session_key: 会话密钥,如果提供则使用会话密钥解密
  60. Returns:
  61. 解密后的原始数据
  62. """
  63. # Base64解码
  64. combined = base64.b64decode(encrypted_data)
  65. # 分离IV和加密数据
  66. iv = combined[:16]
  67. encrypted_data = combined[16:]
  68. # 使用会话密钥或默认密钥
  69. key = session_key if session_key else self.encryption_key
  70. # 创建解密器
  71. cipher = Cipher(
  72. algorithms.AES(key.encode()),
  73. modes.CBC(iv),
  74. backend=default_backend()
  75. )
  76. decryptor = cipher.decryptor()
  77. # 解密数据
  78. padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
  79. # 移除PKCS7填充
  80. unpadder = padding.PKCS7(128).unpadder()
  81. data_str = unpadder.update(padded_data) + unpadder.finalize()
  82. # 尝试解析JSON,如果不是JSON则返回字符串
  83. try:
  84. return json.loads(data_str.decode())
  85. except json.JSONDecodeError:
  86. return data_str.decode()
  87. def encrypt_sensitive_fields(self, data: Dict[str, Any], sensitive_fields: list) -> Dict[str, Any]:
  88. """
  89. 加密字典中的敏感字段
  90. Args:
  91. data: 原始数据字典
  92. sensitive_fields: 需要加密的字段名列表
  93. Returns:
  94. 加密后的数据字典
  95. """
  96. encrypted_data = data.copy()
  97. for field in sensitive_fields:
  98. if field in encrypted_data and encrypted_data[field] is not None:
  99. encrypted_data[field] = self.encrypt(encrypted_data[field])
  100. return encrypted_data
  101. def decrypt_sensitive_fields(self, data: Dict[str, Any], sensitive_fields: list) -> Dict[str, Any]:
  102. """
  103. 解密字典中的敏感字段
  104. Args:
  105. data: 加密的数据字典
  106. sensitive_fields: 需要解密的字段名列表
  107. Returns:
  108. 解密后的数据字典
  109. """
  110. decrypted_data = data.copy()
  111. for field in sensitive_fields:
  112. if field in decrypted_data and decrypted_data[field] is not None:
  113. try:
  114. decrypted_data[field] = self.decrypt(decrypted_data[field])
  115. except Exception:
  116. # 如果解密失败,保留原值
  117. pass
  118. return decrypted_data
  119. def generate_secure_token(self, user_id: str, expires_in: int = 3600) -> str:
  120. """
  121. 生成安全令牌(用于敏感操作)
  122. Args:
  123. user_id: 用户ID
  124. expires_in: 过期时间(秒)
  125. Returns:
  126. 加密的令牌
  127. """
  128. token_data = {
  129. "user_id": user_id,
  130. "timestamp": datetime.utcnow().isoformat(),
  131. "expires_in": expires_in
  132. }
  133. return self.encrypt(token_data)
  134. def verify_secure_token(self, token: str) -> Optional[Dict[str, Any]]:
  135. """
  136. 验证安全令牌
  137. Args:
  138. token: 加密的令牌
  139. Returns:
  140. 令牌数据或None(如果无效)
  141. """
  142. try:
  143. token_data = self.decrypt(token)
  144. # 检查过期时间
  145. timestamp = datetime.fromisoformat(token_data["timestamp"])
  146. expires_delta = timedelta(seconds=token_data["expires_in"])
  147. if datetime.utcnow() > timestamp + expires_delta:
  148. return None
  149. return token_data
  150. except Exception:
  151. return None
  152. @staticmethod
  153. def mask_sensitive_data(data: str, mask_char: str = '*', visible_chars: int = 4) -> str:
  154. """
  155. 掩码敏感数据(如手机号、邮箱等)
  156. Args:
  157. data: 原始数据
  158. mask_char: 掩码字符
  159. visible_chars: 显示的字符数
  160. Returns:
  161. 掩码后的数据
  162. """
  163. if not data or len(data) <= visible_chars * 2:
  164. return data
  165. # 对邮箱特殊处理
  166. if '@' in data:
  167. parts = data.split('@')
  168. username = parts[0]
  169. domain = parts[1]
  170. if len(username) > visible_chars:
  171. username = username[:visible_chars] + mask_char * (len(username) - visible_chars)
  172. return f"{username}@{domain}"
  173. # 对手机号特殊处理
  174. if data.isdigit() and len(data) >= 11:
  175. return data[:3] + mask_char * (len(data) - 7) + data[-4:]
  176. # 默认处理
  177. return data[:visible_chars] + mask_char * (len(data) - visible_chars * 2) + data[-visible_chars:]
  178. # 全局实例
  179. encryption_service = DataEncryptionService()