| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- """远程训练入口脚本 — 在算力节点上执行。"""
- import asyncio
- import json
- import os
- import sys
- import signal
- import time
- import traceback
- from datetime import datetime, timezone
- from pathlib import Path
- # 禁用 FlashAttention
- os.environ["PYTORCH_NO_FLASH"] = "1"
- os.environ["FLASH_ATTENTION_ENABLED"] = "0"
- _progress_log_file = None
- def _init_log_file(data_dir: Path, job_id: str):
- """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
- global _progress_log_file
- log_dir = data_dir / "logs"
- log_dir.mkdir(parents=True, exist_ok=True)
- _progress_log_file = log_dir / f"{job_id}.jsonl"
- _write_log(type="start", job_id=job_id)
- def _write_log(**kwargs):
- """追加一行 JSON 到共享日志文件。"""
- if _progress_log_file:
- entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
- with open(_progress_log_file, "a", encoding="utf-8") as f:
- f.write(json.dumps(entry, ensure_ascii=False) + "\n")
- f.flush()
- class FileProgressCallback:
- """HuggingFace Trainer 回调 — 写进度到共享日志文件。"""
- def __init__(self, job_id: str):
- self.job_id = job_id
- def on_log(self, args, state, control, logs=None, **kwargs):
- if logs and "loss" in logs:
- _write_log(type="progress", epoch=int(state.epoch or 0),
- step=state.global_step, total_steps=state.max_steps or 0,
- loss=round(logs["loss"], 4),
- learning_rate=round(logs.get("learning_rate", 0), 8))
- def on_epoch_begin(self, args, state, control, **kwargs):
- _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
- def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
- _write_log(type="epoch_done", epoch=int(state.epoch or 0),
- eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
- eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
- def on_train_end(self, args, state, control, **kwargs):
- _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
- adapter_path=args.output_dir)
- def on_train_begin(self, args, state, control, **kwargs):
- _write_log(type="status", status="training")
- def on_save(self, args, state, control, **kwargs):
- _write_log(type="save", step=state.global_step)
- def on_evaluate(self, args, state, control, metrics=None, **kwargs):
- if metrics:
- _write_log(type="evaluate", epoch=int(state.epoch or 0),
- eval_loss=metrics.get("eval_loss"),
- eval_accuracy=metrics.get("eval_accuracy"))
- async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict):
- """执行单个训练任务(远程调用入口)。"""
- from app.config import get_settings
- from app.core.logging import logger
- settings = get_settings()
- _init_log_file(settings.data_dir, job_id)
- try:
- # 查找数据集
- from app.core.db import async_session, DatasetRecord
- from sqlalchemy import select
- dataset_path = None
- async with async_session() as session:
- result = await session.execute(select(DatasetRecord).where(
- (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
- ))
- record = result.scalar_one_or_none()
- if record:
- dataset_path = record.file_path
- if not dataset_path:
- upload_path = settings.uploads_dir / dataset_id
- if upload_path.exists():
- dataset_path = str(upload_path)
- if not dataset_path:
- raise FileNotFoundError(f"Dataset not found: {dataset_id}")
- _write_log(type="status", status="preprocessing")
- # 预处理
- processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
- task_type = config.get("task_type", "sft")
- template = config.get("dataset_template", "alpaca")
- # 选择引擎
- if model_type == "vision":
- from app.engines.vision_engine import vision_engine
- engine = vision_engine
- elif model_type == "multimodal":
- from app.engines.multimodal_engine import multimodal_engine
- engine = multimodal_engine
- else:
- from app.engines.text_engine import text_engine
- engine = text_engine
- peft_method = config.get("peft_method", "lora")
- await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
- _write_log(type="status", status="loading_model")
- # 加载模型
- await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
- # 构建 PEFT 配置
- peft_config = engine.get_peft_config(peft_method, config)
- _write_log(type="status", status="training")
- # 训练 — 传入文件日志回调替代 WebSocket 回调
- start_time = time.time()
- file_cb = FileProgressCallback(job_id)
- adapter_path = await engine.train(
- job_id=job_id,
- dataset_path=processed_path,
- peft_config=peft_config,
- training_args=config,
- callbacks=[file_cb],
- )
- elapsed = round(time.time() - start_time, 2)
- _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
- logger.info(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
- return adapter_path
- except Exception as e:
- _write_log(type="error", message=str(e), traceback=traceback.format_exc())
- logger.error(f"Remote training failed: {job_id} - {e}")
- raise
- def main():
- """命令行入口:python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_id> <config_json>"""
- if len(sys.argv) < 6:
- print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_id> <config_json>")
- sys.exit(1)
- job_id = sys.argv[1]
- model_id = sys.argv[2]
- model_type = sys.argv[3]
- dataset_id = sys.argv[4]
- config = json.loads(sys.argv[5])
- asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config))
- if __name__ == "__main__":
- main()
|