training_service.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.db import async_session, TrainingJobModel
  8. from app.core.job_queue import JobStatus, TrainingJob, job_queue
  9. from app.core.logging import logger
  10. from sqlalchemy import select
  11. settings = get_settings()
  12. async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
  13. """校验配置、创建任务记录、加入队列。"""
  14. job_id = str(uuid.uuid4())
  15. model_id = config.get("model_id", "")
  16. model_type = config.get("model_type", "text")
  17. dataset_id = config.get("dataset_id", "")
  18. peft_method = config.get("peft_method", "lora")
  19. task_type = config.get("task_type", "sft")
  20. dataset_template = config.get("dataset_template", "alpaca")
  21. # 写入数据库
  22. record = TrainingJobModel(
  23. id=job_id,
  24. model_id=model_id,
  25. model_type=model_type,
  26. dataset_id=dataset_id,
  27. peft_method=peft_method,
  28. task_type=task_type,
  29. dataset_template=dataset_template,
  30. status="pending",
  31. epochs=config.get("epochs", 3),
  32. batch_size=config.get("batch_size", 4),
  33. gradient_accumulation=config.get("gradient_accumulation", 4),
  34. learning_rate=config.get("learning_rate", 2e-4),
  35. max_seq_length=config.get("max_seq_length", 2048),
  36. warmup_ratio=config.get("warmup_ratio", 0.05),
  37. save_strategy=config.get("save_strategy", "epoch"),
  38. eval_strategy=config.get("eval_strategy", "epoch"),
  39. eval_steps=config.get("eval_steps", 100),
  40. lora_r=config.get("lora_r", 16),
  41. lora_alpha=config.get("lora_alpha", 32),
  42. lora_dropout=config.get("lora_dropout", 0.05),
  43. lora_target_modules=config.get("lora_target_modules", "all-linear"),
  44. qlora_bits=config.get("qlora_bits", 4),
  45. created_at=datetime.utcnow(),
  46. )
  47. async with async_session() as session:
  48. session.add(record)
  49. await session.commit()
  50. # 加入 JobQueue
  51. # 如果启用 DeepSpeed,生成配置文件
  52. if config.get("deepspeed", False):
  53. ds_config_path = _generate_deepspeed_config()
  54. config["deepspeed"] = ds_config_path
  55. job = TrainingJob(
  56. id=job_id,
  57. model_id=model_id,
  58. model_type=model_type,
  59. peft_method=peft_method,
  60. dataset_id=dataset_id,
  61. config=config,
  62. status=JobStatus.PENDING,
  63. )
  64. await job_queue.enqueue(job_id, job)
  65. logger.info(f"Training job created: {job_id}")
  66. return {
  67. "id": job_id,
  68. "model_id": model_id,
  69. "model_type": model_type,
  70. "peft_method": peft_method,
  71. "status": "pending",
  72. "created_at": record.created_at.isoformat(),
  73. }
  74. async def list_training_jobs() -> list[dict[str, Any]]:
  75. """列出所有训练任务。"""
  76. async with async_session() as session:
  77. result = await session.execute(select(TrainingJobModel).order_by(TrainingJobModel.created_at.desc()))
  78. records = result.scalars().all()
  79. return [_job_to_dict(r) for r in records]
  80. async def get_training_job(job_id: str) -> dict[str, Any] | None:
  81. """获取指定任务详情。"""
  82. async with async_session() as session:
  83. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  84. record = result.scalar_one_or_none()
  85. if record:
  86. return _job_to_dict(record)
  87. return None
  88. async def cancel_training_job(job_id: str) -> dict[str, Any]:
  89. """向运行中的任务发送取消信号。"""
  90. await job_queue.cancel(job_id)
  91. async with async_session() as session:
  92. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  93. record = result.scalar_one_or_none()
  94. if record:
  95. record.status = "cancelled"
  96. record.finished_at = datetime.utcnow()
  97. await session.commit()
  98. logger.info(f"Job cancelled: {job_id}")
  99. return {"status": "cancelled"}
  100. async def update_job_in_db(job):
  101. """JobQueue 回调:同步 job 状态到数据库。"""
  102. try:
  103. async with async_session() as session:
  104. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job.id))
  105. record = result.scalar_one_or_none()
  106. if record:
  107. record.status = job.status.value if hasattr(job.status, "value") else str(job.status)
  108. record.progress = job.progress
  109. record.current_epoch = job.current_epoch
  110. record.current_step = job.current_step
  111. record.total_steps = job.total_steps
  112. record.loss = job.loss
  113. record.adapter_path = job.adapter_path
  114. record.error_message = job.error_message
  115. if job.status == JobStatus.TRAINING and not record.started_at:
  116. record.started_at = datetime.utcnow()
  117. if job.status.is_terminal:
  118. record.finished_at = datetime.utcnow()
  119. await session.commit()
  120. except Exception as e:
  121. logger.error(f"Failed to update job {job.id} in DB: {e}")
  122. def _job_to_dict(r) -> dict[str, Any]:
  123. return {
  124. "id": r.id,
  125. "model_id": r.model_id,
  126. "model_type": r.model_type,
  127. "peft_method": r.peft_method,
  128. "status": r.status,
  129. "progress": r.progress or 0.0,
  130. "current_epoch": r.current_epoch or 0,
  131. "current_step": r.current_step or 0,
  132. "total_steps": r.total_steps or 0,
  133. "loss": r.loss,
  134. "created_at": r.created_at.isoformat() if r.created_at else "",
  135. "started_at": r.started_at.isoformat() if r.started_at else None,
  136. "finished_at": r.finished_at.isoformat() if r.finished_at else None,
  137. "error_message": r.error_message,
  138. "adapter_path": r.adapter_path,
  139. }
  140. def _generate_deepspeed_config(stage: int = 2) -> str:
  141. """生成 DeepSpeed 配置文件,返回文件路径。"""
  142. import json
  143. from app.config import get_settings
  144. settings = get_settings()
  145. ds_config = {
  146. "fp16": {"enabled": True},
  147. "zero_optimization": {
  148. "stage": stage,
  149. "offload_optimizer": {"device": "cpu", "pin_memory": True},
  150. "offload_param": {"device": "cpu", "pin_memory": True},
  151. "overlap_comm": True,
  152. "contiguous_gradients": True,
  153. "reduce_bucket_size": "auto",
  154. "stage3_prefetch_bucket_size": "auto",
  155. "stage3_param_persistence_threshold": "auto",
  156. } if stage == 3 else {
  157. "stage": stage,
  158. "offload_optimizer": {"device": "cpu", "pin_memory": True},
  159. "allgather_partitions": True,
  160. "allgather_bucket_size": 2e8,
  161. "overlap_comm": True,
  162. "reduce_scatter": True,
  163. "reduce_bucket_size": 2e8,
  164. "contiguous_gradients": True,
  165. },
  166. "gradient_accumulation_steps": "auto",
  167. "gradient_clipping": "auto",
  168. "steps_per_print": 10,
  169. "train_batch_size": "auto",
  170. "train_micro_batch_size_per_gpu": "auto",
  171. "wall_clock_breakdown": False,
  172. }
  173. config_path = settings.data_dir / "deepspeed_config.json"
  174. with open(config_path, "w") as f:
  175. json.dump(ds_config, f, indent=2)
  176. logger.info(f"DeepSpeed config generated: {config_path}")
  177. return str(config_path)