"""远程训练入口脚本 — 在算力节点上执行。""" import asyncio import json import os import sys import signal import time import traceback from datetime import datetime, timezone from pathlib import Path # 禁用 FlashAttention os.environ["PYTORCH_NO_FLASH"] = "1" os.environ["FLASH_ATTENTION_ENABLED"] = "0" _progress_log_file = None def _init_log_file(data_dir: Path, job_id: str): """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。""" global _progress_log_file log_dir = data_dir / "logs" log_dir.mkdir(parents=True, exist_ok=True) _progress_log_file = log_dir / f"{job_id}.jsonl" _write_log(type="start", job_id=job_id) def _write_log(**kwargs): """追加一行 JSON 到共享日志文件。""" if _progress_log_file: entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs} with open(_progress_log_file, "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") f.flush() class FileProgressCallback: """HuggingFace Trainer 回调 — 写进度到共享日志文件。""" def __init__(self, job_id: str): self.job_id = job_id def on_log(self, args, state, control, logs=None, **kwargs): if logs and "loss" in logs: _write_log(type="progress", epoch=int(state.epoch or 0), step=state.global_step, total_steps=state.max_steps or 0, loss=round(logs["loss"], 4), learning_rate=round(logs.get("learning_rate", 0), 8)) def on_epoch_begin(self, args, state, control, **kwargs): _write_log(type="epoch_begin", epoch=int(state.epoch or 0)) def on_epoch_end(self, args, state, control, metrics=None, **kwargs): _write_log(type="epoch_done", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None, eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None) def on_train_end(self, args, state, control, **kwargs): _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0), adapter_path=args.output_dir) def on_train_begin(self, args, state, control, **kwargs): _write_log(type="status", status="training") def on_save(self, args, state, control, **kwargs): _write_log(type="save", step=state.global_step) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics: _write_log(type="evaluate", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss"), eval_accuracy=metrics.get("eval_accuracy")) async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict): """执行单个训练任务(远程调用入口)。""" from app.config import get_settings from app.core.logging import logger settings = get_settings() _init_log_file(settings.data_dir, job_id) try: # dataset_path 由主节点直接传入 if not dataset_path or not Path(dataset_path).exists(): raise FileNotFoundError(f"Dataset not found: {dataset_path}") _write_log(type="status", status="preprocessing") # 预处理 processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "alpaca") # 选择引擎 if model_type == "vision": from app.engines.vision_engine import vision_engine engine = vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine engine = multimodal_engine else: from app.engines.text_engine import text_engine engine = text_engine peft_method = config.get("peft_method", "lora") await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) _write_log(type="status", status="loading_model") # 加载模型 await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None) # 构建 PEFT 配置 peft_config = engine.get_peft_config(peft_method, config) _write_log(type="status", status="training") # 训练 — 传入文件日志回调替代 WebSocket 回调 start_time = time.time() file_cb = FileProgressCallback(job_id) adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, callbacks=[file_cb], ) elapsed = round(time.time() - start_time, 2) _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed) logger.info(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)") return adapter_path except Exception as e: _write_log(type="error", message=str(e), traceback=traceback.format_exc()) logger.error(f"Remote training failed: {job_id} - {e}") raise def main(): """命令行入口:python -m app.engines.remote_train """ if len(sys.argv) < 6: print("Usage: python -m app.engines.remote_train ") sys.exit(1) job_id = sys.argv[1] model_id = sys.argv[2] model_type = sys.argv[3] dataset_id = sys.argv[4] config_path = sys.argv[5] with open(config_path, encoding="utf-8") as f: config = json.load(f) asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config)) if __name__ == "__main__": main()