"""远程训练入口脚本 — 在算力节点上执行。 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。 """ import asyncio import json import os import re import sys import time import traceback from datetime import datetime, timezone from pathlib import Path # 禁用 FlashAttention os.environ["PYTORCH_NO_FLASH"] = "1" os.environ["FLASH_ATTENTION_ENABLED"] = "0" # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存) os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # 禁用 torch.compile,避免 fork 大量 inductor worker 进程 os.environ["PT2_COMPILE"] = "0" os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1" # CUDA_VISIBLE_DEVICES 由 docker exec 层设置,此处不再覆盖 # 单 GPU 模式: "3" (物理 GPU 3 → 逻辑 cuda:0) # 多 GPU 模式: "2,3" (物理 GPU 2,3 → 逻辑 cuda:0,1) # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU os.environ["MACA_MPS_MODE"] = "1" # Triton 编译缓存 — 持久化编译产物,避免每次训练都重新编译 os.environ["TRITON_CACHE_DIR"] = "/root/Fine-tuning/backend/data/.triton_cache" os.environ["TRITON_HOME"] = "/root/Fine-tuning/backend/data/.triton_cache" # 减少 Triton 编译时的冗余输出 os.environ["MLIR_ENABLE_DUMP"] = "0" os.environ["TRITON_INTERPRET"] = "0" _progress_log_file = None # 直接从环境变量读取配置,避免引入 pydantic-settings _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) _PROCESSED_DIR = _DATA_DIR / "processed" _ADAPTERS_DIR = _DATA_DIR / "adapters" _MODELS_DIR = _DATA_DIR / "models" def _remote_log(msg: str): """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。""" print(f"[remote_train] {msg}", file=sys.stderr) def _init_log_file(job_id: str): """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。""" global _progress_log_file log_dir = _DATA_DIR / "logs" log_dir.mkdir(parents=True, exist_ok=True) _progress_log_file = log_dir / f"{job_id}.jsonl" _write_log(type="start", job_id=job_id) def _write_log(**kwargs): """追加一行 JSON 到共享日志文件。""" if _progress_log_file: entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs} with open(_progress_log_file, "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") f.flush() class FileProgressCallback: """HuggingFace Trainer 回调 — 写进度到共享日志文件。 只实现关心的回调,其余通过 __getattr__ 自动忽略。 """ def __init__(self, job_id: str): self.job_id = job_id def on_step_end(self, args, state, control, **kwargs): """每步结束记录进度,便于观察训练是否在推进。""" _write_log(type="step", step=state.global_step, total_steps=state.max_steps or 0, epoch=round(state.epoch or 0, 2)) if state.global_step % 5 == 0 or state.global_step <= 3: _remote_log(f"Step {state.global_step}/{state.max_steps} done (epoch {state.epoch:.2f})") def on_log(self, args, state, control, logs=None, **kwargs): if logs and "loss" in logs: _write_log(type="progress", epoch=int(state.epoch or 0), step=state.global_step, total_steps=state.max_steps or 0, loss=round(logs["loss"], 4), learning_rate=round(logs.get("learning_rate", 0), 8)) def on_epoch_begin(self, args, state, control, **kwargs): _write_log(type="epoch_begin", epoch=int(state.epoch or 0)) def on_epoch_end(self, args, state, control, metrics=None, **kwargs): _write_log(type="epoch_done", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None, eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None) def on_train_end(self, args, state, control, **kwargs): _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0), adapter_path=args.output_dir) def on_train_begin(self, args, state, control, **kwargs): _write_log(type="status", status="training") def on_save(self, args, state, control, **kwargs): _write_log(type="save", step=state.global_step) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics: _write_log(type="evaluate", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss"), eval_accuracy=metrics.get("eval_accuracy")) def __getattr__(self, name): """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。""" return lambda *args, **kwargs: None async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict, rank: int = 0, local_rank: int = 0, world_size: int = 1): """执行单个训练任务(远程调用入口)。""" is_main = (rank == 0) # 只有 rank 0 写 JSONL 进度日志,避免多进程文件竞争 if is_main: _init_log_file(job_id) _remote_log(f"[rank {rank}] === Training job started: {job_id} ===") if is_main: _remote_log(f"model_id={model_id}, model_type={model_type}") _remote_log(f"dataset_path={dataset_path}") _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}") if world_size > 1: _remote_log(f"DDP: world_size={world_size}, batch_size per GPU={config.get('batch_size', 8)}") try: # dataset_path 由主节点直接传入 if not dataset_path or not Path(dataset_path).exists(): raise FileNotFoundError(f"Dataset not found: {dataset_path}") if is_main: _write_log(type="status", status="preprocessing") _remote_log("Step 1: Preprocessing dataset...") # 预处理 — DDP 模式下只有 rank 0 执行,其他 rank 等待 processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "auto") if is_main: _remote_log(f" task_type={task_type}, template={template}") # 选择引擎 if model_type == "vision": from app.engines.vision_engine import vision_engine engine = vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine engine = multimodal_engine else: from app.engines.text_engine import text_engine engine = text_engine _remote_log(f" Engine loaded: {engine.__class__.__name__}") _remote_log(" Running preprocess_dataset...") await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) _remote_log(f" Preprocessing done, output: {processed_path}") # 写同步标记,通知其他 rank 预处理完成 if world_size > 1: done_marker = _PROCESSED_DIR / f"{job_id}_preprocess_done" done_marker.write_text("done") else: # 非 rank 0 等待预处理完成 done_marker = _PROCESSED_DIR / f"{job_id}_preprocess_done" waited = 0 while not done_marker.exists() and waited < 120: await asyncio.sleep(1) waited += 1 # 选择引擎(所有 rank 都需要) if model_type == "vision": from app.engines.vision_engine import vision_engine engine = vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine engine = multimodal_engine else: from app.engines.text_engine import text_engine engine = text_engine peft_method = config.get("peft_method", "lora") if is_main: _write_log(type="status", status="loading_model") _remote_log(f"Step 2: Loading model: {model_id}...") # 加载模型 — 每个 rank 各自加载到自己的 GPU quantization_mode = "4bit" if peft_method == "qlora" else None await engine.load_model(model_id, quantization=quantization_mode) if is_main: _remote_log(" Model loaded successfully") # 构建 PEFT 配置 if is_main: _remote_log("Step 3: Building PEFT config...") peft_config = engine.get_peft_config(peft_method, config) # PPO 训练需要预下载奖励模型(只在 rank 0 下载) reward_type = config.get("reward_type", "heuristic") reward_model_path = config.get("reward_model_path") if reward_type == "model" and reward_model_path: if is_main: _remote_log(f"Step 3.5: Pre-downloading reward model: {reward_model_path}...") reward_local = str(_MODELS_DIR / reward_model_path.replace("/", "_")) if not (Path(reward_local) / "config.json").exists(): from huggingface_hub import snapshot_download snapshot_download( repo_id=reward_model_path, local_dir=reward_local, local_dir_use_symlinks=False, ) config["reward_model_path"] = reward_local if is_main: _write_log(type="status", status="training") _remote_log("Step 4: Starting training...") _remote_log("NOTE: First step may take 2-5 minutes due to Triton kernel compilation (autotuning). This is normal.") _remote_log(f"Total steps: {config.get('epochs', 3)} epochs, batch_size per GPU={config.get('batch_size', 8)}") # 训练 — 传入文件日志回调(只在 rank 0 写日志) start_time = time.time() file_cb = FileProgressCallback(job_id) if is_main else None callbacks = [file_cb] if file_cb else [] adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, callbacks=callbacks, ) if is_main: elapsed = round(time.time() - start_time, 2) _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed) _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)") _remote_log(f"=== Training job finished: {job_id} ===") return adapter_path except Exception as e: tb = traceback.format_exc() if is_main: _write_log(type="error", message=str(e), traceback=tb) _remote_log(f"[rank {rank}] ERROR: {e}") _remote_log(tb) if is_main: _remote_log(f"=== Training job failed: {job_id} ===") raise def _patch_fla_shared_memory(): """修复 fla 库 Triton kernel 共享内存溢出问题。 Qwen3.5 等混合架构模型的 Gated Delta Rule 层使用 fla 库的 Triton kernel, 反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存, 但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。 修复方式:精确替换 fla 库中控制 block size 的常量和 kernel 名称, 避免误改普通代码中的数字字面量。 """ try: import shutil # 定位 fla 包 conda_path = '/opt/conda/lib/python3.10/site-packages/fla' if os.path.isdir(conda_path): fla_base = conda_path else: import site fla_base = None for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']: candidate = os.path.join(sp, 'fla') if os.path.isdir(candidate): fla_base = candidate break if not fla_base: _remote_log("fla package not found, skipping shared memory patch") return _remote_log(f"fla package found at: {fla_base}") # 检查 fla 源码是否被旧版补丁损坏(语法错误) chunk_py = os.path.join(fla_base, 'ops', 'gated_delta_rule', 'chunk.py') source_corrupted = False if os.path.exists(chunk_py): try: with open(chunk_py, 'r') as f: compile(f.read(), chunk_py, 'exec') except SyntaxError as e: source_corrupted = True _remote_log(f"fla source corrupted (SyntaxError: {e}), will reinstall...") # 幂等检查 marker_path = os.path.join(fla_base, '_PATCHED_SM32') if os.path.exists(marker_path) and not source_corrupted: # 检查标记版本:v1 是旧版补丁(用激进正则,已污染源码),需要重装后重新打补丁 with open(marker_path) as mf: marker_content = mf.read() if 'v2' in marker_content: _remote_log("fla shared memory patch v2 already applied, skipping") to_remove = [k for k in sys.modules if k.startswith('fla')] for k in to_remove: del sys.modules[k] return else: source_corrupted = True _remote_log("Old patch v1 detected, will reinstall fla...") if source_corrupted: _remote_log("WARNING: fla source is corrupted (SyntaxError). " "Please rebuild the container to restore clean source.") return patched_files = [] for root, dirs, files in os.walk(fla_base): for fname in files: if not fname.endswith('.py'): continue fpath = os.path.join(root, fname) try: with open(fpath, 'r') as f: original = f.read() c = original changes = [] # 0. 修复 fla/utils.py 中的 maca 设备映射(沐曦 GPU 兼容) if fname == 'utils.py' and "!= 'hip'" in c: # 把 maca 也映射到 cuda c = c.replace( "device = get_available_device() if get_available_device() != 'hip' else 'cuda'", "device = get_available_device() if get_available_device() not in ('hip', 'maca') else 'cuda'" ) changes.append('maca->cuda mapping') # 1. kernel 函数名后缀: blockdim64 → blockdim32 if 'blockdim64' in c: c = c.replace('blockdim64', 'blockdim32') changes.append('blockdim64->blockdim32') # 2. 精确匹配 fla 中常见的 block size 变量赋值 # BT = 64, BK = 64, BV = 64, chunk_size = 64, BLOCK_SIZE = 64 等 # 用 \b 匹配完整变量名,避免误改其他代码 for var in ['BT', 'BK', 'BV', 'chunk_size', 'BLOCK_SIZE', 'BLOCK_M', 'BLOCK_N', 'BLOCK_K', 'BLOCK_V', 'block_size', 'block_m', 'block_n', 'block_k', 'block_v']: pattern = rf'\b{var}\s*=\s*64\b' replacement = f'{var} = 32' new_c = re.sub(pattern, replacement, c) if new_c != c: changes.append(f'{var}=64->32') c = new_c # 3. Triton autotune 装饰器中的 configs 参数值 # 例如: configs=[..., 64, ...] 或 tl.constexpr = 64 # 只替换 tl.constexpr = 64 的情况 pattern = r'tl\.constexpr\s*=\s*64\b' new_c = re.sub(pattern, 'tl.constexpr = 32', c) if new_c != c: changes.append('tl.constexpr 64->32') c = new_c # 4. num_stages 降低(减少流水线阶段,进一步降低共享内存) pattern = r'num_stages\s*=\s*([3-9]|[1-9]\d+)' new_c = re.sub(pattern, 'num_stages=1', c) if new_c != c: changes.append('num_stages->1') c = new_c if c != original: with open(fpath, 'w') as f: f.write(c) patched_files.append(f"{os.path.relpath(fpath, fla_base)}({', '.join(changes)})") except Exception as e: _remote_log(f" Warning: failed to patch {fpath}: {e}") continue # 清理 __pycache__ cache_count = 0 for root, dirs, files in os.walk(fla_base): if '__pycache__' in dirs: shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True) cache_count += 1 # 清除已缓存的 fla 模块 to_remove = [k for k in sys.modules if k.startswith('fla')] for k in to_remove: del sys.modules[k] # 写入标记文件(幂等),包含版本号 v2 with open(marker_path, 'w') as f: f.write(f"v2 patched at {datetime.now(timezone.utc).isoformat()}\n") _remote_log(f"fla shared memory patch done: {len(patched_files)} files, " f"{cache_count} caches cleared, {len(to_remove)} modules evicted") for pf in patched_files: _remote_log(f" patched: {pf}") except Exception as e: _remote_log(f"Warning: fla shared memory patch failed: {e}") import traceback as tb _remote_log(tb.format_exc()) def main(): """命令行入口。 单 GPU: python -m app.engines.remote_train 多 GPU: torchrun --nproc_per_node=N -m app.engines.remote_train """ # 解析 DDP 环境变量(torchrun 自动设置) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) if world_size > 1: # DDP 模式:只有 rank 0 执行 fla 补丁,其他 rank 等待补丁完成 if rank == 0: _remote_log(f"DDP mode: rank={rank}, local_rank={local_rank}, world_size={world_size}") _patch_fla_shared_memory() # 写一个分布式同步标记,通知其他 rank 补丁已完成 marker = _DATA_DIR / ".fla_patch_done" marker.write_text(f"patched by rank 0 at {datetime.now(timezone.utc).isoformat()}") else: # 等待 rank 0 完成补丁(最多等 60 秒) marker = _DATA_DIR / ".fla_patch_done" waited = 0 while not marker.exists() and waited < 60: time.sleep(1) waited += 1 if not marker.exists(): _remote_log(f"WARNING: rank {rank} timed out waiting for fla patch marker") else: # 单 GPU 模式 _patch_fla_shared_memory() if len(sys.argv) < 6: print("Usage: python -m app.engines.remote_train ") sys.exit(1) job_id = sys.argv[1] model_id = sys.argv[2] model_type = sys.argv[3] dataset_id = sys.argv[4] config_path = sys.argv[5] with open(config_path, encoding="utf-8") as f: config = json.load(f) asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config, rank=rank, local_rank=local_rank, world_size=world_size)) if __name__ == "__main__": main()