"""远程训练入口脚本 — 在算力节点上执行。 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。 """ import asyncio import json import os import re import sys import time import traceback from datetime import datetime, timezone from pathlib import Path # 禁用 FlashAttention os.environ["PYTORCH_NO_FLASH"] = "1" os.environ["FLASH_ATTENTION_ENABLED"] = "0" # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存) os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # 禁用 torch.compile,避免 fork 大量 inductor worker 进程 os.environ["PT2_COMPILE"] = "0" os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1" # 限制训练只用 GPU 3(GPU 0/1 被 VLLM 占用,GPU 2 已占用) # CUDA_VISIBLE_DEVICES 将 3 映射为容器内的 cuda:0 # device_map 中使用相对编号 0(即物理 GPU 3) os.environ["CUDA_VISIBLE_DEVICES"] = "3" # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU os.environ["MACA_MPS_MODE"] = "1" _progress_log_file = None # 直接从环境变量读取配置,避免引入 pydantic-settings _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) _PROCESSED_DIR = _DATA_DIR / "processed" _ADAPTERS_DIR = _DATA_DIR / "adapters" _MODELS_DIR = _DATA_DIR / "models" def _remote_log(msg: str): """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。""" print(f"[remote_train] {msg}", file=sys.stderr) def _init_log_file(job_id: str): """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。""" global _progress_log_file log_dir = _DATA_DIR / "logs" log_dir.mkdir(parents=True, exist_ok=True) _progress_log_file = log_dir / f"{job_id}.jsonl" _write_log(type="start", job_id=job_id) def _write_log(**kwargs): """追加一行 JSON 到共享日志文件。""" if _progress_log_file: entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs} with open(_progress_log_file, "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") f.flush() class FileProgressCallback: """HuggingFace Trainer 回调 — 写进度到共享日志文件。 只实现关心的回调,其余通过 __getattr__ 自动忽略。 """ def __init__(self, job_id: str): self.job_id = job_id def on_log(self, args, state, control, logs=None, **kwargs): if logs and "loss" in logs: _write_log(type="progress", epoch=int(state.epoch or 0), step=state.global_step, total_steps=state.max_steps or 0, loss=round(logs["loss"], 4), learning_rate=round(logs.get("learning_rate", 0), 8)) def on_epoch_begin(self, args, state, control, **kwargs): _write_log(type="epoch_begin", epoch=int(state.epoch or 0)) def on_epoch_end(self, args, state, control, metrics=None, **kwargs): _write_log(type="epoch_done", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None, eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None) def on_train_end(self, args, state, control, **kwargs): _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0), adapter_path=args.output_dir) def on_train_begin(self, args, state, control, **kwargs): _write_log(type="status", status="training") def on_save(self, args, state, control, **kwargs): _write_log(type="save", step=state.global_step) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics: _write_log(type="evaluate", epoch=int(state.epoch or 0), eval_loss=metrics.get("eval_loss"), eval_accuracy=metrics.get("eval_accuracy")) def __getattr__(self, name): """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。""" return lambda *args, **kwargs: None async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict): """执行单个训练任务(远程调用入口)。""" _init_log_file(job_id) _remote_log(f"=== Training job started: {job_id} ===") _remote_log(f"model_id={model_id}, model_type={model_type}") _remote_log(f"dataset_path={dataset_path}") _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}") try: # dataset_path 由主节点直接传入 if not dataset_path or not Path(dataset_path).exists(): raise FileNotFoundError(f"Dataset not found: {dataset_path}") _remote_log(f"Dataset file exists: {dataset_path}") _write_log(type="status", status="preprocessing") _remote_log("Step 1: Preprocessing dataset...") # 预处理 processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "auto") _remote_log(f" task_type={task_type}, template={template}") _remote_log(f" output_path={processed_path}") # 选择引擎 _remote_log(f" Selecting engine for model_type={model_type}...") if model_type == "vision": from app.engines.vision_engine import vision_engine engine = vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine engine = multimodal_engine else: from app.engines.text_engine import text_engine engine = text_engine _remote_log(f" Engine loaded: {engine.__class__.__name__}") peft_method = config.get("peft_method", "lora") _remote_log(f" PEFT method: {peft_method}") _remote_log(" Running preprocess_dataset...") await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) _remote_log(f" Preprocessing done, output: {processed_path}") _write_log(type="status", status="loading_model") _remote_log(f"Step 2: Loading model: {model_id}...") # 加载模型 quantization_mode = "4bit" if peft_method == "qlora" else None _remote_log(f" Quantization: {quantization_mode}") await engine.load_model(model_id, quantization=quantization_mode) _remote_log(" Model loaded successfully") # 构建 PEFT 配置 _remote_log("Step 3: Building PEFT config...") peft_config = engine.get_peft_config(peft_method, config) _remote_log(" PEFT config built") # PPO 训练需要预下载奖励模型 reward_type = config.get("reward_type", "heuristic") reward_model_path = config.get("reward_model_path") if reward_type == "model" and reward_model_path: _remote_log(f"Step 3.5: Pre-downloading reward model: {reward_model_path}...") reward_local = str(_MODELS_DIR / reward_model_path.replace("/", "_")) if not (Path(reward_local) / "config.json").exists(): from huggingface_hub import snapshot_download snapshot_download( repo_id=reward_model_path, local_dir=reward_local, local_dir_use_symlinks=False, ) _remote_log(f" Reward model downloaded to: {reward_local}") else: _remote_log(f" Reward model already exists: {reward_local}") config["reward_model_path"] = reward_local # 覆盖为本地路径 _write_log(type="status", status="training") _remote_log("Step 4: Starting training...") # 训练 — 传入文件日志回调替代 WebSocket 回调 start_time = time.time() file_cb = FileProgressCallback(job_id) adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, callbacks=[file_cb], ) elapsed = round(time.time() - start_time, 2) _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed) _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)") _remote_log(f"=== Training job finished: {job_id} ===") return adapter_path except Exception as e: tb = traceback.format_exc() _write_log(type="error", message=str(e), traceback=tb) _remote_log(f"ERROR: {e}") _remote_log(tb) _remote_log(f"=== Training job failed: {job_id} ===") raise def _patch_fla_shared_memory(): """修复 fla 库 Triton kernel 共享内存溢出问题。 Qwen3.5 等混合架构模型的 Gated Delta Rule 层使用 fla 库的 Triton kernel, 反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存, 但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。 修复方式:精确替换 fla 库中控制 block size 的常量和 kernel 名称, 避免误改普通代码中的数字字面量。 """ try: import shutil # 定位 fla 包 conda_path = '/opt/conda/lib/python3.10/site-packages/fla' if os.path.isdir(conda_path): fla_base = conda_path else: import site fla_base = None for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']: candidate = os.path.join(sp, 'fla') if os.path.isdir(candidate): fla_base = candidate break if not fla_base: _remote_log("fla package not found, skipping shared memory patch") return _remote_log(f"fla package found at: {fla_base}") # 检查 fla 源码是否被旧版补丁损坏(语法错误) chunk_py = os.path.join(fla_base, 'ops', 'gated_delta_rule', 'chunk.py') source_corrupted = False if os.path.exists(chunk_py): try: with open(chunk_py, 'r') as f: compile(f.read(), chunk_py, 'exec') except SyntaxError as e: source_corrupted = True _remote_log(f"fla source corrupted (SyntaxError: {e}), will reinstall...") # 幂等检查 marker_path = os.path.join(fla_base, '_PATCHED_SM32') if os.path.exists(marker_path) and not source_corrupted: # 检查标记版本:v1 是旧版补丁(用激进正则,已污染源码),需要重装后重新打补丁 with open(marker_path) as mf: marker_content = mf.read() if 'v2' in marker_content: _remote_log("fla shared memory patch v2 already applied, skipping") to_remove = [k for k in sys.modules if k.startswith('fla')] for k in to_remove: del sys.modules[k] return else: source_corrupted = True _remote_log("Old patch v1 detected, will reinstall fla...") if source_corrupted: _remote_log("Reinstalling fla to restore clean source...") import subprocess # 尝试多个可能的包名 for pkg_name in ['fla', 'flash-linear-attention']: result = subprocess.run( [sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-deps', pkg_name], capture_output=True, text=True, timeout=120, ) if result.returncode == 0: _remote_log(f"fla reinstalled successfully via '{pkg_name}'") break else: _remote_log(f"pip install '{pkg_name}' failed: {result.stderr[:200]}") # 清理旧标记 if os.path.exists(marker_path): os.remove(marker_path) _remote_log("Reapplying patch v2...") patched_files = [] for root, dirs, files in os.walk(fla_base): for fname in files: if not fname.endswith('.py'): continue fpath = os.path.join(root, fname) try: with open(fpath, 'r') as f: original = f.read() c = original changes = [] # 1. kernel 函数名后缀: blockdim64 → blockdim32 if 'blockdim64' in c: c = c.replace('blockdim64', 'blockdim32') changes.append('blockdim64->blockdim32') # 2. 精确匹配 fla 中常见的 block size 变量赋值 # BT = 64, BK = 64, BV = 64, chunk_size = 64, BLOCK_SIZE = 64 等 # 用 \b 匹配完整变量名,避免误改其他代码 for var in ['BT', 'BK', 'BV', 'chunk_size', 'BLOCK_SIZE', 'BLOCK_M', 'BLOCK_N', 'BLOCK_K', 'BLOCK_V', 'block_size', 'block_m', 'block_n', 'block_k', 'block_v']: pattern = rf'\b{var}\s*=\s*64\b' replacement = f'{var} = 32' new_c = re.sub(pattern, replacement, c) if new_c != c: changes.append(f'{var}=64->32') c = new_c # 3. Triton autotune 装饰器中的 configs 参数值 # 例如: configs=[..., 64, ...] 或 tl.constexpr = 64 # 只替换 tl.constexpr = 64 的情况 pattern = r'tl\.constexpr\s*=\s*64\b' new_c = re.sub(pattern, 'tl.constexpr = 32', c) if new_c != c: changes.append('tl.constexpr 64->32') c = new_c # 4. num_stages 降低(减少流水线阶段,进一步降低共享内存) pattern = r'num_stages\s*=\s*([3-9]|[1-9]\d+)' new_c = re.sub(pattern, 'num_stages=1', c) if new_c != c: changes.append('num_stages->1') c = new_c if c != original: with open(fpath, 'w') as f: f.write(c) patched_files.append(f"{os.path.relpath(fpath, fla_base)}({', '.join(changes)})") except Exception as e: _remote_log(f" Warning: failed to patch {fpath}: {e}") continue # 清理 __pycache__ cache_count = 0 for root, dirs, files in os.walk(fla_base): if '__pycache__' in dirs: shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True) cache_count += 1 # 清除已缓存的 fla 模块 to_remove = [k for k in sys.modules if k.startswith('fla')] for k in to_remove: del sys.modules[k] # 写入标记文件(幂等),包含版本号 v2 with open(marker_path, 'w') as f: f.write(f"v2 patched at {datetime.now(timezone.utc).isoformat()}\n") _remote_log(f"fla shared memory patch done: {len(patched_files)} files, " f"{cache_count} caches cleared, {len(to_remove)} modules evicted") for pf in patched_files: _remote_log(f" patched: {pf}") except Exception as e: _remote_log(f"Warning: fla shared memory patch failed: {e}") import traceback as tb _remote_log(tb.format_exc()) def main(): """命令行入口:python -m app.engines.remote_train """ # 在导入任何 fla 模块之前,修补 Triton kernel 共享内存问题 _patch_fla_shared_memory() if len(sys.argv) < 6: print("Usage: python -m app.engines.remote_train ") sys.exit(1) job_id = sys.argv[1] model_id = sys.argv[2] model_type = sys.argv[3] dataset_id = sys.argv[4] config_path = sys.argv[5] with open(config_path, encoding="utf-8") as f: config = json.load(f) asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config)) if __name__ == "__main__": main()