remote_train.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """远程训练入口脚本 — 在算力节点上执行。
  2. 不依赖 app.config / app.core.logging,避免引入 pydantic-settings / sqlalchemy 等额外包。
  3. """
  4. import asyncio
  5. import json
  6. import os
  7. import re
  8. import sys
  9. import time
  10. import traceback
  11. from datetime import datetime, timezone
  12. from pathlib import Path
  13. # 禁用 FlashAttention
  14. os.environ["PYTORCH_NO_FLASH"] = "1"
  15. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  16. # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存)
  17. os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  18. # 禁用 torch.compile,避免 fork 大量 inductor worker 进程
  19. os.environ["PT2_COMPILE"] = "0"
  20. os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1"
  21. # 限制训练只用 GPU 3(GPU 0/1 被 VLLM 占用,GPU 2 已占用)
  22. # CUDA_VISIBLE_DEVICES 将 3 映射为容器内的 cuda:0
  23. # device_map 中使用相对编号 0(即物理 GPU 3)
  24. os.environ["CUDA_VISIBLE_DEVICES"] = "3"
  25. # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU
  26. os.environ["MACA_MPS_MODE"] = "1"
  27. _progress_log_file = None
  28. # 直接从环境变量读取配置,避免引入 pydantic-settings
  29. _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
  30. _PROCESSED_DIR = _DATA_DIR / "processed"
  31. _ADAPTERS_DIR = _DATA_DIR / "adapters"
  32. _MODELS_DIR = _DATA_DIR / "models"
  33. def _remote_log(msg: str):
  34. """打印到 stderr(即远程训练日志 /tmp/train_{job_id}.log)。"""
  35. print(f"[remote_train] {msg}", file=sys.stderr)
  36. def _init_log_file(job_id: str):
  37. """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
  38. global _progress_log_file
  39. log_dir = _DATA_DIR / "logs"
  40. log_dir.mkdir(parents=True, exist_ok=True)
  41. _progress_log_file = log_dir / f"{job_id}.jsonl"
  42. _write_log(type="start", job_id=job_id)
  43. def _write_log(**kwargs):
  44. """追加一行 JSON 到共享日志文件。"""
  45. if _progress_log_file:
  46. entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
  47. with open(_progress_log_file, "a", encoding="utf-8") as f:
  48. f.write(json.dumps(entry, ensure_ascii=False) + "\n")
  49. f.flush()
  50. class FileProgressCallback:
  51. """HuggingFace Trainer 回调 — 写进度到共享日志文件。
  52. 只实现关心的回调,其余通过 __getattr__ 自动忽略。
  53. """
  54. def __init__(self, job_id: str):
  55. self.job_id = job_id
  56. def on_log(self, args, state, control, logs=None, **kwargs):
  57. if logs and "loss" in logs:
  58. _write_log(type="progress", epoch=int(state.epoch or 0),
  59. step=state.global_step, total_steps=state.max_steps or 0,
  60. loss=round(logs["loss"], 4),
  61. learning_rate=round(logs.get("learning_rate", 0), 8))
  62. def on_epoch_begin(self, args, state, control, **kwargs):
  63. _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
  64. def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
  65. _write_log(type="epoch_done", epoch=int(state.epoch or 0),
  66. eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
  67. eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
  68. def on_train_end(self, args, state, control, **kwargs):
  69. _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
  70. adapter_path=args.output_dir)
  71. def on_train_begin(self, args, state, control, **kwargs):
  72. _write_log(type="status", status="training")
  73. def on_save(self, args, state, control, **kwargs):
  74. _write_log(type="save", step=state.global_step)
  75. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  76. if metrics:
  77. _write_log(type="evaluate", epoch=int(state.epoch or 0),
  78. eval_loss=metrics.get("eval_loss"),
  79. eval_accuracy=metrics.get("eval_accuracy"))
  80. def __getattr__(self, name):
  81. """Trainer 期望其他回调方法存在,返回一个空函数自动忽略。"""
  82. return lambda *args, **kwargs: None
  83. async def run_training(job_id: str, model_id: str, model_type: str, dataset_path: str, config: dict):
  84. """执行单个训练任务(远程调用入口)。"""
  85. _init_log_file(job_id)
  86. _remote_log(f"=== Training job started: {job_id} ===")
  87. _remote_log(f"model_id={model_id}, model_type={model_type}")
  88. _remote_log(f"dataset_path={dataset_path}")
  89. _remote_log(f"config={json.dumps(config, ensure_ascii=False)[:200]}")
  90. try:
  91. # dataset_path 由主节点直接传入
  92. if not dataset_path or not Path(dataset_path).exists():
  93. raise FileNotFoundError(f"Dataset not found: {dataset_path}")
  94. _remote_log(f"Dataset file exists: {dataset_path}")
  95. _write_log(type="status", status="preprocessing")
  96. _remote_log("Step 1: Preprocessing dataset...")
  97. # 预处理
  98. processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl")
  99. task_type = config.get("task_type", "sft")
  100. template = config.get("dataset_template", "auto")
  101. _remote_log(f" task_type={task_type}, template={template}")
  102. _remote_log(f" output_path={processed_path}")
  103. # 选择引擎
  104. _remote_log(f" Selecting engine for model_type={model_type}...")
  105. if model_type == "vision":
  106. from app.engines.vision_engine import vision_engine
  107. engine = vision_engine
  108. elif model_type == "multimodal":
  109. from app.engines.multimodal_engine import multimodal_engine
  110. engine = multimodal_engine
  111. else:
  112. from app.engines.text_engine import text_engine
  113. engine = text_engine
  114. _remote_log(f" Engine loaded: {engine.__class__.__name__}")
  115. peft_method = config.get("peft_method", "lora")
  116. _remote_log(f" PEFT method: {peft_method}")
  117. _remote_log(" Running preprocess_dataset...")
  118. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  119. _remote_log(f" Preprocessing done, output: {processed_path}")
  120. _write_log(type="status", status="loading_model")
  121. _remote_log(f"Step 2: Loading model: {model_id}...")
  122. # 加载模型
  123. quantization_mode = "4bit" if peft_method == "qlora" else None
  124. _remote_log(f" Quantization: {quantization_mode}")
  125. await engine.load_model(model_id, quantization=quantization_mode)
  126. _remote_log(" Model loaded successfully")
  127. # 构建 PEFT 配置
  128. _remote_log("Step 3: Building PEFT config...")
  129. peft_config = engine.get_peft_config(peft_method, config)
  130. _remote_log(" PEFT config built")
  131. # PPO 训练需要预下载奖励模型
  132. reward_type = config.get("reward_type", "heuristic")
  133. reward_model_path = config.get("reward_model_path")
  134. if reward_type == "model" and reward_model_path:
  135. _remote_log(f"Step 3.5: Pre-downloading reward model: {reward_model_path}...")
  136. reward_local = str(_MODELS_DIR / reward_model_path.replace("/", "_"))
  137. if not (Path(reward_local) / "config.json").exists():
  138. from huggingface_hub import snapshot_download
  139. snapshot_download(
  140. repo_id=reward_model_path,
  141. local_dir=reward_local,
  142. local_dir_use_symlinks=False,
  143. )
  144. _remote_log(f" Reward model downloaded to: {reward_local}")
  145. else:
  146. _remote_log(f" Reward model already exists: {reward_local}")
  147. config["reward_model_path"] = reward_local # 覆盖为本地路径
  148. _write_log(type="status", status="training")
  149. _remote_log("Step 4: Starting training...")
  150. # 训练 — 传入文件日志回调替代 WebSocket 回调
  151. start_time = time.time()
  152. file_cb = FileProgressCallback(job_id)
  153. adapter_path = await engine.train(
  154. job_id=job_id,
  155. dataset_path=processed_path,
  156. peft_config=peft_config,
  157. training_args=config,
  158. callbacks=[file_cb],
  159. )
  160. elapsed = round(time.time() - start_time, 2)
  161. _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
  162. _remote_log(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
  163. _remote_log(f"=== Training job finished: {job_id} ===")
  164. return adapter_path
  165. except Exception as e:
  166. tb = traceback.format_exc()
  167. _write_log(type="error", message=str(e), traceback=tb)
  168. _remote_log(f"ERROR: {e}")
  169. _remote_log(tb)
  170. _remote_log(f"=== Training job failed: {job_id} ===")
  171. raise
  172. def _patch_fla_shared_memory():
  173. """修复 fla 库 Triton kernel 共享内存溢出问题。
  174. Qwen3.5 等混合架构模型的 Gated Delta Rule 层使用 fla 库的 Triton kernel,
  175. 反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存,
  176. 但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。
  177. 修复方式:在 fla 模块首次导入前,全面降低所有 block size 相关的值:
  178. 1. blockdim64 → blockdim32(kernel 函数名后缀)
  179. 2. 所有 = 64 的赋值/参数 → = 32(覆盖 BT/BK/BV/chunk_size 等变量名)
  180. 3. tl.constexpr 值为 128/256 的也降为 64
  181. """
  182. try:
  183. import shutil
  184. import site
  185. fla_base = None
  186. # 优先检查 conda 环境路径
  187. conda_path = '/opt/conda/lib/python3.10/site-packages/fla'
  188. if os.path.isdir(conda_path):
  189. fla_base = conda_path
  190. else:
  191. for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']:
  192. candidate = os.path.join(sp, 'fla')
  193. if os.path.isdir(candidate):
  194. fla_base = candidate
  195. break
  196. if not fla_base:
  197. _remote_log("fla package not found, skipping shared memory patch")
  198. return
  199. _remote_log(f"fla package found at: {fla_base}")
  200. # 幂等检查
  201. marker_path = os.path.join(fla_base, '_PATCHED_SM32')
  202. if os.path.exists(marker_path):
  203. _remote_log("fla shared memory patch already applied (marker found), skipping")
  204. return
  205. patched_files = []
  206. for root, dirs, files in os.walk(fla_base):
  207. for fname in files:
  208. if not fname.endswith('.py'):
  209. continue
  210. fpath = os.path.join(root, fname)
  211. try:
  212. with open(fpath, 'r') as f:
  213. original = f.read()
  214. c = original
  215. changes = []
  216. # 1. blockdim64 → blockdim32(kernel 函数名后缀)
  217. if 'blockdim64' in c:
  218. c = c.replace('blockdim64', 'blockdim32')
  219. changes.append('blockdim64->blockdim32')
  220. # 2. = 64 赋值/参数 → = 32(覆盖 BT=64, BK=64, BV=64, chunk_size=64 等)
  221. def _r64(m):
  222. return f'{m.group(1)}= 32'
  223. new_c = re.sub(r'([=:])\s*64\b(?!\d)', _r64, c)
  224. if new_c != c:
  225. changes.append('=64 -> =32')
  226. c = new_c
  227. # 3. tl.constexpr = 128/256 → = 64(进一步降低大值)
  228. def _r_large(m):
  229. val = int(m.group(1))
  230. return f'tl.constexpr = {val // 2}'
  231. new_c = re.sub(r'tl\.constexpr\s*=\s*(128|256)\b', _r_large, c)
  232. if new_c != c:
  233. changes.append('constexpr 128/256 halved')
  234. c = new_c
  235. if c != original:
  236. with open(fpath, 'w') as f:
  237. f.write(c)
  238. patched_files.append(f"{os.path.relpath(fpath, fla_base)}({', '.join(changes)})")
  239. except Exception as e:
  240. _remote_log(f" Warning: failed to patch {fpath}: {e}")
  241. continue
  242. # 清理 __pycache__,确保下次 import 读新源码
  243. cache_count = 0
  244. for root, dirs, files in os.walk(fla_base):
  245. if '__pycache__' in dirs:
  246. shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True)
  247. cache_count += 1
  248. # 清除已缓存的 fla 模块,强制重新导入
  249. to_remove = [k for k in sys.modules if k.startswith('fla')]
  250. for k in to_remove:
  251. del sys.modules[k]
  252. # 写入标记文件,下次运行时跳过(幂等)
  253. with open(marker_path, 'w') as f:
  254. f.write(f"patched at {datetime.now(timezone.utc).isoformat()}\n")
  255. _remote_log(f"fla shared memory patch done: {len(patched_files)} files patched, "
  256. f"{cache_count} caches cleared, {len(to_remove)} modules evicted")
  257. for pf in patched_files:
  258. _remote_log(f" patched: {pf}")
  259. except Exception as e:
  260. _remote_log(f"Warning: fla shared memory patch failed: {e}")
  261. import traceback as tb
  262. _remote_log(tb.format_exc())
  263. def main():
  264. """命令行入口:python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>"""
  265. # 在导入任何 fla 模块之前,修补 Triton kernel 共享内存问题
  266. _patch_fla_shared_memory()
  267. if len(sys.argv) < 6:
  268. print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_path> <config_file>")
  269. sys.exit(1)
  270. job_id = sys.argv[1]
  271. model_id = sys.argv[2]
  272. model_type = sys.argv[3]
  273. dataset_id = sys.argv[4]
  274. config_path = sys.argv[5]
  275. with open(config_path, encoding="utf-8") as f:
  276. config = json.load(f)
  277. asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config))
  278. if __name__ == "__main__":
  279. main()