瀏覽代碼

长久token 逻辑优化

lingmin_package@163.com 1 月之前
父節點
當前提交
e3cbb7163f
共有 2 個文件被更改,包括 162 次插入73 次删除
  1. 53 4
      backend/middleware/auth_middleware.py
  2. 109 69
      backend/scripts/generate_admin_token.py

+ 53 - 4
backend/middleware/auth_middleware.py

@@ -2,14 +2,18 @@
 Authentication Middleware for SSO token verification.
 Validates SSO tokens via the SSO center's userinfo endpoint,
 with an in-memory cache to reduce external calls.
+
+Also supports admin tokens generated by generate_admin_token.py script.
 """
 import logging
+from datetime import datetime, timezone
 from fastapi import Request, HTTPException, status
 from fastapi.responses import JSONResponse
 from starlette.middleware.base import BaseHTTPMiddleware
 from services.token_cache_service import TokenCacheService
 from services.oauth_service import OAuthService
 from config import settings
+from database import get_db_connection
 
 logger = logging.getLogger(__name__)
 
@@ -20,6 +24,42 @@ token_cache = TokenCacheService(
 )
 
 
+def verify_admin_token(token: str) -> dict:
+    """
+    验证管理员 Token(从数据库查询)
+
+    Args:
+        token: Token 字符串
+
+    Returns:
+        dict: 用户信息字典,或 None(Token 无效或已过期)
+    """
+    try:
+        with get_db_connection() as conn:
+            cursor = conn.cursor()
+            cursor.execute("""
+                SELECT at.user_id, u.username, u.email, u.role, at.expires_at
+                FROM admin_tokens at
+                JOIN users u ON at.user_id = u.id
+                WHERE at.token = %s AND at.expires_at > %s
+            """, (token, datetime.now(timezone.utc)))
+            row = cursor.fetchone()
+
+            if not row:
+                return None
+
+            return {
+                "id": row["user_id"],
+                "username": row["username"],
+                "email": row["email"],
+                "role": row["role"],
+                "is_admin_token": True,
+            }
+    except Exception as e:
+        logger.error(f"验证管理员 Token 失败:{e}")
+        return None
+
+
 class AuthMiddleware(BaseHTTPMiddleware):
     """
     SSO Token 认证中间件。
@@ -85,11 +125,20 @@ class AuthMiddleware(BaseHTTPMiddleware):
         sso_token = parts[1]
 
         try:
-            # 1. 先查本地缓存
-            user_info = token_cache.get(sso_token)
+            # 1. 先检查是否是管理员 Token(以 admin_token_ 开头)
+            user_info = None
+            if sso_token.startswith("admin_token_"):
+                logger.debug("检测到管理员 Token,尝试从数据库验证")
+                user_info = verify_admin_token(sso_token)
+                if user_info:
+                    logger.info(f"管理员 Token 验证成功:{user_info['username']}")
+
+            # 2. 如果不是管理员 Token,查本地缓存(SSO token)
+            if user_info is None:
+                user_info = token_cache.get(sso_token)
 
+            # 3. 缓存未命中,调 SSO profile 验证(含角色信息)
             if user_info is None:
-                # 2. 缓存未命中,调 SSO profile 验证(含角色信息)
                 user_info = await OAuthService.verify_sso_token(sso_token)
 
                 # 3. 同步用户到本地数据库(更新角色),获取本地用户ID
@@ -116,7 +165,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
                 except Exception as sync_err:
                     logger.warning(f"重新同步用户失败,使用SSO ID: {sync_err}")
             
