db.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from datetime import datetime
  2. from sqlalchemy import Column, DateTime, Float, Integer, String, Text
  3. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  4. from sqlalchemy.orm import DeclarativeBase
  5. from app.config import get_settings
  6. settings = get_settings()
  7. Base = DeclarativeBase()
  8. engine = create_async_engine(
  9. settings.database_url,
  10. echo=settings.backend_env == "development",
  11. )
  12. async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
  13. class TrainingJobModel(Base):
  14. __tablename__ = "training_jobs"
  15. id = Column(String(36), primary_key=True)
  16. model_id = Column(String(256), nullable=False)
  17. model_type = Column(String(32), nullable=False)
  18. dataset_id = Column(String(36), nullable=False)
  19. peft_method = Column(String(32), nullable=False)
  20. task_type = Column(String(32), default="sft") # sft/dpo/kto/orpo/rm/ppo
  21. dataset_template = Column(String(32), default="alpaca")
  22. status = Column(String(32), default="pending")
  23. progress = Column(Float, default=0.0)
  24. current_epoch = Column(Integer, default=0)
  25. current_step = Column(Integer, default=0)
  26. total_steps = Column(Integer, default=0)
  27. loss = Column(Float, nullable=True)
  28. learning_rate = Column(Float, nullable=True)
  29. epochs = Column(Integer, default=3)
  30. batch_size = Column(Integer, default=4)
  31. gradient_accumulation = Column(Integer, default=4)
  32. max_seq_length = Column(Integer, default=2048)
  33. warmup_ratio = Column(Float, default=0.05)
  34. save_strategy = Column(String(32), default="epoch")
  35. eval_strategy = Column(String(32), default="epoch")
  36. eval_steps = Column(Integer, default=100)
  37. lora_r = Column(Integer, default=16)
  38. lora_alpha = Column(Integer, default=32)
  39. lora_dropout = Column(Float, default=0.05)
  40. lora_target_modules = Column(String(256), default="all-linear")
  41. qlora_bits = Column(Integer, default=4)
  42. created_at = Column(DateTime, default=datetime.utcnow)
  43. started_at = Column(DateTime, nullable=True)
  44. finished_at = Column(DateTime, nullable=True)
  45. error_message = Column(Text, nullable=True)
  46. adapter_path = Column(String(512), nullable=True)
  47. class DatasetRecord(Base):
  48. __tablename__ = "datasets"
  49. id = Column(String(36), primary_key=True)
  50. name = Column(String(256), nullable=False)
  51. format = Column(String(16), nullable=False)
  52. record_count = Column(Integer, default=0)
  53. file_path = Column(String(512), nullable=False)
  54. created_at = Column(DateTime, default=datetime.utcnow)
  55. class ModelCache(Base):
  56. __tablename__ = "model_cache"
  57. id = Column(String(256), primary_key=True)
  58. name = Column(String(256), nullable=False)
  59. model_type = Column(String(32), nullable=False)
  60. path = Column(String(512), nullable=True)
  61. is_downloaded = Column(Integer, default=0)
  62. context_length = Column(Integer, nullable=True)
  63. supported_peft_methods = Column(String(256), default="")
  64. created_at = Column(DateTime, default=datetime.utcnow)
  65. class EvalResultModel(Base):
  66. __tablename__ = "eval_results"
  67. id = Column(String(36), primary_key=True)
  68. job_id = Column(String(36), nullable=False)
  69. metrics = Column(Text, default="{}")
  70. created_at = Column(DateTime, default=datetime.utcnow)
  71. class DeployTaskModel(Base):
  72. __tablename__ = "deploy_tasks"
  73. id = Column(String(36), primary_key=True)
  74. job_id = Column(String(36), nullable=False)
  75. status = Column(String(32), default="pending")
  76. output_path = Column(String(512), nullable=True)
  77. error = Column(Text, nullable=True)
  78. created_at = Column(DateTime, default=datetime.utcnow)
  79. async def init_db():
  80. """创建所有表(首次启动时调用)。"""
  81. async with engine.begin() as conn:
  82. await conn.run_sync(Base.metadata.create_all)
  83. async def get_db() -> AsyncSession:
  84. async with async_session() as session:
  85. yield session