|
|
@@ -1,3 +1,4 @@
|
|
|
+from routers.report_compat import router as report_compat_router
|
|
|
from fastapi import FastAPI, Request
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
@@ -22,9 +23,9 @@ app.add_middleware(
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
|
|
- allow_headers=["Origin", "Authorization", "Access-Control-Allow-Origin",
|
|
|
+ allow_headers=["Origin", "Authorization", "Access-Control-Allow-Origin",
|
|
|
"Access-Control-Allow-Headers", "Content-Type", "token"],
|
|
|
- expose_headers=["Content-Length", "Access-Control-Allow-Origin",
|
|
|
+ expose_headers=["Content-Length", "Access-Control-Allow-Origin",
|
|
|
"Access-Control-Allow-Headers", "Content-Type"]
|
|
|
)
|
|
|
|
|
|
@@ -34,35 +35,39 @@ app.add_middleware(
|
|
|
async def combined_middleware(request: Request, call_next):
|
|
|
"""组合中间件:日志 + 认证"""
|
|
|
from fastapi.responses import JSONResponse
|
|
|
- from utils.token import verify_token
|
|
|
-
|
|
|
+ from utils.token import verify_local_token
|
|
|
+
|
|
|
start_time = time.time()
|
|
|
path = request.url.path
|
|
|
-
|
|
|
+
|
|
|
# 先打印,确认中间件被执行
|
|
|
print(f"[DEBUG] 中间件执行 - 路径: {path}")
|
|
|
logger.info(f"[中间件] 开始处理请求: {path}")
|
|
|
-
|
|
|
+
|
|
|
# 白名单路径(不需要认证)
|
|
|
- whitelist_paths = ["/health", "/docs", "/redoc", "/openapi.json", "/static/", "/assets/", "/apiv1/auth/local_login", "/apiv1/auth/register"]
|
|
|
-
|
|
|
+ whitelist_paths = ["/health", "/docs", "/redoc", "/openapi.json",
|
|
|
+ "/static/", "/assets/", "/apiv1/auth/local_login", "/apiv1/auth/register"]
|
|
|
+
|
|
|
# 检查是否在白名单中(精确匹配或以/结尾的前缀匹配)
|
|
|
- is_whitelist = path == "/" or any(path.startswith(wp) for wp in whitelist_paths)
|
|
|
-
|
|
|
+ is_whitelist = path == "/" or any(path.startswith(wp)
|
|
|
+ for wp in whitelist_paths)
|
|
|
+
|
|
|
print(f"[DEBUG] 是否白名单: {is_whitelist}")
|
|
|
-
|
|
|
+
|
|
|
if is_whitelist:
|
|
|
print(f"[DEBUG] 白名单路径,跳过认证")
|
|
|
request.state.user = None
|
|
|
response = await call_next(request)
|
|
|
else:
|
|
|
# 获取Token
|
|
|
- token = request.headers.get("token") or request.headers.get("Authorization", "").replace("Bearer ", "")
|
|
|
-
|
|
|
+ token = request.headers.get("token") or request.headers.get(
|
|
|
+ "Authorization", "").replace("Bearer ", "")
|
|
|
+
|
|
|
print(f"[DEBUG] Token: {token[:20] if token else 'None'}...")
|
|
|
logger.info(f"认证中间件 - 路径: {path}")
|
|
|
- logger.info(f"认证中间件 - Token (前20字符): {token[:20] if token else 'None'}...")
|
|
|
-
|
|
|
+ logger.info(
|
|
|
+ f"认证中间件 - Token (前20字符): {token[:20] if token else 'None'}...")
|
|
|
+
|
|
|
if not token:
|
|
|
print(f"[DEBUG] 未提供Token")
|
|
|
logger.warning("认证中间件 - 未提供Token")
|
|
|
@@ -74,10 +79,11 @@ async def combined_middleware(request: Request, call_next):
|
|
|
# 验证Token
|
|
|
print(f"[DEBUG] 开始验证Token")
|
|
|
logger.info("认证中间件 - 开始验证Token")
|
|
|
- user_info = await verify_token(token)
|
|
|
-
|
|
|
+ # 注意:verify_local_token 不是异步函数,直接调用
|
|
|
+ user_info = verify_local_token(token)
|
|
|
+
|
|
|
print(f"[DEBUG] 验证结果: {user_info}")
|
|
|
-
|
|
|
+
|
|
|
if not user_info:
|
|
|
print(f"[DEBUG] Token验证失败")
|
|
|
logger.error("认证中间件 - Token验证失败,返回401")
|
|
|
@@ -86,23 +92,31 @@ async def combined_middleware(request: Request, call_next):
|
|
|
content={"statusCode": 401, "msg": "Token验证失败"}
|
|
|
)
|
|
|
else:
|
|
|
- print(f"[DEBUG] Token验证成功: {user_info.username}")
|
|
|
- logger.info(f"认证中间件 - Token验证成功,用户: {user_info.username} ({user_info.account})")
|
|
|
- request.state.user = user_info
|
|
|
+ # 为了不破坏后续代码依赖对象的结构,将 dict 转为带属性的类
|
|
|
+ class UserInfo:
|
|
|
+ def __init__(self, d):
|
|
|
+ self.__dict__.update(d)
|
|
|
+
|
|
|
+ user_obj = UserInfo(user_info)
|
|
|
+ print(
|
|
|
+ f"[DEBUG] Token验证成功: {getattr(user_obj, 'username', 'unknown')}")
|
|
|
+ logger.info(
|
|
|
+ f"认证中间件 - Token验证成功,用户: {getattr(user_obj, 'username', 'unknown')} ({getattr(user_obj, 'account', 'unknown')})")
|
|
|
+ request.state.user = user_obj
|
|
|
response = await call_next(request)
|
|
|
-
|
|
|
+
|
|
|
# 记录日志
|
|
|
process_time = time.time() - start_time
|
|
|
print(f"[DEBUG] 请求完成 - 状态码: {response.status_code}")
|
|
|
- logger.info(f"请求完成: {request.method} {path} - 状态码: {response.status_code} - 耗时: {process_time:.3f}s")
|
|
|
-
|
|
|
+ logger.info(
|
|
|
+ f"请求完成: {request.method} {path} - 状态码: {response.status_code} - 耗时: {process_time:.3f}s")
|
|
|
+
|
|
|
return response
|
|
|
|
|
|
# 注册路由
|
|
|
app.include_router(api_router)
|
|
|
|
|
|
# 单独注册报告兼容路由(避免双重前缀)
|
|
|
-from routers.report_compat import router as report_compat_router
|
|
|
app.include_router(report_compat_router)
|
|
|
|
|
|
# 创建静态文件目录
|
|
|
@@ -114,8 +128,6 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
app.mount("/assets", StaticFiles(directory="assets"), name="assets")
|
|
|
|
|
|
|
|
|
-
|
|
|
-
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def root():
|
|
|
"""根路径 - 欢迎页面"""
|
|
|
@@ -325,11 +337,13 @@ if __name__ == "__main__":
|
|
|
logger.info("=" * 60)
|
|
|
logger.info("🚀 Shudao Chat API 启动中...")
|
|
|
logger.info(f"📍 服务地址: http://{settings.app.host}:{settings.app.port}")
|
|
|
- logger.info(f"📚 API 文档: http://{settings.app.host}:{settings.app.port}/docs")
|
|
|
- logger.info(f"🗄️ 数据库: {settings.database.host}:{settings.database.port}/{settings.database.database}")
|
|
|
+ logger.info(
|
|
|
+ f"📚 API 文档: http://{settings.app.host}:{settings.app.port}/docs")
|
|
|
+ logger.info(
|
|
|
+ f"🗄️ 数据库: {settings.database.host}:{settings.database.port}/{settings.database.database}")
|
|
|
logger.info(f"🔧 调试模式: {'开启' if settings.app.debug else '关闭'}")
|
|
|
logger.info("=" * 60)
|
|
|
-
|
|
|
+
|
|
|
uvicorn.run(
|
|
|
"main:app",
|
|
|
host=settings.app.host,
|