auth.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """
  2. 认证依赖
  3. 提供JWT令牌验证和用户获取的依赖注入
  4. 需求: 5.1, 5.4, 6.1, 6.2, 6.3, 6.4
  5. """
  6. import json
  7. import logging
  8. from fastapi import Depends, HTTPException, status
  9. from fastapi.security import OAuth2PasswordBearer
  10. from jose import JWTError
  11. from sqlalchemy.orm import Session
  12. from app.database import get_db
  13. from app.models.user import User
  14. from app.models.admin import AdminUser
  15. from app.services.auth_service import AuthService
  16. from app.services.user_service import UserService
  17. from app.services.token_revocation_service import token_revocation_service
  18. logger = logging.getLogger(__name__)
  19. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
  20. admin_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/admin/auth/login")
  21. # 用户信息缓存 TTL(秒),与 JWT 有效期远小于 24h,5 分钟足够
  22. _USER_CACHE_TTL = 300
  23. _USER_CACHE_PREFIX = "auth_user:"
  24. def _get_user_from_cache(user_id: str):
  25. """从 Redis 同步缓存获取用户基本信息,未命中返回 None。"""
  26. try:
  27. from app.core.redis import redis_manager
  28. r = redis_manager.get_sync_client()
  29. if not r:
  30. return None
  31. raw = r.get(f"{_USER_CACHE_PREFIX}{user_id}")
  32. if not raw:
  33. return None
  34. data = json.loads(raw)
  35. # 反序列化 datetime 字段
  36. from datetime import datetime, date
  37. for dt_field in ("created_at", "updated_at"):
  38. if data.get(dt_field):
  39. data[dt_field] = datetime.fromisoformat(data[dt_field])
  40. if data.get("registration_date"):
  41. data["registration_date"] = date.fromisoformat(data["registration_date"])
  42. # 用 SimpleNamespace 构造轻量对象,避免触发 SQLAlchemy ORM 机制
  43. from types import SimpleNamespace
  44. user = SimpleNamespace(**data)
  45. return user
  46. except Exception as e:
  47. logger.debug(f"用户缓存读取失败: {e}")
  48. return None
  49. def _set_user_cache(user: User) -> None:
  50. """将用户基本信息写入 Redis 缓存。"""
  51. try:
  52. from app.core.redis import redis_manager
  53. r = redis_manager.get_sync_client()
  54. if not r:
  55. return
  56. data = {
  57. "id": user.id,
  58. "username": user.username,
  59. "nickname": user.nickname,
  60. "email": user.email,
  61. "phone": user.phone,
  62. "apikey": user.apikey,
  63. "status": user.status,
  64. "avatar": user.avatar,
  65. "created_at": user.created_at.isoformat() if user.created_at else None,
  66. "updated_at": user.updated_at.isoformat() if user.updated_at else None,
  67. "registration_date": user.registration_date.isoformat() if getattr(user, "registration_date", None) else None,
  68. }
  69. r.setex(f"{_USER_CACHE_PREFIX}{user.id}", _USER_CACHE_TTL, json.dumps(data))
  70. except Exception as e:
  71. logger.debug(f"用户缓存写入失败: {e}")
  72. def invalidate_user_cache(user_id: str) -> None:
  73. """主动清除用户缓存(用户信息更新时调用)。"""
  74. try:
  75. from app.core.redis import redis_manager
  76. r = redis_manager.get_sync_client()
  77. if r:
  78. r.delete(f"{_USER_CACHE_PREFIX}{user_id}")
  79. except Exception:
  80. pass
  81. def get_current_user(
  82. token: str = Depends(oauth2_scheme),
  83. db: Session = Depends(get_db)
  84. ) -> User:
  85. """从JWT令牌获取当前用户。"""
  86. credentials_exception = HTTPException(
  87. status_code=status.HTTP_401_UNAUTHORIZED,
  88. detail="Invalid credentials",
  89. headers={"WWW-Authenticate": "Bearer"},
  90. )
  91. try:
  92. payload = AuthService.verify_token(token)
  93. if token_revocation_service.is_payload_revoked(payload):
  94. raise credentials_exception
  95. user_id = payload.get("user_id")
  96. if not user_id:
  97. raise credentials_exception
  98. except JWTError:
  99. raise credentials_exception
  100. user_service = UserService(db)
  101. user = user_service.get_user_by_id(user_id)
  102. if not user:
  103. raise credentials_exception
  104. return user
  105. def get_current_apikey(user: User = Depends(get_current_user)) -> str:
  106. """获取当前用户的APIkey"""
  107. if not user.apikey:
  108. raise HTTPException(
  109. status_code=status.HTTP_403_FORBIDDEN,
  110. detail="No APIkey configured"
  111. )
  112. return user.apikey
  113. def get_current_admin(
  114. token: str = Depends(admin_oauth2_scheme),
  115. db: Session = Depends(get_db)
  116. ) -> AdminUser:
  117. """从JWT令牌获取当前管理员"""
  118. from app.services.admin_auth_service import AdminAuthService
  119. credentials_exception = HTTPException(
  120. status_code=status.HTTP_401_UNAUTHORIZED,
  121. detail="Invalid admin credentials",
  122. headers={"WWW-Authenticate": "Bearer"},
  123. )
  124. try:
  125. payload = AdminAuthService.verify_token(token)
  126. admin_id = payload.get("admin_id")
  127. if not admin_id:
  128. raise credentials_exception
  129. except JWTError:
  130. raise credentials_exception
  131. admin = db.query(AdminUser).filter(AdminUser.id == admin_id).first()
  132. if not admin or admin.status != "active":
  133. raise credentials_exception
  134. return admin