db.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from datetime import datetime
  2. from sqlalchemy import Column, DateTime, Float, Integer, String, Text
  3. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  4. from sqlalchemy.orm import declarative_base
  5. from app.config import get_settings
  6. settings = get_settings()
  7. Base = declarative_base()
  8. # 延迟创建 engine/session,在首次使用时再实例化
  9. _engine = None
  10. _async_session = None
  11. def _get_engine():
  12. global _engine
  13. if _engine is None:
  14. _engine = create_async_engine(
  15. settings.database_url,
  16. echo=settings.backend_env == "development",
  17. pool_pre_ping=True,
  18. )
  19. return _engine
  20. def _get_session():
  21. global _async_session
  22. if _async_session is None:
  23. _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False)
  24. return _async_session
  25. async_session = _get_session
  26. async def init_db():
  27. """创建所有表(首次启动时调用)。"""
  28. async with _get_engine().begin() as conn:
  29. await conn.run_sync(Base.metadata.create_all)
  30. class TrainingJobModel(Base):
  31. __tablename__ = "training_jobs"
  32. id = Column(String(36), primary_key=True)
  33. model_id = Column(String(256), nullable=False)
  34. model_type = Column(String(32), nullable=False)
  35. dataset_id = Column(String(36), nullable=False)
  36. peft_method = Column(String(32), nullable=False)
  37. task_type = Column(String(32), default="sft")
  38. dataset_template = Column(String(32), default="alpaca")
  39. status = Column(String(32), default="pending")
  40. progress = Column(Float, default=0.0)
  41. current_epoch = Column(Integer, default=0)
  42. current_step = Column(Integer, default=0)
  43. total_steps = Column(Integer, default=0)
  44. loss = Column(Float, nullable=True)
  45. learning_rate = Column(Float, nullable=True)
  46. epochs = Column(Integer, default=3)
  47. batch_size = Column(Integer, default=4)
  48. gradient_accumulation = Column(Integer, default=4)
  49. max_seq_length = Column(Integer, default=2048)
  50. warmup_ratio = Column(Float, default=0.05)
  51. save_strategy = Column(String(32), default="epoch")
  52. eval_strategy = Column(String(32), default="epoch")
  53. eval_steps = Column(Integer, default=100)
  54. lora_r = Column(Integer, default=16)
  55. lora_alpha = Column(Integer, default=32)
  56. lora_dropout = Column(Float, default=0.05)
  57. lora_target_modules = Column(String(256), default="all-linear")
  58. qlora_bits = Column(Integer, default=4)
  59. created_at = Column(DateTime, default=datetime.utcnow)
  60. started_at = Column(DateTime, nullable=True)
  61. finished_at = Column(DateTime, nullable=True)
  62. error_message = Column(Text, nullable=True)
  63. adapter_path = Column(String(512), nullable=True)
  64. class DatasetRecord(Base):
  65. __tablename__ = "datasets"
  66. id = Column(String(36), primary_key=True)
  67. name = Column(String(256), nullable=False)
  68. format = Column(String(16), nullable=False)
  69. record_count = Column(Integer, default=0)
  70. file_path = Column(String(512), nullable=False)
  71. created_at = Column(DateTime, default=datetime.utcnow)
  72. class ModelCache(Base):
  73. __tablename__ = "model_cache"
  74. id = Column(String(256), primary_key=True)
  75. name = Column(String(256), nullable=False)
  76. model_type = Column(String(32), nullable=False)
  77. path = Column(String(512), nullable=True)
  78. is_downloaded = Column(Integer, default=0)
  79. context_length = Column(Integer, nullable=True)
  80. supported_peft_methods = Column(String(256), default="")
  81. created_at = Column(DateTime, default=datetime.utcnow)
  82. class EvalResultModel(Base):
  83. __tablename__ = "eval_results"
  84. id = Column(String(36), primary_key=True)
  85. job_id = Column(String(36), nullable=False)
  86. metrics = Column(Text, default="{}")
  87. created_at = Column(DateTime, default=datetime.utcnow)
  88. class DeployTaskModel(Base):
  89. __tablename__ = "deploy_tasks"
  90. id = Column(String(36), primary_key=True)
  91. job_id = Column(String(36), nullable=False)
  92. status = Column(String(32), default="pending")
  93. output_path = Column(String(512), nullable=True)
  94. error = Column(Text, nullable=True)
  95. created_at = Column(DateTime, default=datetime.utcnow)
  96. async def get_db() -> AsyncSession:
  97. async with async_session() as session:
  98. yield session