| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- """
- JWT 令牌失效管理
- 支持两种模式:
- - Redis 模式(生产):多进程/多实例共享,重启不丢失
- - 内存模式(降级):Redis 不可用时自动回退,单进程有效
- """
- import time
- import logging
- from threading import Lock
- from typing import Any
- logger = logging.getLogger(__name__)
- # Redis key 前缀
- _JTI_PREFIX = "revoked_jti:"
- _USER_LOGOUT_PREFIX = "user_logout_after:"
- # JTI 在 Redis 中的 TTL 与 JWT 过期时间对齐(最长 25 小时)
- _JTI_TTL = 90000
- class TokenRevocationService:
- """
- 维护被撤销 token 与用户会话失效时间。
- 优先使用 Redis 实现分布式撤销,Redis 不可用时自动降级到进程内存。
- """
- def __init__(self):
- self._lock = Lock()
- # 内存降级存储
- self._revoked_jti_exp: dict[str, int] = {}
- self._user_logout_after: dict[str, int] = {}
- def _get_redis(self):
- """获取同步 Redis 客户端,不可用返回 None。"""
- try:
- from app.core.redis import redis_manager
- return redis_manager.get_sync_client()
- except Exception:
- return None
- def _cleanup_memory(self) -> None:
- """清理内存中已过期的 JTI。"""
- now = int(time.time())
- expired = [jti for jti, exp in self._revoked_jti_exp.items() if exp <= now]
- for jti in expired:
- self._revoked_jti_exp.pop(jti, None)
- def revoke_payload(self, payload: dict[str, Any]) -> None:
- """将当前 payload 对应的 jti 标记为撤销。"""
- jti = payload.get("jti")
- exp = payload.get("exp")
- if not jti or not exp:
- return
- try:
- exp_ts = int(exp)
- except (TypeError, ValueError):
- return
- redis = self._get_redis()
- if redis:
- try:
- ttl = max(exp_ts - int(time.time()), 1)
- redis.setex(f"{_JTI_PREFIX}{jti}", ttl, "1")
- return
- except Exception as e:
- logger.warning(f"Redis revoke_payload 失败,降级到内存: {e}")
- # 内存降级
- with self._lock:
- self._cleanup_memory()
- self._revoked_jti_exp[str(jti)] = exp_ts
- def revoke_user_sessions(self, user_id: str) -> None:
- """使指定用户在当前时间之前签发的 token 全部失效。"""
- now = int(time.time())
- redis = self._get_redis()
- if redis:
- try:
- redis.setex(f"{_USER_LOGOUT_PREFIX}{user_id}", _JTI_TTL, str(now))
- return
- except Exception as e:
- logger.warning(f"Redis revoke_user_sessions 失败,降级到内存: {e}")
- with self._lock:
- self._cleanup_memory()
- self._user_logout_after[str(user_id)] = now
- def is_payload_revoked(self, payload: dict[str, Any]) -> bool:
- """判断 payload 对应 token 是否已失效。"""
- jti = payload.get("jti")
- user_id = payload.get("user_id")
- redis = self._get_redis()
- if redis:
- try:
- # 检查 JTI 是否被撤销
- if jti and redis.exists(f"{_JTI_PREFIX}{jti}"):
- return True
- # 检查用户会话是否被全局撤销
- if user_id:
- cutoff_raw = redis.get(f"{_USER_LOGOUT_PREFIX}{user_id}")
- if cutoff_raw:
- try:
- cutoff = int(cutoff_raw)
- issued_at = int(payload.get("iat", 0))
- if issued_at < cutoff:
- return True
- except (TypeError, ValueError):
- return True
- return False
- except Exception as e:
- logger.warning(f"Redis is_payload_revoked 失败,降级到内存: {e}")
- # 内存降级检查
- with self._lock:
- self._cleanup_memory()
- if jti and str(jti) in self._revoked_jti_exp:
- return True
- if not user_id:
- return False
- cutoff = self._user_logout_after.get(str(user_id))
- if cutoff is None:
- return False
- try:
- issued_at = int(payload.get("iat", 0))
- except (TypeError, ValueError):
- return True
- return issued_at < cutoff
- token_revocation_service = TokenRevocationService()
|