from datetime import datetime import uuid from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import declarative_base from app.config import get_settings settings = get_settings() Base = declarative_base() # 延迟创建 engine/session,在首次使用时再实例化 _engine = None _async_session = None def _get_engine(): global _engine if _engine is None: _engine = create_async_engine( settings.database_url, echo=settings.backend_env == "development", pool_pre_ping=True, ) return _engine def _get_session(): global _async_session if _async_session is None: _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False) return _async_session # async_session 是 async_sessionmaker 实例,调用它返回 AsyncSession async_session = _get_session() async def init_db(): """创建所有表(首次启动时调用)。""" async with _get_engine().begin() as conn: await conn.run_sync(Base.metadata.create_all) class TrainingJobModel(Base): __tablename__ = "training_jobs" id = Column(String(36), primary_key=True) model_id = Column(String(256), nullable=False) model_type = Column(String(32), nullable=False) dataset_id = Column(String(36), nullable=False) peft_method = Column(String(32), nullable=False) task_type = Column(String(32), default="sft") dataset_template = Column(String(32), default="alpaca") status = Column(String(32), default="pending") progress = Column(Float, default=0.0) current_epoch = Column(Integer, default=0) current_step = Column(Integer, default=0) total_steps = Column(Integer, default=0) loss = Column(Float, nullable=True) learning_rate = Column(Float, nullable=True) epochs = Column(Integer, default=3) batch_size = Column(Integer, default=4) gradient_accumulation = Column(Integer, default=4) max_seq_length = Column(Integer, default=2048) warmup_ratio = Column(Float, default=0.05) save_strategy = Column(String(32), default="epoch") eval_strategy = Column(String(32), default="epoch") eval_steps = Column(Integer, default=100) lora_r = Column(Integer, default=16) lora_alpha = Column(Integer, default=32) lora_dropout = Column(Float, default=0.05) lora_target_modules = Column(String(256), default="all-linear") qlora_bits = Column(Integer, default=4) created_at = Column(DateTime, default=datetime.utcnow) started_at = Column(DateTime, nullable=True) finished_at = Column(DateTime, nullable=True) error_message = Column(Text, nullable=True) adapter_path = Column(String(512), nullable=True) class DatasetRecord(Base): __tablename__ = "datasets" id = Column(String(36), primary_key=True) name = Column(String(256), nullable=False) format = Column(String(16), nullable=False) record_count = Column(Integer, default=0) file_path = Column(String(512), nullable=False) created_at = Column(DateTime, default=datetime.utcnow) class ModelCache(Base): __tablename__ = "model_cache" id = Column(String(256), primary_key=True) name = Column(String(256), nullable=False) model_type = Column(String(32), nullable=False) path = Column(String(512), nullable=True) is_downloaded = Column(Integer, default=0) context_length = Column(Integer, nullable=True) supported_peft_methods = Column(String(256), default="") created_at = Column(DateTime, default=datetime.utcnow) class EvalResultModel(Base): __tablename__ = "eval_results" id = Column(String(36), primary_key=True) job_id = Column(String(36), nullable=False) metrics = Column(Text, default="{}") created_at = Column(DateTime, default=datetime.utcnow) class DeployTaskModel(Base): __tablename__ = "deploy_tasks" id = Column(String(36), primary_key=True) job_id = Column(String(36), nullable=False) status = Column(String(32), default="pending") output_path = Column(String(512), nullable=True) error = Column(Text, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) class UserModel(Base): __tablename__ = "users" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) username = Column(String(128), unique=True, nullable=False, index=True) email = Column(String(256), nullable=True) real_name = Column(String(128), nullable=True) avatar_url = Column(String(512), nullable=True) company = Column(String(128), nullable=True) department = Column(String(128), nullable=True) position = Column(String(128), nullable=True) roles = Column(JSON, default=list) is_active = Column(Integer, default=1, nullable=False) is_superuser = Column(Integer, default=0, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) class RefreshTokenModel(Base): __tablename__ = "refresh_tokens" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(String(36), nullable=False, index=True) token = Column(String(512), unique=True, nullable=False, index=True) expires_at = Column(DateTime, nullable=False) revoked = Column(Integer, default=0, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) async def get_db() -> AsyncSession: async with async_session() as session: yield session