training_service.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.db import async_session, TrainingJobModel
  8. from app.core.job_queue import JobStatus, TrainingJob, job_queue
  9. from app.core.logging import logger
  10. from sqlalchemy import select
  11. settings = get_settings()
  12. async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
  13. """校验配置、创建任务记录、加入队列。"""
  14. job_id = str(uuid.uuid4())
  15. model_id = config.get("model_id", "")
  16. model_type = config.get("model_type", "text")
  17. dataset_id = config.get("dataset_id", "")
  18. peft_method = config.get("peft_method", "lora")
  19. task_type = config.get("task_type", "sft")
  20. dataset_template = config.get("dataset_template", "auto")
  21. # 写入数据库
  22. record = TrainingJobModel(
  23. id=job_id,
  24. model_id=model_id,
  25. model_type=model_type,
  26. dataset_id=dataset_id,
  27. peft_method=peft_method,
  28. task_type=task_type,
  29. dataset_template=dataset_template,
  30. status="pending",
  31. epochs=config.get("epochs", 3),
  32. batch_size=config.get("batch_size", 4),
  33. gradient_accumulation=config.get("gradient_accumulation", 4),
  34. learning_rate=config.get("learning_rate", 2e-4),
  35. max_seq_length=config.get("max_seq_length", 2048),
  36. warmup_ratio=config.get("warmup_ratio", 0.05),
  37. save_strategy=config.get("save_strategy", "epoch"),
  38. eval_strategy=config.get("eval_strategy", "epoch"),
  39. eval_steps=config.get("eval_steps", 100),
  40. lora_r=config.get("lora_r", 16),
  41. lora_alpha=config.get("lora_alpha", 32),
  42. lora_dropout=config.get("lora_dropout", 0.05),
  43. lora_target_modules=config.get("lora_target_modules", "all-linear"),
  44. qlora_bits=config.get("qlora_bits", 4),
  45. # PPO fields
  46. ppo_epochs=config.get("ppo_epochs", 4),
  47. vf_coef=config.get("vf_coef", 0.1),
  48. kl_coef=config.get("kl_coef", 0.2),
  49. response_length=config.get("response_length", 512),
  50. reward_model_path=config.get("reward_model_path"),
  51. reward_type=config.get("reward_type", "heuristic"),
  52. created_at=datetime.utcnow(),
  53. )
  54. async with async_session() as session:
  55. session.add(record)
  56. await session.commit()
  57. # 加入 JobQueue
  58. # DeepSpeed 需要多 GPU,单卡模式已禁用
  59. if config.get("deepspeed", False):
  60. config["deepspeed"] = False
  61. logger.warning("DeepSpeed requires multiple GPUs, but only GPU 3 is available. DeepSpeed disabled.")
  62. job = TrainingJob(
  63. id=job_id,
  64. model_id=model_id,
  65. model_type=model_type,
  66. peft_method=peft_method,
  67. dataset_id=dataset_id,
  68. config=config,
  69. status=JobStatus.PENDING,
  70. )
  71. await job_queue.enqueue(job_id, job)
  72. logger.info(f"Training job created: {job_id}")
  73. return {
  74. "id": job_id,
  75. "model_id": model_id,
  76. "model_type": model_type,
  77. "peft_method": peft_method,
  78. "status": "pending",
  79. "created_at": record.created_at.isoformat(),
  80. }
  81. async def list_training_jobs() -> list[dict[str, Any]]:
  82. """列出所有训练任务。"""
  83. async with async_session() as session:
  84. result = await session.execute(select(TrainingJobModel).order_by(TrainingJobModel.created_at.desc()))
  85. records = result.scalars().all()
  86. return [_job_to_dict(r) for r in records]
  87. async def get_training_job(job_id: str) -> dict[str, Any] | None:
  88. """获取指定任务详情。"""
  89. async with async_session() as session:
  90. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  91. record = result.scalar_one_or_none()
  92. if record:
  93. return _job_to_dict(record)
  94. return None
  95. async def cancel_training_job(job_id: str) -> dict[str, Any]:
  96. """向运行中的任务发送取消信号。"""
  97. await job_queue.cancel(job_id)
  98. async with async_session() as session:
  99. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  100. record = result.scalar_one_or_none()
  101. if record:
  102. record.status = "cancelled"
  103. record.finished_at = datetime.utcnow()
  104. await session.commit()
  105. logger.info(f"Job cancelled: {job_id}")
  106. return {"status": "cancelled"}
  107. async def update_job_in_db(job):
  108. """JobQueue 回调:同步 job 状态到数据库。"""
  109. try:
  110. async with async_session() as session:
  111. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job.id))
  112. record = result.scalar_one_or_none()
  113. if record:
  114. record.status = job.status.value if hasattr(job.status, "value") else str(job.status)
  115. record.progress = job.progress
  116. record.current_epoch = job.current_epoch
  117. record.current_step = job.current_step
  118. record.total_steps = job.total_steps
  119. record.loss = job.loss
  120. record.adapter_path = job.adapter_path
  121. record.error_message = job.error_message
  122. if job.status == JobStatus.TRAINING and not record.started_at:
  123. record.started_at = datetime.utcnow()
  124. if job.status.is_terminal:
  125. record.finished_at = datetime.utcnow()
  126. await session.commit()
  127. except Exception as e:
  128. logger.error(f"Failed to update job {job.id} in DB: {e}")
  129. def _job_to_dict(r) -> dict[str, Any]:
  130. return {
  131. "id": r.id,
  132. "model_id": r.model_id,
  133. "model_type": r.model_type,
  134. "peft_method": r.peft_method,
  135. "status": r.status,
  136. "progress": r.progress or 0.0,
  137. "current_epoch": r.current_epoch or 0,
  138. "current_step": r.current_step or 0,
  139. "total_steps": r.total_steps or 0,
  140. "loss": r.loss,
  141. "created_at": r.created_at.isoformat() if r.created_at else "",
  142. "started_at": r.started_at.isoformat() if r.started_at else None,
  143. "finished_at": r.finished_at.isoformat() if r.finished_at else None,
  144. "error_message": r.error_message,
  145. "adapter_path": r.adapter_path,
  146. }
  147. def _generate_deepspeed_config(stage: int = 2) -> str:
  148. """生成 DeepSpeed 配置文件,返回文件路径。"""
  149. import json
  150. from app.config import get_settings
  151. settings = get_settings()
  152. ds_config = {
  153. "fp16": {"enabled": True},
  154. "zero_optimization": {
  155. "stage": stage,
  156. "offload_optimizer": {"device": "cpu", "pin_memory": True},
  157. "offload_param": {"device": "cpu", "pin_memory": True},
  158. "overlap_comm": True,
  159. "contiguous_gradients": True,
  160. "reduce_bucket_size": "auto",
  161. "stage3_prefetch_bucket_size": "auto",
  162. "stage3_param_persistence_threshold": "auto",
  163. } if stage == 3 else {
  164. "stage": stage,
  165. "offload_optimizer": {"device": "cpu", "pin_memory": True},
  166. "allgather_partitions": True,
  167. "allgather_bucket_size": 2e8,
  168. "overlap_comm": True,
  169. "reduce_scatter": True,
  170. "reduce_bucket_size": 2e8,
  171. "contiguous_gradients": True,
  172. },
  173. "gradient_accumulation_steps": "auto",
  174. "gradient_clipping": "auto",
  175. "steps_per_print": 10,
  176. "train_batch_size": "auto",
  177. "train_micro_batch_size_per_gpu": "auto",
  178. "wall_clock_breakdown": False,
  179. }
  180. config_path = settings.data_dir / "deepspeed_config.json"
  181. with open(config_path, "w") as f:
  182. json.dump(ds_config, f, indent=2)
  183. logger.info(f"DeepSpeed config generated: {config_path}")
  184. return str(config_path)