remote_train.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. """远程训练入口脚本 — 在算力节点上执行。
  2. 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。
  3. """
  4. import asyncio
  5. import json
  6. import os
  7. import re
  8. import sys
  9. import time
  10. import traceback
  11. from datetime import datetime, timezone
  12. from pathlib import Path
  13. # 禁用 FlashAttention
  14. os.environ["PYTORCH_NO_FLASH"] = "1"
  15. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  16. # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存)
  17. os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  18. # 禁用 torch.compile,避免 fork 大量 inductor worker 进程
  19. os.environ["PT2_COMPILE"] = "0"
  20. os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1"
  21. # CUDA_VISIBLE_DEVICES 由 docker exec 层设置,此处不再覆盖
  22. # 单 GPU 模式: "3" (物理 GPU 3 → 逻辑 cuda:0)
  23. # 多 GPU 模式: "2,3" (物理 GPU 2,3 → 逻辑 cuda:0,1)
  24. # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU
  25. os.environ["MACA_MPS_MODE"] = "1"
  26. # Triton 编译缓存 — 持久化编译产物,避免每次训练都重新编译
  27. os.environ["TRITON_CACHE_DIR"] = "/root/Fine-tuning/backend/data/.triton_cache"
  28. os.environ["TRITON_HOME"] = "/root/Fine-tuning/backend/data/.triton_cache"
  29. # 减少 Triton 编译时的冗余输出
  30. os.environ["MLIR_ENABLE_DUMP"] = "0"
  31. os.environ["TRITON_INTERPRET"] = "0"
  32. _progress_log_file = None
  33. # 直接从环境变量读取配置,避免引入 pydantic-settings
  34. _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
  35. _PROCESSED_DIR = _DATA_DIR / "processed"
  36. _ADAPTERS_DIR = _DATA_DIR / "adapters"
  37. _MODELS_DIR = _DATA_DIR / "models"
  38. def _remote_log(msg: str):
  39. """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。"""
  40. print(f"[remote_train] {msg}", file=sys.stderr)
  41. def _init_log_file(job_id: str):
  42. """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
  43. global _progress_log_file
  44. log_dir = _DATA_DIR / "logs"
  45. log_dir.mkdir(parents=True, exist_ok=True)
  46. _progress_log_file = log_dir / f"{job_id}.jsonl"
  47. _write_log(type="start", job_id=job_id)
  48. def _write_log(**kwargs):
  49. """追加一行 JSON 到共享日志文件。"""
  50. if _progress_log_file:
  51. entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
  52. with open(_progress_log_file, "a", encoding="utf-8") as f:
  53. f.write(json.dumps(entry, ensure_ascii=False) + "\n")
  54. f.flush()
  55. class FileProgressCallback:
  56. """HuggingFace Trainer 回调 — 写进度到共享日志文件。
  57. 只实现关心的回调,其余通过 __getattr__ 自动忽略。
  58. """
  59. def __init__(self, job_id: str):
  60. self.job_id = job_id
  61. def on_step_end(self, args, state, control, **kwargs):
  62. """每步结束记录进度,便于观察训练是否在推进。"""
  63. _write_log(type="step", step=state.global_step,
  64. total_steps=state.max_steps or 0,
  65. epoch=round(state.epoch or 0, 2))
  66. if state.global_step % 5 == 0 or state.global_step <= 3:
  67. _remote_log(f"Step {state.global_step}/{state.max_steps} done (epoch {state.epoch:.2f})")
  68. def on_log(self, args, state, control, logs=None, **kwargs):
  69. if logs and "loss" in logs:
  70. _write_log(type="progress", epoch=int(state.epoch or 0),
  71. step=state.global_step, total_steps=state.max_steps or 0,
  72. loss=round(logs["loss"], 4),
  73. learning_rate=round(logs.get("learning_rate", 0), 8))
  74. def on_epoch_begin(self, args, state, control, **kwargs):
  75. _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
  76. def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
  77. _write_log(type="epoch_done", epoch=int(state.epoch or 0),
  78. eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
  79. eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
  80. def on_train_end(self, args, state, control, **kwargs):
  81. _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
  82. adapter_path=args.output_dir)
  83. def on_train_begin(self, args, state, control, **kwargs):
  84. _write_log(type="status", status="training")
  85. def on_save(self, args, state, control, **kwargs):
  86. _write_log(type="save", step=state.global_step)
  87. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  88. if metrics:
  89. _write_log(type="evaluate", epoch=int(state.epoch or 0),
  90. eval_loss=metrics.get("eval_loss"),
  91. eval_accuracy=metrics.get("eval_accuracy"))
  92. def __getattr__(self, name):
  93. """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。"""
  94. return lambda *args, **kwargs: None
  95. async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict,
  96. rank: int = 0, local_rank: int = 0, world_size: int = 1):
  97. """执行单个训练任务(远程调用入口)。"""
  98. is_main = (rank == 0)
  99. # 只有 rank 0 写 JSONL 进度日志,避免多进程文件竞争
  100. if is_main:
  101. _init_log_file(job_id)
  102. _remote_log(f"[rank {rank}] === Training job started: {job_id} ===")
  103. if is_main:
  104. _remote_log(f"model_id={model_id}, model_type={model_type}")
  105. _remote_log(f"dataset_path={dataset_path}")
  106. _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}")
  107. if world_size > 1:
  108. _remote_log(f"DDP: world_size={world_size}, batch_size per GPU={config.get('batch_size', 8)}")
  109. try:
  110. # dataset_path 由主节点直接传入
  111. if not dataset_path or not Path(dataset_path).exists():
  112. raise FileNotFoundError(f"Dataset not found: {dataset_path}")
  113. if is_main:
  114. _write_log(type="status", status="preprocessing")
  115. _remote_log("Step 1: Preprocessing dataset...")
  116. # 预处理 — DDP 模式下只有 rank 0 执行,其他 rank 等待
  117. processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl")
  118. task_type = config.get("task_type", "sft")
  119. template = config.get("dataset_template", "auto")
  120. if is_main:
  121. _remote_log(f" task_type={task_type}, template={template}")
  122. # 选择引擎
  123. if model_type == "vision":
  124. from app.engines.vision_engine import vision_engine
  125. engine = vision_engine
  126. elif model_type == "multimodal":
  127. from app.engines.multimodal_engine import multimodal_engine
  128. engine = multimodal_engine
  129. else:
  130. from app.engines.text_engine import text_engine
  131. engine = text_engine
  132. _remote_log(f" Engine loaded: {engine.__class__.__name__}")
  133. _remote_log(" Running preprocess_dataset...")
  134. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  135. _remote_log(f" Preprocessing done, output: {processed_path}")
  136. # 写同步标记,通知其他 rank 预处理完成
  137. if world_size > 1:
  138. done_marker = _PROCESSED_DIR / f"{job_id}_preprocess_done"
  139. done_marker.write_text("done")
  140. else:
  141. # 非 rank 0 等待预处理完成
  142. done_marker = _PROCESSED_DIR / f"{job_id}_preprocess_done"
  143. waited = 0
  144. while not done_marker.exists() and waited < 120:
  145. await asyncio.sleep(1)
  146. waited += 1
  147. # 选择引擎(所有 rank 都需要)
  148. if model_type == "vision":
  149. from app.engines.vision_engine import vision_engine
  150. engine = vision_engine
  151. elif model_type == "multimodal":
  152. from app.engines.multimodal_engine import multimodal_engine
  153. engine = multimodal_engine
  154. else:
  155. from app.engines.text_engine import text_engine
  156. engine = text_engine
  157. peft_method = config.get("peft_method", "lora")
  158. if is_main:
  159. _write_log(type="status", status="loading_model")
  160. _remote_log(f"Step 2: Loading model: {model_id}...")
  161. # 加载模型 — 每个 rank 各自加载到自己的 GPU
  162. quantization_mode = "4bit" if peft_method == "qlora" else None
  163. await engine.load_model(model_id, quantization=quantization_mode)
  164. if is_main:
  165. _remote_log(" Model loaded successfully")
  166. # 构建 PEFT 配置
  167. if is_main:
  168. _remote_log("Step 3: Building PEFT config...")
  169. peft_config = engine.get_peft_config(peft_method, config)
  170. # PPO 训练需要预下载奖励模型(只在 rank 0 下载)
  171. reward_type = config.get("reward_type", "heuristic")
  172. reward_model_path = config.get("reward_model_path")
  173. if reward_type == "model" and reward_model_path:
  174. if is_main:
  175. _remote_log(f"Step 3.5: Pre-downloading reward model: {reward_model_path}...")
  176. reward_local = str(_MODELS_DIR / reward_model_path.replace("/", "_"))
  177. if not (Path(reward_local) / "config.json").exists():
  178. from huggingface_hub import snapshot_download
  179. snapshot_download(
  180. repo_id=reward_model_path,
  181. local_dir=reward_local,
  182. local_dir_use_symlinks=False,
  183. )
  184. config["reward_model_path"] = reward_local
  185. if is_main:
  186. _write_log(type="status", status="training")
  187. _remote_log("Step 4: Starting training...")
  188. _remote_log("NOTE: First step may take 2-5 minutes due to Triton kernel compilation (autotuning). This is normal.")
  189. _remote_log(f"Total steps: {config.get('epochs', 3)} epochs, batch_size per GPU={config.get('batch_size', 8)}")
  190. # 训练 — 传入文件日志回调(只在 rank 0 写日志)
  191. start_time = time.time()
  192. file_cb = FileProgressCallback(job_id) if is_main else None
  193. callbacks = [file_cb] if file_cb else []
  194. adapter_path = await engine.train(
  195. job_id=job_id,
  196. dataset_path=processed_path,
  197. peft_config=peft_config,
  198. training_args=config,
  199. callbacks=callbacks,
  200. )
  201. if is_main:
  202. elapsed = round(time.time() - start_time, 2)
  203. _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
  204. _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
  205. _remote_log(f"=== Training job finished: {job_id} ===")
  206. return adapter_path
  207. except Exception as e:
  208. tb = traceback.format_exc()
  209. if is_main:
  210. _write_log(type="error", message=str(e), traceback=tb)
  211. _remote_log(f"[rank {rank}] ERROR: {e}")
  212. _remote_log(tb)
  213. if is_main:
  214. _remote_log(f"=== Training job failed: {job_id} ===")
  215. raise
  216. def _patch_fla_shared_memory():
  217. """修复 fla 库 Triton kernel 共享内存溢出问题。
  218. Qwen3.5 等混合架构模型的 Gated Delta Rule 层使用 fla 库的 Triton kernel,
  219. 反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存,
  220. 但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。
  221. 修复方式:精确替换 fla 库中控制 block size 的常量和 kernel 名称,
  222. 避免误改普通代码中的数字字面量。
  223. """
  224. try:
  225. import shutil
  226. # 定位 fla 包
  227. conda_path = '/opt/conda/lib/python3.10/site-packages/fla'
  228. if os.path.isdir(conda_path):
  229. fla_base = conda_path
  230. else:
  231. import site
  232. fla_base = None
  233. for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']:
  234. candidate = os.path.join(sp, 'fla')
  235. if os.path.isdir(candidate):
  236. fla_base = candidate
  237. break
  238. if not fla_base:
  239. _remote_log("fla package not found, skipping shared memory patch")
  240. return
  241. _remote_log(f"fla package found at: {fla_base}")
  242. # 检查 fla 源码是否被旧版补丁损坏(语法错误)
  243. chunk_py = os.path.join(fla_base, 'ops', 'gated_delta_rule', 'chunk.py')
  244. source_corrupted = False
  245. if os.path.exists(chunk_py):
  246. try:
  247. with open(chunk_py, 'r') as f:
  248. compile(f.read(), chunk_py, 'exec')
  249. except SyntaxError as e:
  250. source_corrupted = True
  251. _remote_log(f"fla source corrupted (SyntaxError: {e}), will reinstall...")
  252. # 幂等检查
  253. marker_path = os.path.join(fla_base, '_PATCHED_SM32')
  254. if os.path.exists(marker_path) and not source_corrupted:
  255. # 检查标记版本:v1 是旧版补丁(用激进正则,已污染源码),需要重装后重新打补丁
  256. with open(marker_path) as mf:
  257. marker_content = mf.read()
  258. if 'v2' in marker_content:
  259. _remote_log("fla shared memory patch v2 already applied, skipping")
  260. to_remove = [k for k in sys.modules if k.startswith('fla')]
  261. for k in to_remove:
  262. del sys.modules[k]
  263. return
  264. else:
  265. source_corrupted = True
  266. _remote_log("Old patch v1 detected, will reinstall fla...")
  267. if source_corrupted:
  268. _remote_log("WARNING: fla source is corrupted (SyntaxError). "
  269. "Please rebuild the container to restore clean source.")
  270. return
  271. patched_files = []
  272. for root, dirs, files in os.walk(fla_base):
  273. for fname in files:
  274. if not fname.endswith('.py'):
  275. continue
  276. fpath = os.path.join(root, fname)
  277. try:
  278. with open(fpath, 'r') as f:
  279. original = f.read()
  280. c = original
  281. changes = []
  282. # 0. 修复 fla/utils.py 中的 maca 设备映射(沐曦 GPU 兼容)
  283. if fname == 'utils.py' and "!= 'hip'" in c:
  284. # 把 maca 也映射到 cuda
  285. c = c.replace(
  286. "device = get_available_device() if get_available_device() != 'hip' else 'cuda'",
  287. "device = get_available_device() if get_available_device() not in ('hip', 'maca') else 'cuda'"
  288. )
  289. changes.append('maca->cuda mapping')
  290. # 1. kernel 函数名后缀: blockdim64 → blockdim32
  291. if 'blockdim64' in c:
  292. c = c.replace('blockdim64', 'blockdim32')
  293. changes.append('blockdim64->blockdim32')
  294. # 2. 精确匹配 fla 中常见的 block size 变量赋值
  295. # BT = 64, BK = 64, BV = 64, chunk_size = 64, BLOCK_SIZE = 64 等
  296. # 用 \b 匹配完整变量名,避免误改其他代码
  297. for var in ['BT', 'BK', 'BV', 'chunk_size', 'BLOCK_SIZE',
  298. 'BLOCK_M', 'BLOCK_N', 'BLOCK_K', 'BLOCK_V',
  299. 'block_size', 'block_m', 'block_n', 'block_k', 'block_v']:
  300. pattern = rf'\b{var}\s*=\s*64\b'
  301. replacement = f'{var} = 32'
  302. new_c = re.sub(pattern, replacement, c)
  303. if new_c != c:
  304. changes.append(f'{var}=64->32')
  305. c = new_c
  306. # 3. Triton autotune 装饰器中的 configs 参数值
  307. # 例如: configs=[..., 64, ...] 或 tl.constexpr = 64
  308. # 只替换 tl.constexpr = 64 的情况
  309. pattern = r'tl\.constexpr\s*=\s*64\b'
  310. new_c = re.sub(pattern, 'tl.constexpr = 32', c)
  311. if new_c != c:
  312. changes.append('tl.constexpr 64->32')
  313. c = new_c
  314. # 4. num_stages 降低(减少流水线阶段,进一步降低共享内存)
  315. pattern = r'num_stages\s*=\s*([3-9]|[1-9]\d+)'
  316. new_c = re.sub(pattern, 'num_stages=1', c)
  317. if new_c != c:
  318. changes.append('num_stages->1')
  319. c = new_c
  320. if c != original:
  321. with open(fpath, 'w') as f:
  322. f.write(c)
  323. patched_files.append(f"{os.path.relpath(fpath, fla_base)}({', '.join(changes)})")
  324. except Exception as e:
  325. _remote_log(f" Warning: failed to patch {fpath}: {e}")
  326. continue
  327. # 清理 __pycache__
  328. cache_count = 0
  329. for root, dirs, files in os.walk(fla_base):
  330. if '__pycache__' in dirs:
  331. shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True)
  332. cache_count += 1
  333. # 清除已缓存的 fla 模块
  334. to_remove = [k for k in sys.modules if k.startswith('fla')]
  335. for k in to_remove:
  336. del sys.modules[k]
  337. # 写入标记文件(幂等),包含版本号 v2
  338. with open(marker_path, 'w') as f:
  339. f.write(f"v2 patched at {datetime.now(timezone.utc).isoformat()}\n")
  340. _remote_log(f"fla shared memory patch done: {len(patched_files)} files, "
  341. f"{cache_count} caches cleared, {len(to_remove)} modules evicted")
  342. for pf in patched_files:
  343. _remote_log(f" patched: {pf}")
  344. except Exception as e:
  345. _remote_log(f"Warning: fla shared memory patch failed: {e}")
  346. import traceback as tb
  347. _remote_log(tb.format_exc())
  348. def main():
  349. """命令行入口。
  350. 单 GPU: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>
  351. 多 GPU: torchrun --nproc_per_node=N -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>
  352. """
  353. # 解析 DDP 环境变量(torchrun 自动设置)
  354. rank = int(os.environ.get("RANK", "0"))
  355. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  356. world_size = int(os.environ.get("WORLD_SIZE", "1"))
  357. if world_size > 1:
  358. # DDP 模式:只有 rank 0 执行 fla 补丁,其他 rank 等待补丁完成
  359. if rank == 0:
  360. _remote_log(f"DDP mode: rank={rank}, local_rank={local_rank}, world_size={world_size}")
  361. _patch_fla_shared_memory()
  362. # 写一个分布式同步标记,通知其他 rank 补丁已完成
  363. marker = _DATA_DIR / ".fla_patch_done"
  364. marker.write_text(f"patched by rank 0 at {datetime.now(timezone.utc).isoformat()}")
  365. else:
  366. # 等待 rank 0 完成补丁(最多等 60 秒)
  367. marker = _DATA_DIR / ".fla_patch_done"
  368. waited = 0
  369. while not marker.exists() and waited < 60:
  370. time.sleep(1)
  371. waited += 1
  372. if not marker.exists():
  373. _remote_log(f"WARNING: rank {rank} timed out waiting for fla patch marker")
  374. else:
  375. # 单 GPU 模式
  376. _patch_fla_shared_memory()
  377. if len(sys.argv) < 6:
  378. print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>")
  379. sys.exit(1)
  380. job_id = sys.argv[1]
  381. model_id = sys.argv[2]
  382. model_type = sys.argv[3]
  383. dataset_id = sys.argv[4]
  384. config_path = sys.argv[5]
  385. with open(config_path, encoding="utf-8") as f:
  386. config = json.load(f)
  387. asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config,
  388. rank=rank, local_rank=local_rank, world_size=world_size))
  389. if __name__ == "__main__":
  390. main()