Explorar o código

长久token 逻辑优化

lingmin_package@163.com hai 1 mes
pai
achega
e3cbb7163f
Modificáronse 2 ficheiros con 162 adicións e 73 borrados
  1. 53 4
      backend/middleware/auth_middleware.py
  2. 109 69
      backend/scripts/generate_admin_token.py

+ 53 - 4
backend/middleware/auth_middleware.py

@@ -2,14 +2,18 @@
 Authentication Middleware for SSO token verification.
 Authentication Middleware for SSO token verification.
 Validates SSO tokens via the SSO center's userinfo endpoint,
 Validates SSO tokens via the SSO center's userinfo endpoint,
 with an in-memory cache to reduce external calls.
 with an in-memory cache to reduce external calls.
+
+Also supports admin tokens generated by generate_admin_token.py script.
 """
 """
 import logging
 import logging
+from datetime import datetime, timezone
 from fastapi import Request, HTTPException, status
 from fastapi import Request, HTTPException, status
 from fastapi.responses import JSONResponse
 from fastapi.responses import JSONResponse
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from services.token_cache_service import TokenCacheService
 from services.token_cache_service import TokenCacheService
 from services.oauth_service import OAuthService
 from services.oauth_service import OAuthService
 from config import settings
 from config import settings
+from database import get_db_connection
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -20,6 +24,42 @@ token_cache = TokenCacheService(
 )
 )
 
 
 
 
+def verify_admin_token(token: str) -> dict:
+    """
+    验证管理员 Token(从数据库查询)
+
+    Args:
+        token: Token 字符串
+
+    Returns:
+        dict: 用户信息字典,或 None(Token 无效或已过期)
+    """
+    try:
+        with get_db_connection() as conn:
+            cursor = conn.cursor()
+            cursor.execute("""
+                SELECT at.user_id, u.username, u.email, u.role, at.expires_at
+                FROM admin_tokens at
+                JOIN users u ON at.user_id = u.id
+                WHERE at.token = %s AND at.expires_at > %s
+            """, (token, datetime.now(timezone.utc)))
+            row = cursor.fetchone()
+
+            if not row:
+                return None
+
+            return {
+                "id": row["user_id"],
+                "username": row["username"],
+                "email": row["email"],
+                "role": row["role"],
+                "is_admin_token": True,
+            }
+    except Exception as e:
+        logger.error(f"验证管理员 Token 失败:{e}")
+        return None
+
+
 class AuthMiddleware(BaseHTTPMiddleware):
 class AuthMiddleware(BaseHTTPMiddleware):
     """
     """
     SSO Token 认证中间件。
     SSO Token 认证中间件。
@@ -85,11 +125,20 @@ class AuthMiddleware(BaseHTTPMiddleware):
         sso_token = parts[1]
         sso_token = parts[1]
 
 
         try:
         try:
-            # 1. 先查本地缓存
-            user_info = token_cache.get(sso_token)
+            # 1. 先检查是否是管理员 Token(以 admin_token_ 开头)
+            user_info = None
+            if sso_token.startswith("admin_token_"):
+                logger.debug("检测到管理员 Token,尝试从数据库验证")
+                user_info = verify_admin_token(sso_token)
+                if user_info:
+                    logger.info(f"管理员 Token 验证成功:{user_info['username']}")
+
+            # 2. 如果不是管理员 Token,查本地缓存(SSO token)
+            if user_info is None:
+                user_info = token_cache.get(sso_token)
 
 
+            # 3. 缓存未命中,调 SSO profile 验证(含角色信息)
             if user_info is None:
             if user_info is None:
-                # 2. 缓存未命中,调 SSO profile 验证(含角色信息)
                 user_info = await OAuthService.verify_sso_token(sso_token)
                 user_info = await OAuthService.verify_sso_token(sso_token)
 
 
                 # 3. 同步用户到本地数据库(更新角色),获取本地用户ID
                 # 3. 同步用户到本地数据库(更新角色),获取本地用户ID
@@ -116,7 +165,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
                 except Exception as sync_err:
                 except Exception as sync_err:
                     logger.warning(f"重新同步用户失败,使用SSO ID: {sync_err}")
                     logger.warning(f"重新同步用户失败,使用SSO ID: {sync_err}")
             
             
