| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- from datetime import datetime
- import uuid
- from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
- from sqlalchemy.orm import declarative_base
- from app.config import get_settings
- settings = get_settings()
- Base = declarative_base()
- # 延迟创建 engine/session,在首次使用时再实例化
- _engine = None
- _async_session = None
- def _get_engine():
- global _engine
- if _engine is None:
- _engine = create_async_engine(
- settings.database_url,
- echo=settings.backend_env == "development",
- pool_pre_ping=True,
- )
- return _engine
- def _get_session():
- global _async_session
- if _async_session is None:
- _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False)
- return _async_session
- # async_session 是 async_sessionmaker 实例,调用它返回 AsyncSession
- async_session = _get_session()
- async def init_db():
- """创建所有表(首次启动时调用)。"""
- async with _get_engine().begin() as conn:
- await conn.run_sync(Base.metadata.create_all)
- # 自动迁移:为已有表补齐新字段
- await _migrate_tables()
- async def _migrate_tables():
- """补齐表新增字段(兼容已有数据库)。"""
- from sqlalchemy import text
- from sqlalchemy.exc import ProgrammingError, OperationalError
- alter_stmts = [
- # deploy_tasks 补齐字段
- "ALTER TABLE deploy_tasks ADD COLUMN deploy_mode VARCHAR(16) DEFAULT 'export'",
- "ALTER TABLE deploy_tasks ADD COLUMN endpoint_url VARCHAR(256)",
- "ALTER TABLE deploy_tasks ADD COLUMN port INTEGER",
- "ALTER TABLE deploy_tasks ADD COLUMN pid VARCHAR(32)",
- "ALTER TABLE deploy_tasks ADD COLUMN user_id VARCHAR(36)",
- ]
- async with _get_engine().begin() as conn:
- for stmt in alter_stmts:
- try:
- await conn.execute(text(stmt))
- except (ProgrammingError, OperationalError):
- pass
- class TrainingJobModel(Base):
- __tablename__ = "training_jobs"
- id = Column(String(36), primary_key=True)
- model_id = Column(String(256), nullable=False)
- model_type = Column(String(32), nullable=False)
- dataset_id = Column(String(36), nullable=False)
- peft_method = Column(String(32), nullable=False)
- task_type = Column(String(32), default="sft")
- dataset_template = Column(String(32), default="alpaca")
- status = Column(String(32), default="pending")
- progress = Column(Float, default=0.0)
- current_epoch = Column(Integer, default=0)
- current_step = Column(Integer, default=0)
- total_steps = Column(Integer, default=0)
- loss = Column(Float, nullable=True)
- learning_rate = Column(Float, nullable=True)
- epochs = Column(Integer, default=3)
- batch_size = Column(Integer, default=4)
- gradient_accumulation = Column(Integer, default=4)
- max_seq_length = Column(Integer, default=2048)
- warmup_ratio = Column(Float, default=0.05)
- save_strategy = Column(String(32), default="epoch")
- eval_strategy = Column(String(32), default="epoch")
- eval_steps = Column(Integer, default=100)
- lora_r = Column(Integer, default=16)
- lora_alpha = Column(Integer, default=32)
- lora_dropout = Column(Float, default=0.05)
- lora_target_modules = Column(String(256), default="all-linear")
- qlora_bits = Column(Integer, default=4)
- # PPO
- ppo_epochs = Column(Integer, default=4)
- vf_coef = Column(Float, default=0.1)
- kl_coef = Column(Float, default=0.2)
- response_length = Column(Integer, default=512)
- reward_model_path = Column(String(512), nullable=True)
- reward_type = Column(String(32), default="heuristic")
- created_at = Column(DateTime, default=datetime.utcnow)
- started_at = Column(DateTime, nullable=True)
- finished_at = Column(DateTime, nullable=True)
- error_message = Column(Text, nullable=True)
- adapter_path = Column(String(512), nullable=True)
- class DatasetRecord(Base):
- __tablename__ = "datasets"
- id = Column(String(36), primary_key=True)
- name = Column(String(256), nullable=False)
- format = Column(String(16), nullable=False)
- record_count = Column(Integer, default=0)
- file_path = Column(String(512), nullable=False)
- created_at = Column(DateTime, default=datetime.utcnow)
- class ModelCache(Base):
- __tablename__ = "model_cache"
- id = Column(String(256), primary_key=True)
- name = Column(String(256), nullable=False)
- model_type = Column(String(32), nullable=False)
- path = Column(String(512), nullable=True)
- is_downloaded = Column(Integer, default=0)
- context_length = Column(Integer, nullable=True)
- supported_peft_methods = Column(String(256), default="")
- created_at = Column(DateTime, default=datetime.utcnow)
- class EvalResultModel(Base):
- __tablename__ = "eval_results"
- id = Column(String(36), primary_key=True)
- job_id = Column(String(36), nullable=False)
- status = Column(String(32), default="pending") # pending|running|completed|failed
- metrics = Column(Text, default="{}")
- progress = Column(Float, default=0.0)
- error = Column(Text, nullable=True)
- created_at = Column(DateTime, default=datetime.utcnow)
- class DeployTaskModel(Base):
- __tablename__ = "deploy_tasks"
- id = Column(String(36), primary_key=True)
- job_id = Column(String(36), nullable=False)
- user_id = Column(String(36), nullable=True) # 部署任务所属用户
- status = Column(String(32), default="pending")
- deploy_mode = Column(String(16), default="export") # export | serve
- output_path = Column(String(512), nullable=True)
- endpoint_url = Column(String(256), nullable=True) # serve 模式下的 base_url
- port = Column(Integer, nullable=True) # serve 模式分配的端口
- pid = Column(String(32), nullable=True) # serve 模式远程进程 PID
- error = Column(Text, nullable=True)
- progress = Column(Float, default=0.0)
- finished_at = Column(DateTime, nullable=True)
- created_at = Column(DateTime, default=datetime.utcnow)
- class ModelDownloadTask(Base):
- __tablename__ = "model_download_tasks"
- id = Column(String(36), primary_key=True)
- model_id = Column(String(256), nullable=False)
- use_modelscope = Column(Integer, default=0)
- status = Column(String(32), default="pending") # pending|downloading|completed|failed
- path = Column(String(512), nullable=True)
- error = Column(Text, nullable=True)
- progress = Column(Float, default=0.0)
- created_at = Column(DateTime, default=datetime.utcnow)
- finished_at = Column(DateTime, nullable=True)
- class DatasetDownloadTask(Base):
- __tablename__ = "dataset_download_tasks"
- id = Column(String(36), primary_key=True)
- dataset_id = Column(String(256), nullable=False)
- use_modelscope = Column(Integer, default=0)
- status = Column(String(32), default="pending") # pending|downloading|completed|failed
- path = Column(String(512), nullable=True)
- error = Column(Text, nullable=True)
- record_count = Column(Integer, default=0)
- created_at = Column(DateTime, default=datetime.utcnow)
- finished_at = Column(DateTime, nullable=True)
- class UserModel(Base):
- __tablename__ = "users"
- id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
- username = Column(String(128), unique=True, nullable=False, index=True)
- email = Column(String(256), nullable=True)
- real_name = Column(String(128), nullable=True)
- avatar_url = Column(String(512), nullable=True)
- company = Column(String(128), nullable=True)
- department = Column(String(128), nullable=True)
- position = Column(String(128), nullable=True)
- roles = Column(JSON, default=list)
- is_active = Column(Integer, default=1, nullable=False)
- is_superuser = Column(Integer, default=0, nullable=False)
- created_at = Column(DateTime, default=datetime.utcnow)
- updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
- class RefreshTokenModel(Base):
- __tablename__ = "refresh_tokens"
- id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
- user_id = Column(String(36), nullable=False, index=True)
- token = Column(String(512), unique=True, nullable=False, index=True)
- expires_at = Column(DateTime, nullable=False)
- revoked = Column(Integer, default=0, nullable=False)
- created_at = Column(DateTime, default=datetime.utcnow)
- class ApiKeyModel(Base):
- __tablename__ = "api_keys"
- id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
- user_id = Column(String(36), nullable=False, index=True)
- key = Column(String(128), unique=True, nullable=False, index=True) # sk-xxx
- name = Column(String(128), nullable=False, default="default")
- status = Column(String(16), default="active") # active | revoked
- last_used_at = Column(DateTime, nullable=True)
- created_at = Column(DateTime, default=datetime.utcnow)
- async def get_db() -> AsyncSession:
- async with async_session() as session:
- yield session
|