"""远程训练入口脚本 — 在算力节点上执行。 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。 """ import asyncio import json import os import sys import time import traceback from datetime import datetime, timezone from pathlib import Path # 禁用 FlashAttention os.environ["PYTORCH_NO_FLASH"] = "1" os.environ["FLASH_ATTENTION_ENABLED"] = "0" # 禁用 torch.compile,避免 fork 大量 inductor worker 进程 os.environ["PT2_COMPILE"] = "0" os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1" # 限制训练只用 GPU 2 和 3(GPU 0/1 被 VLLM 占用) # 沐曦 GPU 优先用 METAX_VISIBLE_DEVICES,同时设 CUDA_VISIBLE_DEVICES 兜底 os.environ["METAX_VISIBLE_DEVICES"] = "2,3" os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" _progress_log_file = None # 直接从环境变量读取配置,避免引入 pydantic-settings _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) _PROCESSED_DIR = _DATA_DIR / "processed" _ADAPTERS_DIR = _DATA_DIR / "adapters" _MODELS_DIR = _DATA_DIR / "models" def _remote_log(msg: str): """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。""" print(f"[remote_train] {msg}", file=sys.stderr) def _init_log_file(job_id: str): """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。""" global _progress_log_file log_dir = _DATA_DIR / "logs" log_dir.mkdir(parents=True, exist_ok=True) _progress_log_file = log_dir / f"{job_id}.jsonl" _write_log(type="start", job_id=job_id) def _write_log(**kwargs): """追加一行 JSON 到共享日志文件。""" if _progress_log_file: entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs} with open(_progress_log_file, "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") f.flush() class FileProgressCallback: """HuggingFace Trainer 回调 — 写进度到共享日志文件。 只实现关心的回调,其余通过 __getattr__ 自动忽略。 """ def __init__(self, job_id: str): self.job_id = job_id def on_log(self, args, state, control, logs=None, **kwargs): if logs and "loss" in logs: _write_log(type="progress", epoch=int(state.epoch or 0), step=state.global_step, total_steps=state.max_steps or 0, loss=round(logs["loss"], 4), learning_rate=round(logs.get("learning_rate", 0), 8)) def on_epoch_begin(self, args, state, control, **kwargs): _write_log(type="epoch_begin", epoch=int(state.epoch or 0)) def on_epoch_end(self, args, state, control, metrics=None, **kwargs): _write_log(type="epoch_done", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None, eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None) def on_train_end(self, args, state, control, **kwargs): _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0), adapter_path=args.output_dir) def on_train_begin(self, args, state, control, **kwargs): _write_log(type="status", status="training") def on_save(self, args, state, control, **kwargs): _write_log(type="save", step=state.global_step) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics: _write_log(type="evaluate", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss"), eval_accuracy=metrics.get("eval_accuracy")) def __getattr__(self, name): """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。""" return lambda *args, **kwargs: None async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict): """执行单个训练任务(远程调用入口)。""" _init_log_file(job_id) _remote_log(f"=== Training job started: {job_id} ===") _remote_log(f"model_id={model_id}, model_type={model_type}") _remote_log(f"dataset_path={dataset_path}") _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}") try: # dataset_path 由主节点直接传入 if not dataset_path or not Path(dataset_path).exists(): raise FileNotFoundError(f"Dataset not found: {dataset_path}") _remote_log(f"Dataset file exists: {dataset_path}") _write_log(type="status", status="preprocessing") _remote_log("Step 1: Preprocessing dataset...") # 预处理 processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "alpaca") _remote_log(f" task_type={task_type}, template={template}") _remote_log(f" output_path={processed_path}") # 选择引擎 _remote_log(f" Selecting engine for model_type={model_type}...") if model_type == "vision": from app.engines.vision_engine import vision_engine engine = vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine engine = multimodal_engine else: from app.engines.text_engine import text_engine engine = text_engine _remote_log(f" Engine loaded: {engine.__class__.__name__}") peft_method = config.get("peft_method", "lora") _remote_log(f" PEFT method: {peft_method}") _remote_log(" Running preprocess_dataset...") await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) _remote_log(f" Preprocessing done, output: {processed_path}") _write_log(type="status", status="loading_model") _remote_log(f"Step 2: Loading model: {model_id}...") # 加载模型 quantization_mode = "4bit" if peft_method == "qlora" else None _remote_log(f" Quantization: {quantization_mode}") await engine.load_model(model_id, quantization=quantization_mode) _remote_log(" Model loaded successfully") # 构建 PEFT 配置 _remote_log("Step 3: Building PEFT config...") peft_config = engine.get_peft_config(peft_method, config) _remote_log(" PEFT config built") _write_log(type="status", status="training") _remote_log("Step 4: Starting training...") # 训练 — 传入文件日志回调替代 WebSocket 回调 start_time = time.time() file_cb = FileProgressCallback(job_id) adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, callbacks=[file_cb], ) elapsed = round(time.time() - start_time, 2) _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed) _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)") _remote_log(f"=== Training job finished: {job_id} ===") return adapter_path except Exception as e: tb = traceback.format_exc() _write_log(type="error", message=str(e), traceback=tb) _remote_log(f"ERROR: {e}") _remote_log(tb) _remote_log(f"=== Training job failed: {job_id} ===") raise def main(): """命令行入口:python -m app.engines.remote_train """ if len(sys.argv) < 6: print("Usage: python -m app.engines.remote_train ") sys.exit(1) job_id = sys.argv[1] model_id = sys.argv[2] model_type = sys.argv[3] dataset_id = sys.argv[4] config_path = sys.argv[5] with open(config_path, encoding="utf-8") as f: config = json.load(f) asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config)) if __name__ == "__main__": main()