| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- """远程训练入口脚本 — 在算力节点上执行。
- 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。
- """
- import asyncio
- import json
- import os
- import re
- import sys
- import time
- import traceback
- from datetime import datetime, timezone
- from pathlib import Path
- # 禁用 FlashAttention
- os.environ["PYTORCH_NO_FLASH"] = "1"
- os.environ["FLASH_ATTENTION_ENABLED"] = "0"
- # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存)
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
- # 禁用 torch.compile,避免 fork 大量 inductor worker 进程
- os.environ["PT2_COMPILE"] = "0"
- os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1"
- # 限制训练只用 GPU 3(GPU 0/1 被 VLLM 占用,GPU 2 已占用)
- # CUDA_VISIBLE_DEVICES 将 3 映射为容器内的 cuda:0
- # device_map 中使用相对编号 0(即物理 GPU 3)
- os.environ["CUDA_VISIBLE_DEVICES"] = "3"
- # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU
- os.environ["MACA_MPS_MODE"] = "1"
- _progress_log_file = None
- # 直接从环境变量读取配置,避免引入 pydantic-settings
- _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
- _PROCESSED_DIR = _DATA_DIR / "processed"
- _ADAPTERS_DIR = _DATA_DIR / "adapters"
- _MODELS_DIR = _DATA_DIR / "models"
- def _remote_log(msg: str):
- """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。"""
- print(f"[remote_train] {msg}", file=sys.stderr)
- def _init_log_file(job_id: str):
- """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
- global _progress_log_file
- log_dir = _DATA_DIR / "logs"
- log_dir.mkdir(parents=True, exist_ok=True)
- _progress_log_file = log_dir / f"{job_id}.jsonl"
- _write_log(type="start", job_id=job_id)
- def _write_log(**kwargs):
- """追加一行 JSON 到共享日志文件。"""
- if _progress_log_file:
- entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
- with open(_progress_log_file, "a", encoding="utf-8") as f:
- f.write(json.dumps(entry, ensure_ascii=False) + "\n")
- f.flush()
- class FileProgressCallback:
- """HuggingFace Trainer 回调 — 写进度到共享日志文件。
- 只实现关心的回调,其余通过 __getattr__ 自动忽略。
- """
- def __init__(self, job_id: str):
- self.job_id = job_id
- def on_log(self, args, state, control, logs=None, **kwargs):
- if logs and "loss" in logs:
- _write_log(type="progress", epoch=int(state.epoch or 0),
- step=state.global_step, total_steps=state.max_steps or 0,
- loss=round(logs["loss"], 4),
- learning_rate=round(logs.get("learning_rate", 0), 8))
- def on_epoch_begin(self, args, state, control, **kwargs):
- _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
- def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
- _write_log(type="epoch_done", epoch=int(state.epoch or 0),
- eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
- eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
- def on_train_end(self, args, state, control, **kwargs):
- _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
- adapter_path=args.output_dir)
- def on_train_begin(self, args, state, control, **kwargs):
- _write_log(type="status", status="training")
- def on_save(self, args, state, control, **kwargs):
- _write_log(type="save", step=state.global_step)
- def on_evaluate(self, args, state, control, metrics=None, **kwargs):
- if metrics:
- _write_log(type="evaluate", epoch=int(state.epoch or 0),
- eval_loss=metrics.get("eval_loss"),
- eval_accuracy=metrics.get("eval_accuracy"))
- def __getattr__(self, name):
- """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。"""
- return lambda *args, **kwargs: None
- async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict):
- """执行单个训练任务(远程调用入口)。"""
- _init_log_file(job_id)
- _remote_log(f"=== Training job started: {job_id} ===")
- _remote_log(f"model_id={model_id}, model_type={model_type}")
- _remote_log(f"dataset_path={dataset_path}")
- _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}")
- try:
- # dataset_path 由主节点直接传入
- if not dataset_path or not Path(dataset_path).exists():
- raise FileNotFoundError(f"Dataset not found: {dataset_path}")
- _remote_log(f"Dataset file exists: {dataset_path}")
- _write_log(type="status", status="preprocessing")
- _remote_log("Step 1: Preprocessing dataset...")
- # 预处理
- processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl")
- task_type = config.get("task_type", "sft")
- template = config.get("dataset_template", "auto")
- _remote_log(f" task_type={task_type}, template={template}")
- _remote_log(f" output_path={processed_path}")
- # 选择引擎
- _remote_log(f" Selecting engine for model_type={model_type}...")
- if model_type == "vision":
- from app.engines.vision_engine import vision_engine
- engine = vision_engine
- elif model_type == "multimodal":
- from app.engines.multimodal_engine import multimodal_engine
- engine = multimodal_engine
- else:
- from app.engines.text_engine import text_engine
- engine = text_engine
- _remote_log(f" Engine loaded: {engine.__class__.__name__}")
- peft_method = config.get("peft_method", "lora")
- _remote_log(f" PEFT method: {peft_method}")
- _remote_log(" Running preprocess_dataset...")
- await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
- _remote_log(f" Preprocessing done, output: {processed_path}")
- _write_log(type="status", status="loading_model")
- _remote_log(f"Step 2: Loading model: {model_id}...")
- # 加载模型
- quantization_mode = "4bit" if peft_method == "qlora" else None
- _remote_log(f" Quantization: {quantization_mode}")
- await engine.load_model(model_id, quantization=quantization_mode)
- _remote_log(" Model loaded successfully")
- # 构建 PEFT 配置
- _remote_log("Step 3: Building PEFT config...")
- peft_config = engine.get_peft_config(peft_method, config)
- _remote_log(" PEFT config built")
- # PPO 训练需要预下载奖励模型
- reward_type = config.get("reward_type", "heuristic")
- reward_model_path = config.get("reward_model_path")
- if reward_type == "model" and reward_model_path:
- _remote_log(f"Step 3.5: Pre-downloading reward model: {reward_model_path}...")
- reward_local = str(_MODELS_DIR / reward_model_path.replace("/", "_"))
- if not (Path(reward_local) / "config.json").exists():
- from huggingface_hub import snapshot_download
- snapshot_download(
- repo_id=reward_model_path,
- local_dir=reward_local,
- local_dir_use_symlinks=False,
- )
- _remote_log(f" Reward model downloaded to: {reward_local}")
- else:
- _remote_log(f" Reward model already exists: {reward_local}")
- config["reward_model_path"] = reward_local # 覆盖为本地路径
- _write_log(type="status", status="training")
- _remote_log("Step 4: Starting training...")
- # 训练 — 传入文件日志回调替代 WebSocket 回调
- start_time = time.time()
- file_cb = FileProgressCallback(job_id)
- adapter_path = await engine.train(
- job_id=job_id,
- dataset_path=processed_path,
- peft_config=peft_config,
- training_args=config,
- callbacks=[file_cb],
- )
- elapsed = round(time.time() - start_time, 2)
- _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
- _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
- _remote_log(f"=== Training job finished: {job_id} ===")
- return adapter_path
- except Exception as e:
- tb = traceback.format_exc()
- _write_log(type="error", message=str(e), traceback=tb)
- _remote_log(f"ERROR: {e}")
- _remote_log(tb)
- _remote_log(f"=== Training job failed: {job_id} ===")
- raise
- def _patch_fla_shared_memory():
- """修复 fla 库 Triton kernel 共享内存溢出问题。
- Qwen3.5 等混合架构模型的 Gated Delta Rule 层使用 fla 库的 Triton kernel,
- 反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存,
- 但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。
- 修复方式:精确替换 fla 库中控制 block size 的常量和 kernel 名称,
- 避免误改普通代码中的数字字面量。
- """
- try:
- import shutil
- # 定位 fla 包
- conda_path = '/opt/conda/lib/python3.10/site-packages/fla'
- if os.path.isdir(conda_path):
- fla_base = conda_path
- else:
- import site
- fla_base = None
- for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']:
- candidate = os.path.join(sp, 'fla')
- if os.path.isdir(candidate):
- fla_base = candidate
- break
- if not fla_base:
- _remote_log("fla package not found, skipping shared memory patch")
- return
- _remote_log(f"fla package found at: {fla_base}")
- # 检查 fla 源码是否被旧版补丁损坏(语法错误)
- chunk_py = os.path.join(fla_base, 'ops', 'gated_delta_rule', 'chunk.py')
- source_corrupted = False
- if os.path.exists(chunk_py):
- try:
- with open(chunk_py, 'r') as f:
- compile(f.read(), chunk_py, 'exec')
- except SyntaxError as e:
- source_corrupted = True
- _remote_log(f"fla source corrupted (SyntaxError: {e}), will reinstall...")
- # 幂等检查
- marker_path = os.path.join(fla_base, '_PATCHED_SM32')
- if os.path.exists(marker_path) and not source_corrupted:
- # 检查标记版本:v1 是旧版补丁(用激进正则,已污染源码),需要重装后重新打补丁
- with open(marker_path) as mf:
- marker_content = mf.read()
- if 'v2' in marker_content:
- _remote_log("fla shared memory patch v2 already applied, skipping")
- to_remove = [k for k in sys.modules if k.startswith('fla')]
- for k in to_remove:
- del sys.modules[k]
- return
- else:
- source_corrupted = True
- _remote_log("Old patch v1 detected, will reinstall fla...")
- if source_corrupted:
- _remote_log("Reinstalling fla to restore clean source...")
- import subprocess
- # 尝试多个可能的包名
- for pkg_name in ['fla', 'flash-linear-attention']:
- result = subprocess.run(
- [sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-deps', pkg_name],
- capture_output=True, text=True, timeout=120,
- )
- if result.returncode == 0:
- _remote_log(f"fla reinstalled successfully via '{pkg_name}'")
- break
- else:
- _remote_log(f"pip install '{pkg_name}' failed: {result.stderr[:200]}")
- # 清理旧标记
- if os.path.exists(marker_path):
- os.remove(marker_path)
- _remote_log("Reapplying patch v2...")
- patched_files = []
- for root, dirs, files in os.walk(fla_base):
- for fname in files:
- if not fname.endswith('.py'):
- continue
- fpath = os.path.join(root, fname)
- try:
- with open(fpath, 'r') as f:
- original = f.read()
- c = original
- changes = []
- # 1. kernel 函数名后缀: blockdim64 → blockdim32
- if 'blockdim64' in c:
- c = c.replace('blockdim64', 'blockdim32')
- changes.append('blockdim64->blockdim32')
- # 2. 精确匹配 fla 中常见的 block size 变量赋值
- # BT = 64, BK = 64, BV = 64, chunk_size = 64, BLOCK_SIZE = 64 等
- # 用 \b 匹配完整变量名,避免误改其他代码
- for var in ['BT', 'BK', 'BV', 'chunk_size', 'BLOCK_SIZE',
- 'BLOCK_M', 'BLOCK_N', 'BLOCK_K', 'BLOCK_V',
- 'block_size', 'block_m', 'block_n', 'block_k', 'block_v']:
- pattern = rf'\b{var}\s*=\s*64\b'
- replacement = f'{var} = 32'
- new_c = re.sub(pattern, replacement, c)
- if new_c != c:
- changes.append(f'{var}=64->32')
- c = new_c
- # 3. Triton autotune 装饰器中的 configs 参数值
- # 例如: configs=[..., 64, ...] 或 tl.constexpr = 64
- # 只替换 tl.constexpr = 64 的情况
- pattern = r'tl\.constexpr\s*=\s*64\b'
- new_c = re.sub(pattern, 'tl.constexpr = 32', c)
- if new_c != c:
- changes.append('tl.constexpr 64->32')
- c = new_c
- # 4. num_stages 降低(减少流水线阶段,进一步降低共享内存)
- pattern = r'num_stages\s*=\s*([3-9]|[1-9]\d+)'
- new_c = re.sub(pattern, 'num_stages=1', c)
- if new_c != c:
- changes.append('num_stages->1')
- c = new_c
- if c != original:
- with open(fpath, 'w') as f:
- f.write(c)
- patched_files.append(f"{os.path.relpath(fpath, fla_base)}({', '.join(changes)})")
- except Exception as e:
- _remote_log(f" Warning: failed to patch {fpath}: {e}")
- continue
- # 清理 __pycache__
- cache_count = 0
- for root, dirs, files in os.walk(fla_base):
- if '__pycache__' in dirs:
- shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True)
- cache_count += 1
- # 清除已缓存的 fla 模块
- to_remove = [k for k in sys.modules if k.startswith('fla')]
- for k in to_remove:
- del sys.modules[k]
- # 写入标记文件(幂等),包含版本号 v2
- with open(marker_path, 'w') as f:
- f.write(f"v2 patched at {datetime.now(timezone.utc).isoformat()}\n")
- _remote_log(f"fla shared memory patch done: {len(patched_files)} files, "
- f"{cache_count} caches cleared, {len(to_remove)} modules evicted")
- for pf in patched_files:
- _remote_log(f" patched: {pf}")
- except Exception as e:
- _remote_log(f"Warning: fla shared memory patch failed: {e}")
- import traceback as tb
- _remote_log(tb.format_exc())
- def main():
- """命令行入口:python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>"""
- # 在导入任何 fla 模块之前,修补 Triton kernel 共享内存问题
- _patch_fla_shared_memory()
- if len(sys.argv) < 6:
- print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>")
- sys.exit(1)
- job_id = sys.argv[1]
- model_id = sys.argv[2]
- model_type = sys.argv[3]
- dataset_id = sys.argv[4]
- config_path = sys.argv[5]
- with open(config_path, encoding="utf-8") as f:
- config = json.load(f)
- asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config))
- if __name__ == "__main__":
- main()
|