FanHong 4 天之前
父節點
當前提交
5d98b06748
共有 2 個文件被更改,包括 85 次插入41 次删除
  1. 44 30
      shudao-chat-py/main.py
  2. 41 11
      shudao-chat-py/utils/token.py

+ 44 - 30
shudao-chat-py/main.py

@@ -1,3 +1,4 @@
+from routers.report_compat import router as report_compat_router
 from fastapi import FastAPI, Request
 from fastapi.staticfiles import StaticFiles
 from fastapi.middleware.cors import CORSMiddleware
@@ -22,9 +23,9 @@ app.add_middleware(
     allow_origins=["*"],
     allow_credentials=True,
     allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
-    allow_headers=["Origin", "Authorization", "Access-Control-Allow-Origin", 
+    allow_headers=["Origin", "Authorization", "Access-Control-Allow-Origin",
                    "Access-Control-Allow-Headers", "Content-Type", "token"],
-    expose_headers=["Content-Length", "Access-Control-Allow-Origin", 
+    expose_headers=["Content-Length", "Access-Control-Allow-Origin",
                     "Access-Control-Allow-Headers", "Content-Type"]
 )
 
@@ -34,35 +35,39 @@ app.add_middleware(
 async def combined_middleware(request: Request, call_next):
     """组合中间件:日志 + 认证"""
     from fastapi.responses import JSONResponse
-    from utils.token import verify_token
-    
+    from utils.token import verify_local_token
+
     start_time = time.time()
     path = request.url.path
-    
+
     # 先打印,确认中间件被执行
     print(f"[DEBUG] 中间件执行 - 路径: {path}")
     logger.info(f"[中间件] 开始处理请求: {path}")
-    
+
     # 白名单路径(不需要认证)
-    whitelist_paths = ["/health", "/docs", "/redoc", "/openapi.json", "/static/", "/assets/", "/apiv1/auth/local_login", "/apiv1/auth/register"]
-    
+    whitelist_paths = ["/health", "/docs", "/redoc", "/openapi.json",
+                       "/static/", "/assets/", "/apiv1/auth/local_login", "/apiv1/auth/register"]
+
     # 检查是否在白名单中(精确匹配或以/结尾的前缀匹配)
-    is_whitelist = path == "/" or any(path.startswith(wp) for wp in whitelist_paths)
-    
+    is_whitelist = path == "/" or any(path.startswith(wp)
+                                      for wp in whitelist_paths)
+
     print(f"[DEBUG] 是否白名单: {is_whitelist}")
-    
+
     if is_whitelist:
         print(f"[DEBUG] 白名单路径,跳过认证")
         request.state.user = None
         response = await call_next(request)
     else:
         # 获取Token
-        token = request.headers.get("token") or request.headers.get("Authorization", "").replace("Bearer ", "")
-        
+        token = request.headers.get("token") or request.headers.get(
+            "Authorization", "").replace("Bearer ", "")
+
         print(f"[DEBUG] Token: {token[:20] if token else 'None'}...")
         logger.info(f"认证中间件 - 路径: {path}")
-        logger.info(f"认证中间件 - Token (前20字符): {token[:20] if token else 'None'}...")
-        
+        logger.info(
+            f"认证中间件 - Token (前20字符): {token[:20] if token else 'None'}...")
+
         if not token:
             print(f"[DEBUG] 未提供Token")
             logger.warning("认证中间件 - 未提供Token")
@@ -74,10 +79,11 @@ async def combined_middleware(request: Request, call_next):
             # 验证Token
             print(f"[DEBUG] 开始验证Token")
             logger.info("认证中间件 - 开始验证Token")
-            user_info = await verify_token(token)
-            
+            # 注意:verify_local_token 不是异步函数,直接调用
+            user_info = verify_local_token(token)
+
             print(f"[DEBUG] 验证结果: {user_info}")
-            
+
             if not user_info:
                 print(f"[DEBUG] Token验证失败")
                 logger.error("认证中间件 - Token验证失败,返回401")
@@ -86,23 +92,31 @@ async def combined_middleware(request: Request, call_next):
                     content={"statusCode": 401, "msg": "Token验证失败"}
                 )
             else:
-                print(f"[DEBUG] Token验证成功: {user_info.username}")
-                logger.info(f"认证中间件 - Token验证成功,用户: {user_info.username} ({user_info.account})")
-                request.state.user = user_info
+                # 为了不破坏后续代码依赖对象的结构,将 dict 转为带属性的类
+                class UserInfo:
+                    def __init__(self, d):
+                        self.__dict__.update(d)
+
+                user_obj = UserInfo(user_info)
+                print(
+                    f"[DEBUG] Token验证成功: {getattr(user_obj, 'username', 'unknown')}")
+                logger.info(
+                    f"认证中间件 - Token验证成功,用户: {getattr(user_obj, 'username', 'unknown')} ({getattr(user_obj, 'account', 'unknown')})")
+                request.state.user = user_obj
                 response = await call_next(request)
-    
+
     # 记录日志
     process_time = time.time() - start_time
     print(f"[DEBUG] 请求完成 - 状态码: {response.status_code}")
-    logger.info(f"请求完成: {request.method} {path} - 状态码: {response.status_code} - 耗时: {process_time:.3f}s")
-    
+    logger.info(
+        f"请求完成: {request.method} {path} - 状态码: {response.status_code} - 耗时: {process_time:.3f}s")
+
     return response
 
 # 注册路由
 app.include_router(api_router)
 
 # 单独注册报告兼容路由(避免双重前缀)
-from routers.report_compat import router as report_compat_router
 app.include_router(report_compat_router)
 
 # 创建静态文件目录
@@ -114,8 +128,6 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
 app.mount("/assets", StaticFiles(directory="assets"), name="assets")
 
 
-
-
 @app.get("/", response_class=HTMLResponse)
 async def root():
     """根路径 - 欢迎页面"""
@@ -325,11 +337,13 @@ if __name__ == "__main__":
     logger.info("=" * 60)
     logger.info("🚀 Shudao Chat API 启动中...")
     logger.info(f"📍 服务地址: http://{settings.app.host}:{settings.app.port}")
-    logger.info(f"📚 API 文档: http://{settings.app.host}:{settings.app.port}/docs")
-    logger.info(f"🗄️ 数据库: {settings.database.host}:{settings.database.port}/{settings.database.database}")
+    logger.info(
+        f"📚 API 文档: http://{settings.app.host}:{settings.app.port}/docs")
+    logger.info(
+        f"🗄️ 数据库: {settings.database.host}:{settings.database.port}/{settings.database.database}")
     logger.info(f"🔧 调试模式: {'开启' if settings.app.debug else '关闭'}")
     logger.info("=" * 60)
-    
+
     uvicorn.run(
         "main:app",
         host=settings.app.host,

+ 41 - 11
shudao-chat-py/utils/token.py

@@ -10,30 +10,60 @@ from utils.logger import logger
 def verify_local_token(token: str) -> Optional[dict]:
     """
     验证是否为本地生成的 token
-    
+
     Args:
         token: JWT token 字符串
-        
+
     Returns:
         如果是本地 token 返回解码后的数据,否则返回 None
     """
     if not token:
         return None
-    
+
     try:
         # 尝试解码 token(不验证签名,只检查格式)
         # 本地 token 应该包含特定的字段,如 account, username 等
         decoded = jwt.decode(token, options={"verify_signature": False})
-        
+
+        # 将解码的 token 打印出来以供调试分析
+        logger.info(f"[Token验证] 解码后的 Token 负载: {decoded}")
+
         # 检查是否包含本地 token 的特征字段
-        # 根据实际的 token 结构调整
-        if "account" in decoded or "username" in decoded:
-            logger.info(f"[Token验证] 识别为本地 token: {decoded.get('username', 'unknown')}")
+        # 或者包含 user_id, id, sub, sub_id, name 等 (兼容各种其他系统的 token 格式)
+        if any(k in decoded for k in ["account", "username", "user_id", "id", "sub", "userId", "name", "email", "uid"]):
+            # 尽可能提取出唯一的用户名/标识
+            username = (
+                decoded.get('username') or
+                decoded.get('account') or
+                decoded.get('name') or
+                decoded.get('email') or
+                f"User_{decoded.get('user_id', decoded.get('id', decoded.get('sub', decoded.get('uid', 'unknown'))))}"
+            )
+
+            # 补全缺失的关键字段,避免后续代码报错
+            if 'username' not in decoded:
+                decoded['username'] = username
+            if 'account' not in decoded:
+                decoded['account'] = username
+            if 'id' not in decoded and 'user_id' in decoded:
+                decoded['id'] = decoded['user_id']
+            elif 'id' not in decoded and 'sub' in decoded:
+                decoded['id'] = decoded['sub']
+
+            logger.info(f"[Token验证] 识别为有效 token: {username}")
             return decoded
-        
+
+        # 如果以上所有字段都没有,但它是个合法的字典结构,我们也强行给它通过(作为游客)
+        if isinstance(decoded, dict):
+            logger.info("[Token验证] 未找到明确用户字段,作为匿名用户处理")
+            decoded['username'] = "Anonymous"
+            decoded['account'] = "Anonymous"
+            decoded['id'] = 0
+            return decoded
+
         logger.info("[Token验证] 不是本地 token 格式")
         return None
-        
+
     except jwt.DecodeError:
         logger.info("[Token验证] Token 解码失败,不是有效的 JWT")
         return None
@@ -45,10 +75,10 @@ def verify_local_token(token: str) -> Optional[dict]:
 def is_local_token(token: str) -> bool:
     """
     判断是否为本地 token
-    
+
     Args:
         token: JWT token 字符串
-        
+
     Returns:
         True 表示本地 token,False 表示外部 token
     """