db.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from datetime import datetime
  2. import uuid
  3. from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text
  4. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  5. from sqlalchemy.orm import declarative_base
  6. from app.config import get_settings
  7. settings = get_settings()
  8. Base = declarative_base()
  9. # 延迟创建 engine/session,在首次使用时再实例化
  10. _engine = None
  11. _async_session = None
  12. def _get_engine():
  13. global _engine
  14. if _engine is None:
  15. _engine = create_async_engine(
  16. settings.database_url,
  17. echo=settings.backend_env == "development",
  18. pool_pre_ping=True,
  19. )
  20. return _engine
  21. def _get_session():
  22. global _async_session
  23. if _async_session is None:
  24. _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False)
  25. return _async_session
  26. # async_session 是 async_sessionmaker 实例,调用它返回 AsyncSession
  27. async_session = _get_session()
  28. async def init_db():
  29. """创建所有表(首次启动时调用)。"""
  30. async with _get_engine().begin() as conn:
  31. await conn.run_sync(Base.metadata.create_all)
  32. class TrainingJobModel(Base):
  33. __tablename__ = "training_jobs"
  34. id = Column(String(36), primary_key=True)
  35. model_id = Column(String(256), nullable=False)
  36. model_type = Column(String(32), nullable=False)
  37. dataset_id = Column(String(36), nullable=False)
  38. peft_method = Column(String(32), nullable=False)
  39. task_type = Column(String(32), default="sft")
  40. dataset_template = Column(String(32), default="alpaca")
  41. status = Column(String(32), default="pending")
  42. progress = Column(Float, default=0.0)
  43. current_epoch = Column(Integer, default=0)
  44. current_step = Column(Integer, default=0)
  45. total_steps = Column(Integer, default=0)
  46. loss = Column(Float, nullable=True)
  47. learning_rate = Column(Float, nullable=True)
  48. epochs = Column(Integer, default=3)
  49. batch_size = Column(Integer, default=4)
  50. gradient_accumulation = Column(Integer, default=4)
  51. max_seq_length = Column(Integer, default=2048)
  52. warmup_ratio = Column(Float, default=0.05)
  53. save_strategy = Column(String(32), default="epoch")
  54. eval_strategy = Column(String(32), default="epoch")
  55. eval_steps = Column(Integer, default=100)
  56. lora_r = Column(Integer, default=16)
  57. lora_alpha = Column(Integer, default=32)
  58. lora_dropout = Column(Float, default=0.05)
  59. lora_target_modules = Column(String(256), default="all-linear")
  60. qlora_bits = Column(Integer, default=4)
  61. created_at = Column(DateTime, default=datetime.utcnow)
  62. started_at = Column(DateTime, nullable=True)
  63. finished_at = Column(DateTime, nullable=True)
  64. error_message = Column(Text, nullable=True)
  65. adapter_path = Column(String(512), nullable=True)
  66. class DatasetRecord(Base):
  67. __tablename__ = "datasets"
  68. id = Column(String(36), primary_key=True)
  69. name = Column(String(256), nullable=False)
  70. format = Column(String(16), nullable=False)
  71. record_count = Column(Integer, default=0)
  72. file_path = Column(String(512), nullable=False)
  73. created_at = Column(DateTime, default=datetime.utcnow)
  74. class ModelCache(Base):
  75. __tablename__ = "model_cache"
  76. id = Column(String(256), primary_key=True)
  77. name = Column(String(256), nullable=False)
  78. model_type = Column(String(32), nullable=False)
  79. path = Column(String(512), nullable=True)
  80. is_downloaded = Column(Integer, default=0)
  81. context_length = Column(Integer, nullable=True)
  82. supported_peft_methods = Column(String(256), default="")
  83. created_at = Column(DateTime, default=datetime.utcnow)
  84. class EvalResultModel(Base):
  85. __tablename__ = "eval_results"
  86. id = Column(String(36), primary_key=True)
  87. job_id = Column(String(36), nullable=False)
  88. metrics = Column(Text, default="{}")
  89. created_at = Column(DateTime, default=datetime.utcnow)
  90. class DeployTaskModel(Base):
  91. __tablename__ = "deploy_tasks"
  92. id = Column(String(36), primary_key=True)
  93. job_id = Column(String(36), nullable=False)
  94. status = Column(String(32), default="pending")
  95. output_path = Column(String(512), nullable=True)
  96. error = Column(Text, nullable=True)
  97. created_at = Column(DateTime, default=datetime.utcnow)
  98. class UserModel(Base):
  99. __tablename__ = "users"
  100. id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
  101. username = Column(String(128), unique=True, nullable=False, index=True)
  102. email = Column(String(256), nullable=True)
  103. real_name = Column(String(128), nullable=True)
  104. avatar_url = Column(String(512), nullable=True)
  105. company = Column(String(128), nullable=True)
  106. department = Column(String(128), nullable=True)
  107. position = Column(String(128), nullable=True)
  108. roles = Column(JSON, default=list)
  109. is_active = Column(Integer, default=1, nullable=False)
  110. is_superuser = Column(Integer, default=0, nullable=False)
  111. created_at = Column(DateTime, default=datetime.utcnow)
  112. updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
  113. class RefreshTokenModel(Base):
  114. __tablename__ = "refresh_tokens"
  115. id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
  116. user_id = Column(String(36), nullable=False, index=True)
  117. token = Column(String(512), unique=True, nullable=False, index=True)
  118. expires_at = Column(DateTime, nullable=False)
  119. revoked = Column(Integer, default=0, nullable=False)
  120. created_at = Column(DateTime, default=datetime.utcnow)
  121. async def get_db() -> AsyncSession:
  122. async with async_session() as session:
  123. yield session