|
|
@@ -4,8 +4,6 @@ from sqlalchemy import Column, DateTime, Float, Integer, String, Text
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
|
|
-from pathlib import Path
|
|
|
-
|
|
|
from app.config import get_settings
|
|
|
|
|
|
settings = get_settings()
|
|
|
@@ -13,7 +11,6 @@ settings = get_settings()
|
|
|
Base = declarative_base()
|
|
|
|
|
|
# 延迟创建 engine/session,在首次使用时再实例化
|
|
|
-# 避免模块导入阶段目录还未创建就尝试连接数据库
|
|
|
_engine = None
|
|
|
_async_session = None
|
|
|
|
|
|
@@ -24,6 +21,7 @@ def _get_engine():
|
|
|
_engine = create_async_engine(
|
|
|
settings.database_url,
|
|
|
echo=settings.backend_env == "development",
|
|
|
+ pool_pre_ping=True,
|
|
|
)
|
|
|
return _engine
|
|
|
|
|
|
@@ -39,12 +37,7 @@ async_session = _get_session
|
|
|
|
|
|
|
|
|
async def init_db():
|
|
|
- """创建数据库目录 + 创建所有表(首次启动时调用)。"""
|
|
|
- # 确保数据库文件目录存在
|
|
|
- db_path = settings.database_url.removeprefix("sqlite+aiosqlite://")
|
|
|
- if db_path and not db_path.startswith(":memory"):
|
|
|
- db_path_obj = Path(db_path) if db_path.startswith("/") else Path("/") / db_path
|
|
|
- db_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ """创建所有表(首次启动时调用)。"""
|
|
|
async with _get_engine().begin() as conn:
|
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
@@ -57,7 +50,7 @@ class TrainingJobModel(Base):
|
|
|
model_type = Column(String(32), nullable=False)
|
|
|
dataset_id = Column(String(36), nullable=False)
|
|
|
peft_method = Column(String(32), nullable=False)
|
|
|
- task_type = Column(String(32), default="sft") # sft/dpo/kto/orpo/rm/ppo
|
|
|
+ task_type = Column(String(32), default="sft")
|
|
|
dataset_template = Column(String(32), default="alpaca")
|
|
|
|
|
|
status = Column(String(32), default="pending")
|