remote_train.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """远程训练入口脚本 — 在算力节点上执行。"""
  2. import asyncio
  3. import json
  4. import os
  5. import sys
  6. import signal
  7. import time
  8. import traceback
  9. from datetime import datetime, timezone
  10. from pathlib import Path
  11. # 禁用 FlashAttention
  12. os.environ["PYTORCH_NO_FLASH"] = "1"
  13. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  14. _progress_log_file = None
  15. def _init_log_file(data_dir: Path, job_id: str):
  16. """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
  17. global _progress_log_file
  18. log_dir = data_dir / "logs"
  19. log_dir.mkdir(parents=True, exist_ok=True)
  20. _progress_log_file = log_dir / f"{job_id}.jsonl"
  21. _write_log(type="start", job_id=job_id)
  22. def _write_log(**kwargs):
  23. """追加一行 JSON 到共享日志文件。"""
  24. if _progress_log_file:
  25. entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
  26. with open(_progress_log_file, "a", encoding="utf-8") as f:
  27. f.write(json.dumps(entry, ensure_ascii=False) + "\n")
  28. f.flush()
  29. class FileProgressCallback:
  30. """HuggingFace Trainer 回调 — 写进度到共享日志文件。"""
  31. def __init__(self, job_id: str):
  32. self.job_id = job_id
  33. def on_log(self, args, state, control, logs=None, **kwargs):
  34. if logs and "loss" in logs:
  35. _write_log(type="progress", epoch=int(state.epoch or 0),
  36. step=state.global_step, total_steps=state.max_steps or 0,
  37. loss=round(logs["loss"], 4),
  38. learning_rate=round(logs.get("learning_rate", 0), 8))
  39. def on_epoch_begin(self, args, state, control, **kwargs):
  40. _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
  41. def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
  42. _write_log(type="epoch_done", epoch=int(state.epoch or 0),
  43. eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
  44. eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
  45. def on_train_end(self, args, state, control, **kwargs):
  46. _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
  47. adapter_path=args.output_dir)
  48. def on_train_begin(self, args, state, control, **kwargs):
  49. _write_log(type="status", status="training")
  50. def on_save(self, args, state, control, **kwargs):
  51. _write_log(type="save", step=state.global_step)
  52. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  53. if metrics:
  54. _write_log(type="evaluate", epoch=int(state.epoch or 0),
  55. eval_loss=metrics.get("eval_loss"),
  56. eval_accuracy=metrics.get("eval_accuracy"))
  57. async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict):
  58. """执行单个训练任务(远程调用入口)。"""
  59. from app.config import get_settings
  60. from app.core.logging import logger
  61. settings = get_settings()
  62. _init_log_file(settings.data_dir, job_id)
  63. try:
  64. # dataset_path 由主节点直接传入
  65. if not dataset_path or not Path(dataset_path).exists():
  66. raise FileNotFoundError(f"Dataset not found: {dataset_path}")
  67. _write_log(type="status", status="preprocessing")
  68. # 预处理
  69. processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
  70. task_type = config.get("task_type", "sft")
  71. template = config.get("dataset_template", "alpaca")
  72. # DEBUG: 诊断权限
  73. import stat
  74. from pathlib import Path
  75. proc_dir = settings.processed_dir
  76. _write_log(type="debug",
  77. proc_dir=str(proc_dir),
  78. proc_dir_exists=proc_dir.exists(),
  79. proc_dir_writable=os.access(proc_dir, os.W_OK) if proc_dir.exists() else False,
  80. dataset_path=dataset_path,
  81. dataset_exists=Path(dataset_path).exists())
  82. if proc_dir.exists():
  83. st = proc_dir.stat()
  84. _write_log(type="debug",
  85. proc_dir_mode=oct(st.st_mode),
  86. proc_dir_uid=st.st_uid,
  87. proc_dir_gid=st.st_gid,
  88. my_uid=os.getuid(),
  89. my_gid=os.getgid())
  90. # 选择引擎
  91. if model_type == "vision":
  92. from app.engines.vision_engine import vision_engine
  93. engine = vision_engine
  94. elif model_type == "multimodal":
  95. from app.engines.multimodal_engine import multimodal_engine
  96. engine = multimodal_engine
  97. else:
  98. from app.engines.text_engine import text_engine
  99. engine = text_engine
  100. peft_method = config.get("peft_method", "lora")
  101. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  102. _write_log(type="status", status="loading_model")
  103. # 加载模型
  104. await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
  105. # 构建 PEFT 配置
  106. peft_config = engine.get_peft_config(peft_method, config)
  107. _write_log(type="status", status="training")
  108. # 训练 — 传入文件日志回调替代 WebSocket 回调
  109. start_time = time.time()
  110. file_cb = FileProgressCallback(job_id)
  111. adapter_path = await engine.train(
  112. job_id=job_id,
  113. dataset_path=processed_path,
  114. peft_config=peft_config,
  115. training_args=config,
  116. callbacks=[file_cb],
  117. )
  118. elapsed = round(time.time() - start_time, 2)
  119. _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
  120. logger.info(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
  121. return adapter_path
  122. except Exception as e:
  123. _write_log(type="error", message=str(e), traceback=traceback.format_exc())
  124. logger.error(f"Remote training failed: {job_id} - {e}")
  125. raise
  126. def main():
  127. """命令行入口:python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>"""
  128. if len(sys.argv) < 6:
  129. print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>")
  130. sys.exit(1)
  131. job_id = sys.argv[1]
  132. model_id = sys.argv[2]
  133. model_type = sys.argv[3]
  134. dataset_id = sys.argv[4]
  135. config_path = sys.argv[5]
  136. with open(config_path, encoding="utf-8") as f:
  137. config = json.load(f)
  138. asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config))
  139. if __name__ == "__main__":
  140. main()