-            user_id = local_user_id or user_info.get("id") or user_info.get("sub")
+            user_id = local_user_id or user_info.get("id") or user_info.get("sub") if not user_info.get("is_admin_token") else user_info.get("id")
             username = (
                 user_info.get("username")
                 or user_info.get("preferred_username")

+ 109 - 69
backend/scripts/generate_admin_token.py

@@ -1,26 +1,26 @@
 #!/usr/bin/env python3
 """
-生成长期有效的管理员Token脚本
+生成长期有效的管理员 Token 脚本
 
 功能:
 1. 查找管理员用户
-2. 生成99999天有效期的Token
-3. 输出Token并验证有效性
+2. 生成一个随机的长期 Token
+3. 将 Token 持久化到数据库的 admin_tokens 表中
 
 使用方式:
     cd backend
     python scripts/generate_admin_token.py
 
-注意:需要在backend目录下运行,以确保正确加载配置
+注意:需要在 backend 目录下运行,以确保正确加载配置
 """
 import sys
 import os
 
-# 添加backend目录到路径
+# 添加 backend 目录到路径
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 from datetime import datetime, timedelta, timezone
-import jwt
+import secrets
 from config import settings
 from database import get_db_connection
 
@@ -46,100 +46,135 @@ def find_admin_user():
         return None
 
 
-def create_long_term_token(user_data: dict, days: int = 99999) -> str:
+def ensure_admin_tokens_table():
+    """确保 admin_tokens 表存在"""
+    with get_db_connection() as conn:
+        cursor = conn.cursor()
+        cursor.execute("""
+            CREATE TABLE IF NOT EXISTS admin_tokens (
+                id VARCHAR(36) PRIMARY KEY,
+                user_id VARCHAR(36) NOT NULL,
+                token VARCHAR(255) NOT NULL UNIQUE,
+                expires_at TIMESTAMP NOT NULL,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                INDEX idx_admin_tokens_token (token),
+                INDEX idx_admin_tokens_user_id (user_id)
+            )
+        """)
+        conn.commit()
+
+
+def create_long_term_token(user_data: dict, years: int = 10) -> str:
     """
-    创建长期有效的Token
-    
+    创建长期有效的 Token(持久化到数据库)
+
     Args:
         user_data: 用户信息字典
-        days: 有效天数,默认99999天
-        
+        years: 有效年数,默认 10 年
+
     Returns:
-        str: JWT Token
+        str: Token 字符串
     """
-    expire = datetime.now(timezone.utc) + timedelta(days=days)
-    payload = {
-        "sub": user_data["id"],
-        "username": user_data["username"],
-        "email": user_data["email"],
-        "role": user_data["role"],
-        "exp": expire,
-        "iat": datetime.now(timezone.utc),
-        "type": "access"
-    }
-    return jwt.encode(
-        payload,
-        settings.JWT_SECRET_KEY,
-        algorithm=settings.JWT_ALGORITHM
-    )
+    # 生成随机 token
+    token = f"admin_token_{secrets.token_urlsafe(32)}"
+    token_id = f"token_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
+    expire = datetime.now(timezone.utc) + timedelta(days=years * 365)
+
+    # 持久化到数据库
+    with get_db_connection() as conn:
+        cursor = conn.cursor()
+        cursor.execute("""
+            INSERT INTO admin_tokens (id, user_id, token, expires_at)
+            VALUES (%s, %s, %s, %s)
+        """, (token_id, user_data["id"], token, expire.replace(microsecond=0)))
+        conn.commit()
+
+    return token
 
 
 def verify_token(token: str) -> dict:
     """
-    验证Token有效性
-    
+    验证 Token 有效性(从数据库查询)
+
     Args:
-        token: JWT Token
-        
+        token: Token 字符串
+
     Returns:
-        dict: 解码后的payload
+        dict: 用户信息字典
+
+    Raises:
+        Exception: Token 无效或已过期
     """
-    try:
-        payload = jwt.decode(
-            token,
-            settings.JWT_SECRET_KEY,
-            algorithms=[settings.JWT_ALGORITHM]
-        )
-        return payload
-    except jwt.ExpiredSignatureError:
-        raise Exception("Token已过期")
-    except jwt.InvalidTokenError as e:
-        raise Exception(f"Token无效: {str(e)}")
+    with get_db_connection() as conn:
+        cursor = conn.cursor()
+        cursor.execute("""
+            SELECT at.user_id, u.username, u.email, u.role, at.expires_at
+            FROM admin_tokens at
+            JOIN users u ON at.user_id = u.id
+            WHERE at.token = %s AND at.expires_at > %s
+        """, (token, datetime.now(timezone.utc)))
+        row = cursor.fetchone()
+
+        if not row:
+            raise Exception("Token 无效或已过期")
+
+        return {
+            "id": row["user_id"],
+            "username": row["username"],
+            "email": row["email"],
+            "role": row["role"],
+            "expires_at": row["expires_at"]
+        }
 
 
 def main():
     print("=" * 60)
-    print("管理员长期Token生成工具")
+    print("管理员长期 Token 生成工具")
     print("=" * 60)
     print()
-    
+
+    # 确保表存在
+    print("正在检查数据库表...")
+    ensure_admin_tokens_table()
+    print("[OK] 数据库表检查完成")
+    print()
+
     # 查找管理员用户
     print("正在查找管理员用户...")
     admin_user = find_admin_user()
-    
+
     if not admin_user:
-        print("\n❌ 错误: 未找到管理员用户!")
+        print("\n[ERROR] 未找到管理员用户!")
         print("\n请先创建管理员用户,可以使用以下方式:")
         print("  1. 运行 python create_test_user.py 创建测试用户")
-        print("  2. 或通过API注册用户后在数据库中将role改为admin")
+        print("  2. 或通过 API 注册用户后在数据库中将 role 改为 admin")
         sys.exit(1)
-    
-    print(f"✓ 找到管理员用户: {admin_user['username']} ({admin_user['email']})")
+
+    print(f"[OK] 找到管理员用户:{admin_user['username']} ({admin_user['email']})")
     print()
-    
-    # 生成Token
-    print("正在生成99999天有效期的Token...")
-    token = create_long_term_token(admin_user, days=99999)
-    print("✓ Token生成成功!")
+
+    # 生成 Token
+    print("正在生成 10 年有效期的 Token...")
+    token = create_long_term_token(admin_user, years=10)
+    print("[OK] Token 生成成功!")
     print()
-    
-    # 验证Token
-    print("正在验证Token有效性...")
+
+    # 验证 Token
+    print("正在验证 Token 有效性...")
     try:
-        payload = verify_token(token)
-        expire_time = datetime.fromtimestamp(payload["exp"])
-        print(f"✓ Token验证通过!")
-        print(f"  - 用户ID: {payload['sub']}")
-        print(f"  - 用户名: {payload['username']}")
-        print(f"  - 角色: {payload['role']}")
-        print(f"  - 过期时间: {expire_time.strftime('%Y-%m-%d %H:%M:%S')}")
+        user_info = verify_token(token)
+        print(f"[OK] Token 验证通过!")
+        print(f"  - 用户 ID: {user_info['id']}")
+        print(f"  - 用户名:{user_info['username']}")
+        print(f"  - 角色:{user_info['role']}")
+        print(f"  - 过期时间:{user_info['expires_at'].strftime('%Y-%m-%d %H:%M:%S')}")
     except Exception as e:
-        print(f"❌ Token验证失败: {str(e)}")
+        print(f"[ERROR] Token 验证失败:{str(e)}")
         sys.exit(1)
-    
+
     print()
     print("=" * 60)
-    print("生成的管理员Token (请妥善保管):")
+    print("生成的管理员 Token (请妥善保管):")
     print("=" * 60)
     print()
     print(token)
@@ -147,9 +182,14 @@ def main():
     print("=" * 60)
     print()
     print("使用方式:")
-    print("  在HTTP请求头中添加:")
+    print("  在 HTTP 请求头中添加:")
     print(f"  Authorization: Bearer {token[:50]}...")
     print()
+    print("注意:")
+    print("  - 此 Token 持久化存储在数据库中")
+    print("  - 服务器重启后 Token 仍然有效")
+    print("  - Token 过期时间:10 年")
+    print()
 
 
 if __name__ == "__main__":