token_revocation_service.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. JWT 令牌失效管理
  3. 支持两种模式:
  4. - Redis 模式(生产):多进程/多实例共享,重启不丢失
  5. - 内存模式(降级):Redis 不可用时自动回退,单进程有效
  6. """
  7. import time
  8. import logging
  9. from threading import Lock
  10. from typing import Any
  11. logger = logging.getLogger(__name__)
  12. # Redis key 前缀
  13. _JTI_PREFIX = "revoked_jti:"
  14. _USER_LOGOUT_PREFIX = "user_logout_after:"
  15. # JTI 在 Redis 中的 TTL 与 JWT 过期时间对齐(最长 25 小时)
  16. _JTI_TTL = 90000
  17. class TokenRevocationService:
  18. """
  19. 维护被撤销 token 与用户会话失效时间。
  20. 优先使用 Redis 实现分布式撤销,Redis 不可用时自动降级到进程内存。
  21. """
  22. def __init__(self):
  23. self._lock = Lock()
  24. # 内存降级存储
  25. self._revoked_jti_exp: dict[str, int] = {}
  26. self._user_logout_after: dict[str, int] = {}
  27. def _get_redis(self):
  28. """获取同步 Redis 客户端,不可用返回 None。"""
  29. try:
  30. from app.core.redis import redis_manager
  31. return redis_manager.get_sync_client()
  32. except Exception:
  33. return None
  34. def _cleanup_memory(self) -> None:
  35. """清理内存中已过期的 JTI。"""
  36. now = int(time.time())
  37. expired = [jti for jti, exp in self._revoked_jti_exp.items() if exp <= now]
  38. for jti in expired:
  39. self._revoked_jti_exp.pop(jti, None)
  40. def revoke_payload(self, payload: dict[str, Any]) -> None:
  41. """将当前 payload 对应的 jti 标记为撤销。"""
  42. jti = payload.get("jti")
  43. exp = payload.get("exp")
  44. if not jti or not exp:
  45. return
  46. try:
  47. exp_ts = int(exp)
  48. except (TypeError, ValueError):
  49. return
  50. redis = self._get_redis()
  51. if redis:
  52. try:
  53. ttl = max(exp_ts - int(time.time()), 1)
  54. redis.setex(f"{_JTI_PREFIX}{jti}", ttl, "1")
  55. return
  56. except Exception as e:
  57. logger.warning(f"Redis revoke_payload 失败,降级到内存: {e}")
  58. # 内存降级
  59. with self._lock:
  60. self._cleanup_memory()
  61. self._revoked_jti_exp[str(jti)] = exp_ts
  62. def revoke_user_sessions(self, user_id: str) -> None:
  63. """使指定用户在当前时间之前签发的 token 全部失效。"""
  64. now = int(time.time())
  65. redis = self._get_redis()
  66. if redis:
  67. try:
  68. redis.setex(f"{_USER_LOGOUT_PREFIX}{user_id}", _JTI_TTL, str(now))
  69. return
  70. except Exception as e:
  71. logger.warning(f"Redis revoke_user_sessions 失败,降级到内存: {e}")
  72. with self._lock:
  73. self._cleanup_memory()
  74. self._user_logout_after[str(user_id)] = now
  75. def is_payload_revoked(self, payload: dict[str, Any]) -> bool:
  76. """判断 payload 对应 token 是否已失效。"""
  77. jti = payload.get("jti")
  78. user_id = payload.get("user_id")
  79. redis = self._get_redis()
  80. if redis:
  81. try:
  82. # 检查 JTI 是否被撤销
  83. if jti and redis.exists(f"{_JTI_PREFIX}{jti}"):
  84. return True
  85. # 检查用户会话是否被全局撤销
  86. if user_id:
  87. cutoff_raw = redis.get(f"{_USER_LOGOUT_PREFIX}{user_id}")
  88. if cutoff_raw:
  89. try:
  90. cutoff = int(cutoff_raw)
  91. issued_at = int(payload.get("iat", 0))
  92. if issued_at < cutoff:
  93. return True
  94. except (TypeError, ValueError):
  95. return True
  96. return False
  97. except Exception as e:
  98. logger.warning(f"Redis is_payload_revoked 失败,降级到内存: {e}")
  99. # 内存降级检查
  100. with self._lock:
  101. self._cleanup_memory()
  102. if jti and str(jti) in self._revoked_jti_exp:
  103. return True
  104. if not user_id:
  105. return False
  106. cutoff = self._user_logout_after.get(str(user_id))
  107. if cutoff is None:
  108. return False
  109. try:
  110. issued_at = int(payload.get("iat", 0))
  111. except (TypeError, ValueError):
  112. return True
  113. return issued_at < cutoff
  114. token_revocation_service = TokenRevocationService()