-            user_id = local_user_id or user_info.get("id") or user_info.get("sub")
+            user_id = local_user_id or user_info.get("id") or user_info.get("sub") if not user_info.get("is_admin_token") else user_info.get("id")
             username = (
             username = (
                 user_info.get("username")
                 user_info.get("username")
                 or user_info.get("preferred_username")
                 or user_info.get("preferred_username")

+ 109 - 69
backend/scripts/generate_admin_token.py

@@ -1,26 +1,26 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 """
 """
-生成长期有效的管理员Token脚本
+生成长期有效的管理员 Token 脚本
 
 
 功能:
 功能:
 1. 查找管理员用户
 1. 查找管理员用户
-2. 生成99999天有效期的Token
-3. 输出Token并验证有效性
+2. 生成一个随机的长期 Token
+3. 将 Token 持久化到数据库的 admin_tokens 表中
 
 
 使用方式:
 使用方式:
     cd backend
     cd backend
     python scripts/generate_admin_token.py
     python scripts/generate_admin_token.py
 
 
-注意:需要在backend目录下运行,以确保正确加载配置
+注意:需要在 backend 目录下运行,以确保正确加载配置
 """
 """
 import sys
 import sys
 import os
 import os
 
 
-# 添加backend目录到路径
+# 添加 backend 目录到路径
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 
 from datetime import datetime, timedelta, timezone
 from datetime import datetime, timedelta, timezone
-import jwt
+import secrets
 from config import settings
 from config import settings
 from database import get_db_connection
 from database import get_db_connection
 
 
@@ -46,100 +46,135 @@ def find_admin_user():
         return None
         return None
 
 
 
 
-def create_long_term_token(user_data: dict, days: int = 99999) -> str:
+def ensure_admin_tokens_table():
+    """确保 admin_tokens 表存在"""
+    with get_db_connection() as conn:
+        cursor = conn.cursor()
+        cursor.execute("""
+            CREATE TABLE IF NOT EXISTS admin_tokens (
+                id VARCHAR(36) PRIMARY KEY,
+                user_id VARCHAR(36) NOT NULL,
+                token VARCHAR(255) NOT NULL UNIQUE,
+                expires_at TIMESTAMP NOT NULL,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                INDEX idx_admin_tokens_token (token),
+                INDEX idx_admin_tokens_user_id (user_id)
+            )
+        """)
+        conn.commit()
+
+
+def create_long_term_token(user_data: dict, years: int = 10) -> str:
     """
     """
-    创建长期有效的Token
-    
+    创建长期有效的 Token(持久化到数据库)
+
     Args:
     Args:
         user_data: 用户信息字典
         user_data: 用户信息字典
-        days: 有效天数,默认99999天
-        
+        years: 有效年数,默认 10 年
+
     Returns:
     Returns:
-        str: JWT Token
+        str: Token 字符串
     """
     """
-    expire = datetime.now(timezone.utc) + timedelta(days=days)
-    payload = {
-        "sub": user_data["id"],
-        "username": user_data["username"],
-        "email": user_data["email"],
-        "role": user_data["role"],
-        "exp": expire,
-        "iat": datetime.now(timezone.utc),
-        "type": "access"
-    }
-    return jwt.encode(
-        payload,
-        settings.JWT_SECRET_KEY,
-        algorithm=settings.JWT_ALGORITHM
-    )
+    # 生成随机 token
+    token = f"admin_token_{secrets.token_urlsafe(32)}"
+    token_id = f"token_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
+    expire = datetime.now(timezone.utc) + timedelta(days=years * 365)
+
+    # 持久化到数据库
+    with get_db_connection() as conn:
+        cursor = conn.cursor()
+        cursor.execute("""
+            INSERT INTO admin_tokens (id, user_id, token, expires_at)
+            VALUES (%s, %s, %s, %s)
+        """, (token_id, user_data["id"], token, expire.replace(microsecond=0)))
+        conn.commit()
+
+    return token
 
 
 
 
 def verify_token(token: str) -> dict:
 def verify_token(token: str) -> dict:
     """
     """
-    验证Token有效性
-    
+    验证 Token 有效性(从数据库查询)
+
     Args:
     Args:
-        token: JWT Token
-        
+        token: Token 字符串
+
     Returns:
     Returns:
-        dict: 解码后的payload
+        dict: 用户信息字典
+
+    Raises:
+        Exception: Token 无效或已过期
     """
     """
-    try:
-        payload = jwt.decode(
-            token,
-            settings.JWT_SECRET_KEY,
-            algorithms=[settings.JWT_ALGORITHM]
-        )
-        return payload
-    except jwt.ExpiredSignatureError:
-        raise Exception("Token已过期")
-    except jwt.InvalidTokenError as e:
-        raise Exception(f"Token无效: {str(e)}")
+    with get_db_connection() as conn:
+        cursor = conn.cursor()
+        cursor.execute("""
+            SELECT at.user_id, u.username, u.email, u.role, at.expires_at
+            FROM admin_tokens at
+            JOIN users u ON at.user_id = u.id
+            WHERE at.token = %s AND at.expires_at > %s
+        """, (token, datetime.now(timezone.utc)))
+        row = cursor.fetchone()
+
+        if not row:
+            raise Exception("Token 无效或已过期")
+
+        return {
+            "id": row["user_id"],
+            "username": row["username"],
+            "email": row["email"],
+            "role": row["role"],
+            "expires_at": row["expires_at"]
+        }
 
 
 
 
 def main():
 def main():
     print("=" * 60)
     print("=" * 60)
-    print("管理员长期Token生成工具")
+    print("管理员长期 Token 生成工具")
     print("=" * 60)
     print("=" * 60)
     print()
     print()
-    
+
+    # 确保表存在
+    print("正在检查数据库表...")
+    ensure_admin_tokens_table()
+    print("[OK] 数据库表检查完成")
+    print()
+
     # 查找管理员用户
     # 查找管理员用户
     print("正在查找管理员用户...")
     print("正在查找管理员用户...")
     admin_user = find_admin_user()
     admin_user = find_admin_user()
-    
+
     if not admin_user:
     if not admin_user:
-        print("\n❌ 错误: 未找到管理员用户!")
+        print("\n[ERROR] 未找到管理员用户!")
         print("\n请先创建管理员用户,可以使用以下方式:")
         print("\n请先创建管理员用户,可以使用以下方式:")
         print("  1. 运行 python create_test_user.py 创建测试用户")
         print("  1. 运行 python create_test_user.py 创建测试用户")
-        print("  2. 或通过API注册用户后在数据库中将role改为admin")
+        print("  2. 或通过 API 注册用户后在数据库中将 role 改为 admin")
         sys.exit(1)
         sys.exit(1)
-    
-    print(f"✓ 找到管理员用户: {admin_user['username']} ({admin_user['email']})")
+
+    print(f"[OK] 找到管理员用户:{admin_user['username']} ({admin_user['email']})")
     print()
     print()
-    
-    # 生成Token
-    print("正在生成99999天有效期的Token...")
-    token = create_long_term_token(admin_user, days=99999)
-    print("✓ Token生成成功!")
+
+    # 生成 Token
+    print("正在生成 10 年有效期的 Token...")
+    token = create_long_term_token(admin_user, years=10)
+    print("[OK] Token 生成成功!")
     print()
     print()
-    
-    # 验证Token
-    print("正在验证Token有效性...")
+
+    # 验证 Token
+    print("正在验证 Token 有效性...")
     try:
     try:
-        payload = verify_token(token)
-        expire_time = datetime.fromtimestamp(payload["exp"])
-        print(f"✓ Token验证通过!")
-        print(f"  - 用户ID: {payload['sub']}")
-        print(f"  - 用户名: {payload['username']}")
-        print(f"  - 角色: {payload['role']}")
-        print(f"  - 过期时间: {expire_time.strftime('%Y-%m-%d %H:%M:%S')}")
+        user_info = verify_token(token)
+        print(f"[OK] Token 验证通过!")
+        print(f"  - 用户 ID: {user_info['id']}")
+        print(f"  - 用户名:{user_info['username']}")
+        print(f"  - 角色:{user_info['role']}")
+        print(f"  - 过期时间:{user_info['expires_at'].strftime('%Y-%m-%d %H:%M:%S')}")
     except Exception as e:
     except Exception as e:
-        print(f"❌ Token验证失败: {str(e)}")
+        print(f"[ERROR] Token 验证失败:{str(e)}")
         sys.exit(1)
         sys.exit(1)
-    
+
     print()
     print()
     print("=" * 60)
     print("=" * 60)
-    print("生成的管理员Token (请妥善保管):")
+    print("生成的管理员 Token (请妥善保管):")
     print("=" * 60)
     print("=" * 60)
     print()
     print()
     print(token)
     print(token)
@@ -147,9 +182,14 @@ def main():
     print("=" * 60)
     print("=" * 60)
     print()
     print()
     print("使用方式:")
     print("使用方式:")
-    print("  在HTTP请求头中添加:")
+    print("  在 HTTP 请求头中添加:")
     print(f"  Authorization: Bearer {token[:50]}...")
     print(f"  Authorization: Bearer {token[:50]}...")
     print()
     print()
+    print("注意:")
+    print("  - 此 Token 持久化存储在数据库中")
+    print("  - 服务器重启后 Token 仍然有效")
+    print("  - Token 过期时间:10 年")
+    print()
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":