db.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from datetime import datetime
  2. import uuid
  3. from sqlalchemy import JSON, Column, DateTime, Float, Integer, String, Text
  4. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
  5. from sqlalchemy.orm import declarative_base
  6. from app.config import get_settings
  7. settings = get_settings()
  8. Base = declarative_base()
  9. # 延迟创建 engine/session,在首次使用时再实例化
  10. _engine = None
  11. _async_session = None
  12. def _get_engine():
  13. global _engine
  14. if _engine is None:
  15. _engine = create_async_engine(
  16. settings.database_url,
  17. echo=settings.backend_env == "development",
  18. pool_pre_ping=True,
  19. )
  20. return _engine
  21. def _get_session():
  22. global _async_session
  23. if _async_session is None:
  24. _async_session = async_sessionmaker(_get_engine(), class_=AsyncSession, expire_on_commit=False)
  25. return _async_session
  26. # async_session 是 async_sessionmaker 实例,调用它返回 AsyncSession
  27. async_session = _get_session()
  28. async def init_db():
  29. """创建所有表(首次启动时调用)。"""
  30. async with _get_engine().begin() as conn:
  31. await conn.run_sync(Base.metadata.create_all)
  32. # 自动迁移:为已有表补齐新字段
  33. await _migrate_tables()
  34. async def _migrate_tables():
  35. """补齐表新增字段(兼容已有数据库)。"""
  36. from sqlalchemy import text
  37. # 用 DO 块判断列是否存在,避免 ADD COLUMN 重复列导致事务 abort
  38. alter_stmts = [
  39. "DO $$ BEGIN "
  40. "IF NOT EXISTS (SELECT 1 FROM information_schema.columns "
  41. "WHERE table_name='deploy_tasks' AND column_name='deploy_mode') THEN "
  42. "ALTER TABLE deploy_tasks ADD COLUMN deploy_mode VARCHAR(16) DEFAULT 'export'; "
  43. "END IF; END $$",
  44. "DO $$ BEGIN "
  45. "IF NOT EXISTS (SELECT 1 FROM information_schema.columns "
  46. "WHERE table_name='deploy_tasks' AND column_name='endpoint_url') THEN "
  47. "ALTER TABLE deploy_tasks ADD COLUMN endpoint_url VARCHAR(256); "
  48. "END IF; END $$",
  49. "DO $$ BEGIN "
  50. "IF NOT EXISTS (SELECT 1 FROM information_schema.columns "
  51. "WHERE table_name='deploy_tasks' AND column_name='port') THEN "
  52. "ALTER TABLE deploy_tasks ADD COLUMN port INTEGER; "
  53. "END IF; END $$",
  54. "DO $$ BEGIN "
  55. "IF NOT EXISTS (SELECT 1 FROM information_schema.columns "
  56. "WHERE table_name='deploy_tasks' AND column_name='pid') THEN "
  57. "ALTER TABLE deploy_tasks ADD COLUMN pid VARCHAR(32); "
  58. "END IF; END $$",
  59. "DO $$ BEGIN "
  60. "IF NOT EXISTS (SELECT 1 FROM information_schema.columns "
  61. "WHERE table_name='deploy_tasks' AND column_name='user_id') THEN "
  62. "ALTER TABLE deploy_tasks ADD COLUMN user_id VARCHAR(36); "
  63. "END IF; END $$",
  64. ]
  65. engine = _get_engine()
  66. for stmt in alter_stmts:
  67. try:
  68. async with engine.begin() as conn:
  69. await conn.execute(text(stmt))
  70. except Exception:
  71. pass
  72. class TrainingJobModel(Base):
  73. __tablename__ = "training_jobs"
  74. id = Column(String(36), primary_key=True)
  75. model_id = Column(String(256), nullable=False)
  76. model_type = Column(String(32), nullable=False)
  77. dataset_id = Column(String(36), nullable=False)
  78. peft_method = Column(String(32), nullable=False)
  79. task_type = Column(String(32), default="sft")
  80. dataset_template = Column(String(32), default="alpaca")
  81. status = Column(String(32), default="pending")
  82. progress = Column(Float, default=0.0)
  83. current_epoch = Column(Integer, default=0)
  84. current_step = Column(Integer, default=0)
  85. total_steps = Column(Integer, default=0)
  86. loss = Column(Float, nullable=True)
  87. learning_rate = Column(Float, nullable=True)
  88. epochs = Column(Integer, default=3)
  89. batch_size = Column(Integer, default=4)
  90. gradient_accumulation = Column(Integer, default=4)
  91. max_seq_length = Column(Integer, default=2048)
  92. warmup_ratio = Column(Float, default=0.05)
  93. save_strategy = Column(String(32), default="epoch")
  94. eval_strategy = Column(String(32), default="epoch")
  95. eval_steps = Column(Integer, default=100)
  96. lora_r = Column(Integer, default=16)
  97. lora_alpha = Column(Integer, default=32)
  98. lora_dropout = Column(Float, default=0.05)
  99. lora_target_modules = Column(String(256), default="all-linear")
  100. qlora_bits = Column(Integer, default=4)
  101. # PPO
  102. ppo_epochs = Column(Integer, default=4)
  103. vf_coef = Column(Float, default=0.1)
  104. kl_coef = Column(Float, default=0.2)
  105. response_length = Column(Integer, default=512)
  106. reward_model_path = Column(String(512), nullable=True)
  107. reward_type = Column(String(32), default="heuristic")
  108. created_at = Column(DateTime, default=datetime.utcnow)
  109. started_at = Column(DateTime, nullable=True)
  110. finished_at = Column(DateTime, nullable=True)
  111. error_message = Column(Text, nullable=True)
  112. adapter_path = Column(String(512), nullable=True)
  113. class DatasetRecord(Base):
  114. __tablename__ = "datasets"
  115. id = Column(String(36), primary_key=True)
  116. name = Column(String(256), nullable=False)
  117. format = Column(String(16), nullable=False)
  118. record_count = Column(Integer, default=0)
  119. file_path = Column(String(512), nullable=False)
  120. created_at = Column(DateTime, default=datetime.utcnow)
  121. class ModelCache(Base):
  122. __tablename__ = "model_cache"
  123. id = Column(String(256), primary_key=True)
  124. name = Column(String(256), nullable=False)
  125. model_type = Column(String(32), nullable=False)
  126. path = Column(String(512), nullable=True)
  127. is_downloaded = Column(Integer, default=0)
  128. context_length = Column(Integer, nullable=True)
  129. supported_peft_methods = Column(String(256), default="")
  130. created_at = Column(DateTime, default=datetime.utcnow)
  131. class EvalResultModel(Base):
  132. __tablename__ = "eval_results"
  133. id = Column(String(36), primary_key=True)
  134. job_id = Column(String(36), nullable=False)
  135. status = Column(String(32), default="pending") # pending|running|completed|failed
  136. metrics = Column(Text, default="{}")
  137. progress = Column(Float, default=0.0)
  138. error = Column(Text, nullable=True)
  139. created_at = Column(DateTime, default=datetime.utcnow)
  140. class DeployTaskModel(Base):
  141. __tablename__ = "deploy_tasks"
  142. id = Column(String(36), primary_key=True)
  143. job_id = Column(String(36), nullable=False)
  144. user_id = Column(String(36), nullable=True) # 部署任务所属用户
  145. status = Column(String(32), default="pending")
  146. deploy_mode = Column(String(16), default="export") # export | serve
  147. output_path = Column(String(512), nullable=True)
  148. endpoint_url = Column(String(256), nullable=True) # serve 模式下的 base_url
  149. port = Column(Integer, nullable=True) # serve 模式分配的端口
  150. pid = Column(String(32), nullable=True) # serve 模式远程进程 PID
  151. error = Column(Text, nullable=True)
  152. progress = Column(Float, default=0.0)
  153. finished_at = Column(DateTime, nullable=True)
  154. created_at = Column(DateTime, default=datetime.utcnow)
  155. class ModelDownloadTask(Base):
  156. __tablename__ = "model_download_tasks"
  157. id = Column(String(36), primary_key=True)
  158. model_id = Column(String(256), nullable=False)
  159. use_modelscope = Column(Integer, default=0)
  160. status = Column(String(32), default="pending") # pending|downloading|completed|failed
  161. path = Column(String(512), nullable=True)
  162. error = Column(Text, nullable=True)
  163. progress = Column(Float, default=0.0)
  164. created_at = Column(DateTime, default=datetime.utcnow)
  165. finished_at = Column(DateTime, nullable=True)
  166. class DatasetDownloadTask(Base):
  167. __tablename__ = "dataset_download_tasks"
  168. id = Column(String(36), primary_key=True)
  169. dataset_id = Column(String(256), nullable=False)
  170. use_modelscope = Column(Integer, default=0)
  171. status = Column(String(32), default="pending") # pending|downloading|completed|failed
  172. path = Column(String(512), nullable=True)
  173. error = Column(Text, nullable=True)
  174. record_count = Column(Integer, default=0)
  175. created_at = Column(DateTime, default=datetime.utcnow)
  176. finished_at = Column(DateTime, nullable=True)
  177. class UserModel(Base):
  178. __tablename__ = "users"
  179. id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
  180. username = Column(String(128), unique=True, nullable=False, index=True)
  181. email = Column(String(256), nullable=True)
  182. real_name = Column(String(128), nullable=True)
  183. avatar_url = Column(String(512), nullable=True)
  184. company = Column(String(128), nullable=True)
  185. department = Column(String(128), nullable=True)
  186. position = Column(String(128), nullable=True)
  187. roles = Column(JSON, default=list)
  188. is_active = Column(Integer, default=1, nullable=False)
  189. is_superuser = Column(Integer, default=0, nullable=False)
  190. created_at = Column(DateTime, default=datetime.utcnow)
  191. updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
  192. class RefreshTokenModel(Base):
  193. __tablename__ = "refresh_tokens"
  194. id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
  195. user_id = Column(String(36), nullable=False, index=True)
  196. token = Column(String(512), unique=True, nullable=False, index=True)
  197. expires_at = Column(DateTime, nullable=False)
  198. revoked = Column(Integer, default=0, nullable=False)
  199. created_at = Column(DateTime, default=datetime.utcnow)
  200. class ApiKeyModel(Base):
  201. __tablename__ = "api_keys"
  202. id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
  203. user_id = Column(String(36), nullable=False, index=True)
  204. key = Column(String(128), unique=True, nullable=False, index=True) # sk-xxx
  205. name = Column(String(128), nullable=False, default="default")
  206. status = Column(String(16), default="active") # active | revoked
  207. last_used_at = Column(DateTime, nullable=True)
  208. created_at = Column(DateTime, default=datetime.utcnow)
  209. async def get_db() -> AsyncSession:
  210. async with async_session() as session:
  211. yield session