|
|
@@ -10,21 +10,43 @@ from app.config import get_settings
|
|
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
-# 确保数据库目录存在(从 URL 中解析路径)
|
|
|
-db_path = settings.database_url.removeprefix("sqlite+aiosqlite://")
|
|
|
-if db_path and not db_path.startswith(":memory"):
|
|
|
- # 确保是绝对路径
|
|
|
- db_path_obj = Path(db_path) if db_path.startswith("/") else Path("/") / db_path
|
|
|
- db_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
-
|
|
|
Base = declarative_base()
|
|
|
|
|
|
-engine = create_async_engine(
|
|
|
- settings.database_url,
|
|
|
- echo=settings.backend_env == "development",
|
|
|
-)
|
|
|
+# 延迟创建 engine/session,在首次使用时再实例化
|
|
|
+# 避免模块导入阶段目录还未创建就尝试连接数据库
|
|
|
+_engine = None
|
|
|
+_async_session = None
|
|
|
+
|
|
|
+
|
|
|
+def _get_engine():
|
|
|
+ global _engine
|
|
|
+ if _engine is None:
|
|
|
+ _engine = create_async_engine(
|
|
|
+ settings.database_url,
|
|
|
+ echo=settings.backend_env == "development",
|
|
|
+ )
|
|
|
+ return _engine
|
|
|
+
|
|
|
+
|
|
|
+def _get_session():
|
|
|
+ global _async_session
|
|
|
+ if _async_session is None:
|
|
|
+ _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False)
|
|
|
+ return _async_session
|
|
|
+
|
|
|
+
|
|
|
+async_session = _get_session
|
|
|
|
|
|
-async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
+
|
|
|
+async def init_db():
|
|
|
+ """创建数据库目录 + 创建所有表(首次启动时调用)。"""
|
|
|
+ # 确保数据库文件目录存在
|
|
|
+ db_path = settings.database_url.removeprefix("sqlite+aiosqlite://")
|
|
|
+ if db_path and not db_path.startswith(":memory"):
|
|
|
+ db_path_obj = Path(db_path) if db_path.startswith("/") else Path("/") / db_path
|
|
|
+ db_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ async with _get_engine().begin() as conn:
|
|
|
+ await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
|
|
|
class TrainingJobModel(Base):
|
|
|
@@ -112,12 +134,6 @@ class DeployTaskModel(Base):
|
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
|
|
|
|
|
-async def init_db():
|
|
|
- """创建所有表(首次启动时调用)。"""
|
|
|
- async with engine.begin() as conn:
|
|
|
- await conn.run_sync(Base.metadata.create_all)
|
|
|
-
|
|
|
-
|
|
|
async def get_db() -> AsyncSession:
|
|
|
async with async_session() as session:
|
|
|
yield session
|