training_service.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.db import async_session, TrainingJobModel
  8. from app.core.job_queue import JobStatus, TrainingJob, job_queue
  9. from app.core.logging import logger
  10. from sqlalchemy import select
  11. settings = get_settings()
  12. async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
  13. """校验配置、创建任务记录、加入队列。"""
  14. job_id = str(uuid.uuid4())
  15. model_id = config.get("model_id", "")
  16. model_type = config.get("model_type", "text")
  17. dataset_id = config.get("dataset_id", "")
  18. peft_method = config.get("peft_method", "lora")
  19. task_type = config.get("task_type", "sft")
  20. dataset_template = config.get("dataset_template", "auto")
  21. # 写入数据库
  22. record = TrainingJobModel(
  23. id=job_id,
  24. model_id=model_id,
  25. model_type=model_type,
  26. dataset_id=dataset_id,
  27. peft_method=peft_method,
  28. task_type=task_type,
  29. dataset_template=dataset_template,
  30. status="pending",
  31. epochs=config.get("epochs", 3),
  32. batch_size=config.get("batch_size", 4),
  33. gradient_accumulation=config.get("gradient_accumulation", 4),
  34. learning_rate=config.get("learning_rate", 2e-4),
  35. max_seq_length=config.get("max_seq_length", 2048),
  36. warmup_ratio=config.get("warmup_ratio", 0.05),
  37. save_strategy=config.get("save_strategy", "epoch"),
  38. eval_strategy=config.get("eval_strategy", "epoch"),
  39. eval_steps=config.get("eval_steps", 100),
  40. lora_r=config.get("lora_r", 16),
  41. lora_alpha=config.get("lora_alpha", 32),
  42. lora_dropout=config.get("lora_dropout", 0.05),
  43. lora_target_modules=config.get("lora_target_modules", "all-linear"),
  44. qlora_bits=config.get("qlora_bits", 4),
  45. # PPO fields
  46. ppo_epochs=config.get("ppo_epochs", 4),
  47. vf_coef=config.get("vf_coef", 0.1),
  48. kl_coef=config.get("kl_coef", 0.2),
  49. response_length=config.get("response_length", 512),
  50. reward_model_path=config.get("reward_model_path"),
  51. reward_type=config.get("reward_type", "heuristic"),
  52. created_at=datetime.utcnow(),
  53. )
  54. async with async_session() as session:
  55. session.add(record)
  56. await session.commit()
  57. # 加入 JobQueue
  58. num_gpus = config.get("num_gpus", 1)
  59. # 单 GPU 模式禁用 DeepSpeed;多 GPU 使用 DDP,DeepSpeed 暂不支持
  60. if config.get("deepspeed", False):
  61. config["deepspeed"] = False
  62. if num_gpus < 2:
  63. logger.warning("DeepSpeed requires multiple GPUs, but only 1 GPU is configured. DeepSpeed disabled.")
  64. else:
  65. logger.warning("DeepSpeed is not yet supported on MetaX GPU. Using DDP instead. DeepSpeed disabled.")
  66. logger.info(f"Training job {job_id}: num_gpus={num_gpus}, batch_size={config.get('batch_size', 4)}")
  67. job = TrainingJob(
  68. id=job_id,
  69. model_id=model_id,
  70. model_type=model_type,
  71. peft_method=peft_method,
  72. dataset_id=dataset_id,
  73. config=config,
  74. status=JobStatus.PENDING,
  75. )
  76. await job_queue.enqueue(job_id, job)
  77. logger.info(f"Training job created: {job_id}")
  78. return {
  79. "id": job_id,
  80. "model_id": model_id,
  81. "model_type": model_type,
  82. "peft_method": peft_method,
  83. "status": "pending",
  84. "created_at": record.created_at.isoformat(),
  85. }
  86. async def list_training_jobs() -> list[dict[str, Any]]:
  87. """列出所有训练任务。"""
  88. async with async_session() as session:
  89. result = await session.execute(select(TrainingJobModel).order_by(TrainingJobModel.created_at.desc()))
  90. records = result.scalars().all()
  91. return [_job_to_dict(r) for r in records]
  92. async def get_training_job(job_id: str) -> dict[str, Any] | None:
  93. """获取指定任务详情。"""
  94. async with async_session() as session:
  95. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  96. record = result.scalar_one_or_none()
  97. if record:
  98. return _job_to_dict(record)
  99. return None
  100. async def cancel_training_job(job_id: str) -> dict[str, Any]:
  101. """向运行中的任务发送取消信号。"""
  102. await job_queue.cancel(job_id)
  103. async with async_session() as session:
  104. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  105. record = result.scalar_one_or_none()
  106. if record:
  107. record.status = "cancelled"
  108. record.finished_at = datetime.utcnow()
  109. await session.commit()
  110. logger.info(f"Job cancelled: {job_id}")
  111. return {"status": "cancelled"}
  112. async def update_job_in_db(job):
  113. """JobQueue 回调:同步 job 状态到数据库。"""
  114. try:
  115. async with async_session() as session:
  116. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job.id))
  117. record = result.scalar_one_or_none()
  118. if record:
  119. record.status = job.status.value if hasattr(job.status, "value") else str(job.status)
  120. record.progress = job.progress
  121. record.current_epoch = job.current_epoch
  122. record.current_step = job.current_step
  123. record.total_steps = job.total_steps
  124. record.loss = job.loss
  125. record.adapter_path = job.adapter_path
  126. record.error_message = job.error_message
  127. if job.status == JobStatus.TRAINING and not record.started_at:
  128. record.started_at = datetime.utcnow()
  129. if job.status.is_terminal:
  130. record.finished_at = datetime.utcnow()
  131. await session.commit()
  132. except Exception as e:
  133. logger.error(f"Failed to update job {job.id} in DB: {e}")
  134. def _job_to_dict(r) -> dict[str, Any]:
  135. return {
  136. "id": r.id,
  137. "model_id": r.model_id,
  138. "model_type": r.model_type,
  139. "peft_method": r.peft_method,
  140. "status": r.status,
  141. "progress": r.progress or 0.0,
  142. "current_epoch": r.current_epoch or 0,
  143. "current_step": r.current_step or 0,
  144. "total_steps": r.total_steps or 0,
  145. "loss": r.loss,
  146. "created_at": r.created_at.isoformat() if r.created_at else "",
  147. "started_at": r.started_at.isoformat() if r.started_at else None,
  148. "finished_at": r.finished_at.isoformat() if r.finished_at else None,
  149. "error_message": r.error_message,
  150. "adapter_path": r.adapter_path,
  151. }
  152. def _generate_deepspeed_config(stage: int = 2) -> str:
  153. """生成 DeepSpeed 配置文件,返回文件路径。"""
  154. import json
  155. from app.config import get_settings
  156. settings = get_settings()
  157. ds_config = {
  158. "fp16": {"enabled": True},
  159. "zero_optimization": {
  160. "stage": stage,
  161. "offload_optimizer": {"device": "cpu", "pin_memory": True},
  162. "offload_param": {"device": "cpu", "pin_memory": True},
  163. "overlap_comm": True,
  164. "contiguous_gradients": True,
  165. "reduce_bucket_size": "auto",
  166. "stage3_prefetch_bucket_size": "auto",
  167. "stage3_param_persistence_threshold": "auto",
  168. } if stage == 3 else {
  169. "stage": stage,
  170. "offload_optimizer": {"device": "cpu", "pin_memory": True},
  171. "allgather_partitions": True,
  172. "allgather_bucket_size": 2e8,
  173. "overlap_comm": True,
  174. "reduce_scatter": True,
  175. "reduce_bucket_size": 2e8,
  176. "contiguous_gradients": True,
  177. },
  178. "gradient_accumulation_steps": "auto",
  179. "gradient_clipping": "auto",
  180. "steps_per_print": 10,
  181. "train_batch_size": "auto",
  182. "train_micro_batch_size_per_gpu": "auto",
  183. "wall_clock_breakdown": False,
  184. }
  185. config_path = settings.data_dir / "deepspeed_config.json"
  186. with open(config_path, "w") as f:
  187. json.dump(ds_config, f, indent=2)
  188. logger.info(f"DeepSpeed config generated: {config_path}")
  189. return str(config_path)