remote_train.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """远程训练入口脚本 — 在算力节点上执行。
  2. 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。
  3. """
  4. import asyncio
  5. import json
  6. import os
  7. import sys
  8. import time
  9. import traceback
  10. from datetime import datetime, timezone
  11. from pathlib import Path
  12. # 禁用 FlashAttention
  13. os.environ["PYTORCH_NO_FLASH"] = "1"
  14. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  15. # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存)
  16. os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  17. # 禁用 torch.compile,避免 fork 大量 inductor worker 进程
  18. os.environ["PT2_COMPILE"] = "0"
  19. os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1"
  20. # 限制训练只用 GPU 2 和 3(GPU 0/1 被 VLLM 占用)
  21. # 沐曦 MPS 模式下 CUDA_VISIBLE_DEVICES 可能干扰设备映射,
  22. # 只设 METAX_VISIBLE_DEVICES,device_map 里用物理 GPU 号手动指定。
  23. os.environ["METAX_VISIBLE_DEVICES"] = "2,3"
  24. # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU
  25. os.environ["MACA_MPS_MODE"] = "1"
  26. _progress_log_file = None
  27. # 直接从环境变量读取配置,避免引入 pydantic-settings
  28. _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
  29. _PROCESSED_DIR = _DATA_DIR / "processed"
  30. _ADAPTERS_DIR = _DATA_DIR / "adapters"
  31. _MODELS_DIR = _DATA_DIR / "models"
  32. def _remote_log(msg: str):
  33. """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。"""
  34. print(f"[remote_train] {msg}", file=sys.stderr)
  35. def _init_log_file(job_id: str):
  36. """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
  37. global _progress_log_file
  38. log_dir = _DATA_DIR / "logs"
  39. log_dir.mkdir(parents=True, exist_ok=True)
  40. _progress_log_file = log_dir / f"{job_id}.jsonl"
  41. _write_log(type="start", job_id=job_id)
  42. def _write_log(**kwargs):
  43. """追加一行 JSON 到共享日志文件。"""
  44. if _progress_log_file:
  45. entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
  46. with open(_progress_log_file, "a", encoding="utf-8") as f:
  47. f.write(json.dumps(entry, ensure_ascii=False) + "\n")
  48. f.flush()
  49. class FileProgressCallback:
  50. """HuggingFace Trainer 回调 — 写进度到共享日志文件。
  51. 只实现关心的回调,其余通过 __getattr__ 自动忽略。
  52. """
  53. def __init__(self, job_id: str):
  54. self.job_id = job_id
  55. def on_log(self, args, state, control, logs=None, **kwargs):
  56. if logs and "loss" in logs:
  57. _write_log(type="progress", epoch=int(state.epoch or 0),
  58. step=state.global_step, total_steps=state.max_steps or 0,
  59. loss=round(logs["loss"], 4),
  60. learning_rate=round(logs.get("learning_rate", 0), 8))
  61. def on_epoch_begin(self, args, state, control, **kwargs):
  62. _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
  63. def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
  64. _write_log(type="epoch_done", epoch=int(state.epoch or 0),
  65. eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
  66. eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
  67. def on_train_end(self, args, state, control, **kwargs):
  68. _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
  69. adapter_path=args.output_dir)
  70. def on_train_begin(self, args, state, control, **kwargs):
  71. _write_log(type="status", status="training")
  72. def on_save(self, args, state, control, **kwargs):
  73. _write_log(type="save", step=state.global_step)
  74. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  75. if metrics:
  76. _write_log(type="evaluate", epoch=int(state.epoch or 0),
  77. eval_loss=metrics.get("eval_loss"),
  78. eval_accuracy=metrics.get("eval_accuracy"))
  79. def __getattr__(self, name):
  80. """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。"""
  81. return lambda *args, **kwargs: None
  82. async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict):
  83. """执行单个训练任务(远程调用入口)。"""
  84. _init_log_file(job_id)
  85. _remote_log(f"=== Training job started: {job_id} ===")
  86. _remote_log(f"model_id={model_id}, model_type={model_type}")
  87. _remote_log(f"dataset_path={dataset_path}")
  88. _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}")
  89. try:
  90. # dataset_path 由主节点直接传入
  91. if not dataset_path or not Path(dataset_path).exists():
  92. raise FileNotFoundError(f"Dataset not found: {dataset_path}")
  93. _remote_log(f"Dataset file exists: {dataset_path}")
  94. _write_log(type="status", status="preprocessing")
  95. _remote_log("Step 1: Preprocessing dataset...")
  96. # 预处理
  97. processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl")
  98. task_type = config.get("task_type", "sft")
  99. template = config.get("dataset_template", "auto")
  100. _remote_log(f" task_type={task_type}, template={template}")
  101. _remote_log(f" output_path={processed_path}")
  102. # 选择引擎
  103. _remote_log(f" Selecting engine for model_type={model_type}...")
  104. if model_type == "vision":
  105. from app.engines.vision_engine import vision_engine
  106. engine = vision_engine
  107. elif model_type == "multimodal":
  108. from app.engines.multimodal_engine import multimodal_engine
  109. engine = multimodal_engine
  110. else:
  111. from app.engines.text_engine import text_engine
  112. engine = text_engine
  113. _remote_log(f" Engine loaded: {engine.__class__.__name__}")
  114. peft_method = config.get("peft_method", "lora")
  115. _remote_log(f" PEFT method: {peft_method}")
  116. _remote_log(" Running preprocess_dataset...")
  117. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  118. _remote_log(f" Preprocessing done, output: {processed_path}")
  119. _write_log(type="status", status="loading_model")
  120. _remote_log(f"Step 2: Loading model: {model_id}...")
  121. # 加载模型
  122. quantization_mode = "4bit" if peft_method == "qlora" else None
  123. _remote_log(f" Quantization: {quantization_mode}")
  124. await engine.load_model(model_id, quantization=quantization_mode)
  125. _remote_log(" Model loaded successfully")
  126. # 构建 PEFT 配置
  127. _remote_log("Step 3: Building PEFT config...")
  128. peft_config = engine.get_peft_config(peft_method, config)
  129. _remote_log(" PEFT config built")
  130. _write_log(type="status", status="training")
  131. _remote_log("Step 4: Starting training...")
  132. # 训练 — 传入文件日志回调替代 WebSocket 回调
  133. start_time = time.time()
  134. file_cb = FileProgressCallback(job_id)
  135. adapter_path = await engine.train(
  136. job_id=job_id,
  137. dataset_path=processed_path,
  138. peft_config=peft_config,
  139. training_args=config,
  140. callbacks=[file_cb],
  141. )
  142. elapsed = round(time.time() - start_time, 2)
  143. _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
  144. _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
  145. _remote_log(f"=== Training job finished: {job_id} ===")
  146. return adapter_path
  147. except Exception as e:
  148. tb = traceback.format_exc()
  149. _write_log(type="error", message=str(e), traceback=tb)
  150. _remote_log(f"ERROR: {e}")
  151. _remote_log(tb)
  152. _remote_log(f"=== Training job failed: {job_id} ===")
  153. raise
  154. def main():
  155. """命令行入口:python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>"""
  156. if len(sys.argv) < 6:
  157. print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>")
  158. sys.exit(1)
  159. job_id = sys.argv[1]
  160. model_id = sys.argv[2]
  161. model_type = sys.argv[3]
  162. dataset_id = sys.argv[4]
  163. config_path = sys.argv[5]
  164. with open(config_path, encoding="utf-8") as f:
  165. config = json.load(f)
  166. asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config))
  167. if __name__ == "__main__":
  168. main()