main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. """
  2. 智创空间后端服务入口
  3. """
  4. import os
  5. import time
  6. import json
  7. import logging
  8. from logging.handlers import RotatingFileHandler
  9. from datetime import datetime
  10. from contextlib import asynccontextmanager
  11. from fastapi import FastAPI, Request
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from starlette.middleware.base import BaseHTTPMiddleware
  14. from starlette.responses import Response
  15. from app.routers import (
  16. model_router, auth_router, user_router,
  17. oss_router,
  18. local_model_router, platform_api_key_router, openai_compat_router, platform_stats_router,
  19. user_local_model_permission_router, password_strength_router
  20. )
  21. from app.routers.admin_auth_router import router as admin_auth_router
  22. from app.routers.admin_user_router import router as admin_user_router
  23. from app.routers.admin_model_router import router as admin_model_router
  24. from app.routers.admin_log_router import router as admin_log_router
  25. from app.routers.admin_stats_router import router as admin_stats_router
  26. from app.routers.admin_config_router import router as admin_config_router
  27. from app.routers.admin_local_model_router import router as admin_local_model_router
  28. from app.routers.admin_oss_router import router as admin_oss_router
  29. from app.routers.oauth_sso_router import router as sso_router
  30. from app.routers.admin_local_config_router import router as admin_local_config_router
  31. from app.core.async_logger import async_log_queue
  32. from app.core.redis import redis_manager
  33. from app.middleware import register_exception_handlers
  34. from app.middleware.rate_limit_middleware import RateLimitMiddleware
  35. from app.database import engine, SessionLocal
  36. from app.services.user_service import UserService
  37. # ==================== 日志配置 ====================
  38. # 创建 logs 目录
  39. os.makedirs('logs', exist_ok=True)
  40. # 配置日志格式
  41. log_formatter = logging.Formatter(
  42. '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  43. )
  44. # 配置根日志记录器
  45. root_logger = logging.getLogger()
  46. root_logger.setLevel(logging.INFO)
  47. # 清除已有的 handlers,避免重复(reload 模式下会重复加载)
  48. if root_logger.handlers:
  49. root_logger.handlers.clear()
  50. # 控制台处理器
  51. console_handler = logging.StreamHandler()
  52. console_handler.setFormatter(log_formatter)
  53. root_logger.addHandler(console_handler)
  54. # 文件处理器(所有日志)
  55. file_handler = RotatingFileHandler(
  56. 'logs/app.log',
  57. maxBytes=10*1024*1024, # 10MB
  58. backupCount=5,
  59. encoding='utf-8'
  60. )
  61. file_handler.setFormatter(log_formatter)
  62. root_logger.addHandler(file_handler)
  63. # 错误日志文件处理器
  64. error_handler = RotatingFileHandler(
  65. 'logs/error.log',
  66. maxBytes=10*1024*1024, # 10MB
  67. backupCount=5,
  68. encoding='utf-8'
  69. )
  70. error_handler.setLevel(logging.ERROR)
  71. error_handler.setFormatter(log_formatter)
  72. root_logger.addHandler(error_handler)
  73. logger = logging.getLogger(__name__)
  74. # 抑制第三方库的冗余日志(OSS SDK、HTTP 客户端等)
  75. for _noisy_logger in ['oss2', 'aiohttp', 'urllib3', 'urllib3.connectionpool', 'httpcore', 'httpx']:
  76. logging.getLogger(_noisy_logger).setLevel(logging.ERROR)
  77. logger.info("Logging configured: logs/app.log, logs/error.log")
  78. class RequestLogMiddleware(BaseHTTPMiddleware):
  79. """全局请求日志中间件"""
  80. async def dispatch(self, request: Request, call_next):
  81. start_time = time.time()
  82. # 尝试获取用户ID和用户名
  83. user_id = None
  84. username = None
  85. auth_header = request.headers.get("Authorization")
  86. if auth_header and auth_header.startswith("Bearer "):
  87. try:
  88. from app.services.auth_service import AuthService
  89. from app.services.admin_auth_service import AdminAuthService
  90. token = auth_header[7:]
  91. # 先尝试普通用户token
  92. try:
  93. payload = AuthService.verify_token(token)
  94. user_id = payload.get("user_id")
  95. except:
  96. # 再尝试管理员token
  97. try:
  98. payload = AdminAuthService.verify_token(token)
  99. user_id = f"admin_{payload.get('admin_id')}"
  100. except:
  101. pass
  102. except:
  103. pass
  104. response = await call_next(request)
  105. duration_ms = int((time.time() - start_time) * 1000)
  106. # 记录到控制台
  107. log_data = {
  108. "timestamp": datetime.now().isoformat(),
  109. "method": request.method,
  110. "path": str(request.url.path),
  111. "query_params": str(request.query_params) if request.query_params else None,
  112. "user_id": user_id or "anonymous",
  113. "status_code": response.status_code,
  114. "duration_ms": duration_ms
  115. }
  116. if response.status_code >= 400:
  117. logger.warning(f"Request: {json.dumps(log_data, ensure_ascii=False)}")
  118. else:
  119. logger.info(f"Request: {json.dumps(log_data, ensure_ascii=False)}")
  120. # 记录到异步队列(非阻塞,需求 6.1, 6.4)
  121. if user_id and not request.url.path.startswith(("/health", "/static", "/exports")):
  122. async_log_queue.enqueue({
  123. "user_id": user_id,
  124. "api_path": str(request.url.path),
  125. "method": request.method,
  126. "status_code": response.status_code,
  127. "duration_ms": duration_ms,
  128. "request_params": dict(request.query_params) if request.query_params else None,
  129. "request_ip": request.client.host if request.client else None
  130. })
  131. return response
  132. def init_admin_user():
  133. """初始化管理员用户"""
  134. db = SessionLocal()
  135. try:
  136. user_service = UserService(db)
  137. user_service.init_admin_user()
  138. logger.info("管理员用户初始化完成")
  139. except Exception as e:
  140. logger.error(f"管理员用户初始化失败: {e}")
  141. finally:
  142. db.close()
  143. @asynccontextmanager
  144. async def lifespan(app: FastAPI):
  145. """应用生命周期管理"""
  146. logger.info("=" * 50)
  147. logger.info("智创空间后端服务启动中...")
  148. logger.info(f"数据库连接: {engine.url}")
  149. try:
  150. with engine.connect() as conn:
  151. logger.info("数据库连接成功")
  152. init_admin_user()
  153. # 初始化 Redis 连接(需求 4.1)
  154. redis_connected = await redis_manager.connect()
  155. if redis_connected:
  156. logger.info("Redis 连接已建立")
  157. else:
  158. logger.warning("Redis 连接失败,系统将以降级模式运行(无缓存、无分布式限流)")
  159. # 启动异步日志队列(需求 6.1)
  160. from app.core.async_database import AsyncSessionLocal
  161. await async_log_queue.start(AsyncSessionLocal)
  162. logger.info("异步日志队列已启动")
  163. # 启动定时任务(仅第一个 worker 运行,避免多 worker 重复执行)
  164. _is_main_worker = os.environ.get("_SCHEDULER_STARTED") != "1"
  165. if _is_main_worker:
  166. os.environ["_SCHEDULER_STARTED"] = "1"
  167. from apscheduler.schedulers.background import BackgroundScheduler
  168. # TODO: hourly_deduction_task 模块已移除,需要重新实现或确认是否需要
  169. # from app.services.hourly_deduction_task import run_hourly_deduction
  170. scheduler = BackgroundScheduler()
  171. # 每小时整点执行定时扣减任务(已禁用)
  172. # scheduler.add_job(
  173. # run_hourly_deduction,
  174. # 'cron',
  175. # hour='*',
  176. # minute=0,
  177. # id='hourly_deduction',
  178. # name='每小时余额扣减任务'
  179. # )
  180. # 爬虫数据同步任务:每天凌晨3点执行
  181. try:
  182. import asyncio
  183. from app.services.crawler_sync_service import sync_from_crawler
  184. def run_crawler_sync():
  185. db = SessionLocal()
  186. try:
  187. asyncio.run(sync_from_crawler(db))
  188. except Exception as e:
  189. logger.error(f"爬虫同步任务异常: {e}")
  190. finally:
  191. db.close()
  192. scheduler.add_job(
  193. run_crawler_sync,
  194. 'cron',
  195. hour=3,
  196. minute=0,
  197. id='crawler_sync',
  198. name='爬虫数据同步任务'
  199. )
  200. logger.info("定时任务已启动:爬虫数据同步任务每天凌晨3点执行")
  201. # 启动时立即触发一次同步
  202. import threading
  203. threading.Thread(target=run_crawler_sync, daemon=True, name="crawler_sync_startup").start()
  204. logger.info("已触发启动时爬虫数据同步")
  205. except Exception as _:
  206. logger.exception("注册爬虫同步任务失败")
  207. if _is_main_worker:
  208. scheduler.start()
  209. logger.info("定时任务调度器已启动(当前为主 worker)")
  210. else:
  211. logger.info("非主 worker,跳过定时任务调度器启动")
  212. except Exception as e:
  213. logger.error(f"数据库连接失败: {e}")
  214. logger.info("=" * 50)
  215. yield
  216. # 停止异步日志队列(需求 6.1)
  217. logger.info("正在停止异步日志队列...")
  218. await async_log_queue.stop()
  219. # 关闭 Redis 连接(需求 4.1)
  220. logger.info("正在关闭 Redis 连接...")
  221. await redis_manager.close()
  222. logger.info("服务关闭")
  223. app = FastAPI(
  224. title="智创空间API",
  225. description="智创空间后端服务,包含模型广场等模块",
  226. version="1.0.0",
  227. lifespan=lifespan,
  228. docs_url="/docs",
  229. redoc_url="/redoc"
  230. )
  231. # 添加安全中间件
  232. from app.middleware.security_middleware import SecurityMiddleware
  233. # 添加全局请求日志中间件
  234. app.add_middleware(RequestLogMiddleware)
  235. # 添加安全中间件(在CORS之前添加)
  236. app.add_middleware(SecurityMiddleware)
  237. # 添加限流中间件(需求 5.1)
  238. app.add_middleware(RateLimitMiddleware)
  239. app.add_middleware(
  240. CORSMiddleware,
  241. allow_origins=["*"],
  242. allow_credentials=True,
  243. allow_methods=["*"],
  244. allow_headers=["*"],
  245. )
  246. register_exception_handlers(app)
  247. # 注意:local_model_router 必须在 model_router 之前注册,
  248. # 因为 model_router 有 /{model_id} 路由会捕获 /local
  249. app.include_router(local_model_router)
  250. app.include_router(model_router)
  251. app.include_router(auth_router)
  252. app.include_router(user_router)
  253. app.include_router(oss_router)
  254. app.include_router(platform_api_key_router)
  255. app.include_router(openai_compat_router)
  256. app.include_router(platform_stats_router)
  257. app.include_router(user_local_model_permission_router)
  258. app.include_router(password_strength_router)
  259. # 管理后台路由
  260. app.include_router(admin_auth_router)
  261. app.include_router(admin_user_router)
  262. app.include_router(admin_model_router)
  263. app.include_router(admin_log_router)
  264. app.include_router(admin_stats_router)
  265. app.include_router(admin_config_router)
  266. app.include_router(admin_local_model_router)
  267. app.include_router(admin_oss_router)
  268. app.include_router(sso_router)
  269. app.include_router(admin_local_config_router)
  270. # 短信验证码路由
  271. from app.routers.sms_router import router as sms_router
  272. app.include_router(sms_router)
  273. # 邮箱验证码路由
  274. from app.routers.email_router import router as email_router
  275. app.include_router(email_router)
  276. # 公开品牌配置接口(无需登录,前端用于显示 logo/名称)
  277. @app.get("/api/public/branding")
  278. async def get_public_branding():
  279. """返回平台品牌配置(system_name, system_logo, icp_number)"""
  280. from app.models.config import SystemConfig
  281. from app.database import SessionLocal
  282. import json as _json
  283. db = SessionLocal()
  284. try:
  285. def _get(key: str, default: str) -> str:
  286. row = db.query(SystemConfig).filter(SystemConfig.config_key == key).first()
  287. if row:
  288. try:
  289. return _json.loads(row.config_value)
  290. except Exception:
  291. return row.config_value
  292. return default
  293. return {
  294. "system_name": _get("system_name", "智创空间"),
  295. "system_logo": _get("system_logo", ""),
  296. "icp_number": _get("icp_number", ""),
  297. }
  298. finally:
  299. db.close()
  300. @app.get("/health")
  301. async def health_check():
  302. """基本健康检查
  303. 返回系统整体健康状态(healthy/degraded/unhealthy)。
  304. 需求引用: 8.1, 8.2, 8.3, 8.5
  305. """
  306. from app.services.health_service import health_service
  307. overall = await health_service.get_overall_health()
  308. return {"status": overall["status"]}
  309. @app.get("/health/detailed")
  310. async def health_check_detailed():
  311. """详细健康检查
  312. 返回所有组件的详细状态信息,包括:
  313. - 数据库连接状态和连接池使用情况
  314. - Redis 连接状态和内存使用情况
  315. - 异步日志队列状态
  316. 需求引用: 8.1, 8.2, 8.3, 8.4, 8.5
  317. """
  318. from app.services.health_service import health_service
  319. return await health_service.get_overall_health()
  320. if __name__ == "__main__":
  321. import uvicorn
  322. host = os.getenv("APP_HOST", "0.0.0.0")
  323. port = int(os.getenv("APP_PORT", "8010"))
  324. debug = os.getenv("DEBUG", "False").lower() == "true"
  325. logger.info(f"启动开发服务器: http://{host}:{port}")
  326. # 配置 reload 参数
  327. reload_config = {}
  328. if debug:
  329. # 只监控 app 目录,避免监控 logs
  330. reload_config = {
  331. "reload": True,
  332. "reload_dirs": ["app"], # 只监控 app 目录
  333. "reload_includes": ["*.py"], # 只监控 Python 文件
  334. }
  335. uvicorn.run(
  336. "main:app",
  337. host=host,
  338. port=port,
  339. log_level="info",
  340. **reload_config
  341. )