Selaa lähdekoodia

修改db文件

lxylxy123321 1 viikko sitten
vanhempi
sitoutus
2903b09d92
1 muutettua tiedostoa jossa 34 lisäystä ja 18 poistoa
  1. 34 18
      backend/app/core/db.py

+ 34 - 18
backend/app/core/db.py

@@ -10,21 +10,43 @@ from app.config import get_settings
 
 settings = get_settings()
 
-# 确保数据库目录存在(从 URL 中解析路径)
-db_path = settings.database_url.removeprefix("sqlite+aiosqlite://")
-if db_path and not db_path.startswith(":memory"):
-    # 确保是绝对路径
-    db_path_obj = Path(db_path) if db_path.startswith("/") else Path("/") / db_path
-    db_path_obj.parent.mkdir(parents=True, exist_ok=True)
-
 Base = declarative_base()
 
-engine = create_async_engine(
-    settings.database_url,
-    echo=settings.backend_env == "development",
-)
+# 延迟创建 engine/session,在首次使用时再实例化
+# 避免模块导入阶段目录还未创建就尝试连接数据库
+_engine = None
+_async_session = None
+
+
+def _get_engine():
+    global _engine
+    if _engine is None:
+        _engine = create_async_engine(
+            settings.database_url,
+            echo=settings.backend_env == "development",
+        )
+    return _engine
+
+
+def _get_session():
+    global _async_session
+    if _async_session is None:
+        _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False)
+    return _async_session
+
+
+async_session = _get_session
 
-async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
+
+async def init_db():
+    """创建数据库目录 + 创建所有表(首次启动时调用)。"""
+    # 确保数据库文件目录存在
+    db_path = settings.database_url.removeprefix("sqlite+aiosqlite://")
+    if db_path and not db_path.startswith(":memory"):
+        db_path_obj = Path(db_path) if db_path.startswith("/") else Path("/") / db_path
+        db_path_obj.parent.mkdir(parents=True, exist_ok=True)
+    async with _get_engine().begin() as conn:
+        await conn.run_sync(Base.metadata.create_all)
 
 
 class TrainingJobModel(Base):
@@ -112,12 +134,6 @@ class DeployTaskModel(Base):
     created_at = Column(DateTime, default=datetime.utcnow)
 
 
-async def init_db():
-    """创建所有表(首次启动时调用)。"""
-    async with engine.begin() as conn:
-        await conn.run_sync(Base.metadata.create_all)
-
-
 async def get_db() -> AsyncSession:
     async with async_session() as session:
         yield session