from datetime import datetime import uuid from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import declarative_base from app.config import get_settings settings = get_settings() Base = declarative_base() # 延迟创建 engine/session,在首次使用时再实例化 _engine = None _async_session = None def _get_engine(): global _engine if _engine is None: _engine = create_async_engine( settings.database_url, echo=settings.backend_env == "development", pool_pre_ping=True, ) return _engine def _get_session(): global _async_session if _async_session is None: _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False) return _async_session # async_session 是 async_sessionmaker 实例,调用它返回 AsyncSession async_session = _get_session() async def init_db(): """创建所有表(首次启动时调用)。""" async with _get_engine().begin() as conn: await conn.run_sync(Base.metadata.create_all) # 自动迁移:为已有表补齐新字段 await _migrate_tables() async def _migrate_tables(): """补齐表新增字段(兼容已有数据库)。""" from sqlalchemy import text from sqlalchemy.exc import ProgrammingError, OperationalError alter_stmts = [ # deploy_tasks 补齐字段 "ALTER TABLE deploy_tasks ADD COLUMN deploy_mode VARCHAR(16) DEFAULT 'export'", "ALTER TABLE deploy_tasks ADD COLUMN endpoint_url VARCHAR(256)", "ALTER TABLE deploy_tasks ADD COLUMN port INTEGER", "ALTER TABLE deploy_tasks ADD COLUMN pid VARCHAR(32)", "ALTER TABLE deploy_tasks ADD COLUMN user_id VARCHAR(36)", ] async with _get_engine().begin() as conn: for stmt in alter_stmts: try: await conn.execute(text(stmt)) except (ProgrammingError, OperationalError): pass class TrainingJobModel(Base): __tablename__ = "training_jobs" id = Column(String(36), primary_key=True) model_id = Column(String(256), nullable=False) model_type = Column(String(32), nullable=False) dataset_id = Column(String(36), nullable=False) peft_method = Column(String(32), nullable=False) task_type = Column(String(32), default="sft") dataset_template = Column(String(32), default="alpaca") status = Column(String(32), default="pending") progress = Column(Float, default=0.0) current_epoch = Column(Integer, default=0) current_step = Column(Integer, default=0) total_steps = Column(Integer, default=0) loss = Column(Float, nullable=True) learning_rate = Column(Float, nullable=True) epochs = Column(Integer, default=3) batch_size = Column(Integer, default=4) gradient_accumulation = Column(Integer, default=4) max_seq_length = Column(Integer, default=2048) warmup_ratio = Column(Float, default=0.05) save_strategy = Column(String(32), default="epoch") eval_strategy = Column(String(32), default="epoch") eval_steps = Column(Integer, default=100) lora_r = Column(Integer, default=16) lora_alpha = Column(Integer, default=32) lora_dropout = Column(Float, default=0.05) lora_target_modules = Column(String(256), default="all-linear") qlora_bits = Column(Integer, default=4) # PPO ppo_epochs = Column(Integer, default=4) vf_coef = Column(Float, default=0.1) kl_coef = Column(Float, default=0.2) response_length = Column(Integer, default=512) reward_model_path = Column(String(512), nullable=True) reward_type = Column(String(32), default="heuristic") created_at = Column(DateTime, default=datetime.utcnow) started_at = Column(DateTime, nullable=True) finished_at = Column(DateTime, nullable=True) error_message = Column(Text, nullable=True) adapter_path = Column(String(512), nullable=True) class DatasetRecord(Base): __tablename__ = "datasets" id = Column(String(36), primary_key=True) name = Column(String(256), nullable=False) format = Column(String(16), nullable=False) record_count = Column(Integer, default=0) file_path = Column(String(512), nullable=False) created_at = Column(DateTime, default=datetime.utcnow) class ModelCache(Base): __tablename__ = "model_cache" id = Column(String(256), primary_key=True) name = Column(String(256), nullable=False) model_type = Column(String(32), nullable=False) path = Column(String(512), nullable=True) is_downloaded = Column(Integer, default=0) context_length = Column(Integer, nullable=True) supported_peft_methods = Column(String(256), default="") created_at = Column(DateTime, default=datetime.utcnow) class EvalResultModel(Base): __tablename__ = "eval_results" id = Column(String(36), primary_key=True) job_id = Column(String(36), nullable=False) status = Column(String(32), default="pending") # pending|running|completed|failed metrics = Column(Text, default="{}") progress = Column(Float, default=0.0) error = Column(Text, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) class DeployTaskModel(Base): __tablename__ = "deploy_tasks" id = Column(String(36), primary_key=True) job_id = Column(String(36), nullable=False) user_id = Column(String(36), nullable=True) # 部署任务所属用户 status = Column(String(32), default="pending") deploy_mode = Column(String(16), default="export") # export | serve output_path = Column(String(512), nullable=True) endpoint_url = Column(String(256), nullable=True) # serve 模式下的 base_url port = Column(Integer, nullable=True) # serve 模式分配的端口 pid = Column(String(32), nullable=True) # serve 模式远程进程 PID error = Column(Text, nullable=True) progress = Column(Float, default=0.0) finished_at = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) class ModelDownloadTask(Base): __tablename__ = "model_download_tasks" id = Column(String(36), primary_key=True) model_id = Column(String(256), nullable=False) use_modelscope = Column(Integer, default=0) status = Column(String(32), default="pending") # pending|downloading|completed|failed path = Column(String(512), nullable=True) error = Column(Text, nullable=True) progress = Column(Float, default=0.0) created_at = Column(DateTime, default=datetime.utcnow) finished_at = Column(DateTime, nullable=True) class DatasetDownloadTask(Base): __tablename__ = "dataset_download_tasks" id = Column(String(36), primary_key=True) dataset_id = Column(String(256), nullable=False) use_modelscope = Column(Integer, default=0) status = Column(String(32), default="pending") # pending|downloading|completed|failed path = Column(String(512), nullable=True) error = Column(Text, nullable=True) record_count = Column(Integer, default=0) created_at = Column(DateTime, default=datetime.utcnow) finished_at = Column(DateTime, nullable=True) class UserModel(Base): __tablename__ = "users" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) username = Column(String(128), unique=True, nullable=False, index=True) email = Column(String(256), nullable=True) real_name = Column(String(128), nullable=True) avatar_url = Column(String(512), nullable=True) company = Column(String(128), nullable=True) department = Column(String(128), nullable=True) position = Column(String(128), nullable=True) roles = Column(JSON, default=list) is_active = Column(Integer, default=1, nullable=False) is_superuser = Column(Integer, default=0, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) class RefreshTokenModel(Base): __tablename__ = "refresh_tokens" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(String(36), nullable=False, index=True) token = Column(String(512), unique=True, nullable=False, index=True) expires_at = Column(DateTime, nullable=False) revoked = Column(Integer, default=0, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) class ApiKeyModel(Base): __tablename__ = "api_keys" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(String(36), nullable=False, index=True) key = Column(String(128), unique=True, nullable=False, index=True) # sk-xxx name = Column(String(128), nullable=False, default="default") status = Column(String(16), default="active") # active | revoked last_used_at = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) async def get_db() -> AsyncSession: async with async_session() as session: yield session