|
@@ -4,12 +4,73 @@ import json
|
|
|
import os
|
|
import os
|
|
|
import sys
|
|
import sys
|
|
|
import signal
|
|
import signal
|
|
|
|
|
+import time
|
|
|
|
|
+import traceback
|
|
|
|
|
+from datetime import datetime, timezone
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
# 禁用 FlashAttention
|
|
# 禁用 FlashAttention
|
|
|
os.environ["PYTORCH_NO_FLASH"] = "1"
|
|
os.environ["PYTORCH_NO_FLASH"] = "1"
|
|
|
os.environ["FLASH_ATTENTION_ENABLED"] = "0"
|
|
os.environ["FLASH_ATTENTION_ENABLED"] = "0"
|
|
|
|
|
|
|
|
|
|
+_progress_log_file = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _init_log_file(data_dir: Path, job_id: str):
|
|
|
|
|
+ """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
|
|
|
|
|
+ global _progress_log_file
|
|
|
|
|
+ log_dir = data_dir / "logs"
|
|
|
|
|
+ log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ _progress_log_file = log_dir / f"{job_id}.jsonl"
|
|
|
|
|
+ _write_log(type="start", job_id=job_id)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _write_log(**kwargs):
|
|
|
|
|
+ """追加一行 JSON 到共享日志文件。"""
|
|
|
|
|
+ if _progress_log_file:
|
|
|
|
|
+ entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
|
|
|
|
|
+ with open(_progress_log_file, "a", encoding="utf-8") as f:
|
|
|
|
|
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
|
+ f.flush()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class FileProgressCallback:
|
|
|
|
|
+ """HuggingFace Trainer 回调 — 写进度到共享日志文件。"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, job_id: str):
|
|
|
|
|
+ self.job_id = job_id
|
|
|
|
|
+
|
|
|
|
|
+ def on_log(self, args, state, control, logs=None, **kwargs):
|
|
|
|
|
+ if logs and "loss" in logs:
|
|
|
|
|
+ _write_log(type="progress", epoch=int(state.epoch or 0),
|
|
|
|
|
+ step=state.global_step, total_steps=state.max_steps or 0,
|
|
|
|
|
+ loss=round(logs["loss"], 4),
|
|
|
|
|
+ learning_rate=round(logs.get("learning_rate", 0), 8))
|
|
|
|
|
+
|
|
|
|
|
+ def on_epoch_begin(self, args, state, control, **kwargs):
|
|
|
|
|
+ _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
|
|
|
|
|
+
|
|
|
|
|
+ def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
|
|
|
|
|
+ _write_log(type="epoch_done", epoch=int(state.epoch or 0),
|
|
|
|
|
+ eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
|
|
|
|
|
+ eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
|
|
|
|
|
+
|
|
|
|
|
+ def on_train_end(self, args, state, control, **kwargs):
|
|
|
|
|
+ _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
|
|
|
|
|
+ adapter_path=args.output_dir)
|
|
|
|
|
+
|
|
|
|
|
+ def on_train_begin(self, args, state, control, **kwargs):
|
|
|
|
|
+ _write_log(type="status", status="training")
|
|
|
|
|
+
|
|
|
|
|
+ def on_save(self, args, state, control, **kwargs):
|
|
|
|
|
+ _write_log(type="save", step=state.global_step)
|
|
|
|
|
+
|
|
|
|
|
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
|
|
|
|
+ if metrics:
|
|
|
|
|
+ _write_log(type="evaluate", epoch=int(state.epoch or 0),
|
|
|
|
|
+ eval_loss=metrics.get("eval_loss"),
|
|
|
|
|
+ eval_accuracy=metrics.get("eval_accuracy"))
|
|
|
|
|
+
|
|
|
|
|
|
|
|
async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict):
|
|
async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict):
|
|
|
"""执行单个训练任务(远程调用入口)。"""
|
|
"""执行单个训练任务(远程调用入口)。"""
|
|
@@ -17,66 +78,84 @@ async def run_training(job_id: str, model_id: str, model_type: str, dataset_id:
|
|
|
from app.core.logging import logger
|
|
from app.core.logging import logger
|
|
|
|
|
|
|
|
settings = get_settings()
|
|
settings = get_settings()
|
|
|
-
|
|
|
|
|
- # 查找数据集
|
|
|
|
|
- from app.core.db import async_session, DatasetRecord
|
|
|
|
|
- from sqlalchemy import select
|
|
|
|
|
-
|
|
|
|
|
- dataset_path = None
|
|
|
|
|
- async with async_session() as session:
|
|
|
|
|
- result = await session.execute(select(DatasetRecord).where(
|
|
|
|
|
- (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
|
|
|
|
|
- ))
|
|
|
|
|
- record = result.scalar_one_or_none()
|
|
|
|
|
- if record:
|
|
|
|
|
- dataset_path = record.file_path
|
|
|
|
|
-
|
|
|
|
|
- if not dataset_path:
|
|
|
|
|
- # 尝试 uploads 目录
|
|
|
|
|
- upload_path = settings.uploads_dir / dataset_id
|
|
|
|
|
- if upload_path.exists():
|
|
|
|
|
- dataset_path = str(upload_path)
|
|
|
|
|
-
|
|
|
|
|
- if not dataset_path:
|
|
|
|
|
- raise FileNotFoundError(f"Dataset not found: {dataset_id}")
|
|
|
|
|
-
|
|
|
|
|
- # 预处理
|
|
|
|
|
- processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
|
|
|
|
|
- task_type = config.get("task_type", "sft")
|
|
|
|
|
- template = config.get("dataset_template", "alpaca")
|
|
|
|
|
-
|
|
|
|
|
- # 选择引擎
|
|
|
|
|
- if model_type == "vision":
|
|
|
|
|
- from app.engines.vision_engine import vision_engine
|
|
|
|
|
- engine = vision_engine
|
|
|
|
|
- elif model_type == "multimodal":
|
|
|
|
|
- from app.engines.multimodal_engine import multimodal_engine
|
|
|
|
|
- engine = multimodal_engine
|
|
|
|
|
- else:
|
|
|
|
|
- from app.engines.text_engine import text_engine
|
|
|
|
|
- engine = text_engine
|
|
|
|
|
-
|
|
|
|
|
- peft_method = config.get("peft_method", "lora")
|
|
|
|
|
-
|
|
|
|
|
- # 预处理数据集
|
|
|
|
|
- await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
|
|
|
|
|
-
|
|
|
|
|
- # 加载模型
|
|
|
|
|
- await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
|
|
|
|
|
-
|
|
|
|
|
- # 构建 PEFT 配置
|
|
|
|
|
- peft_config = engine.get_peft_config(peft_method, config)
|
|
|
|
|
-
|
|
|
|
|
- # 训练
|
|
|
|
|
- adapter_path = await engine.train(
|
|
|
|
|
- job_id=job_id,
|
|
|
|
|
- dataset_path=processed_path,
|
|
|
|
|
- peft_config=peft_config,
|
|
|
|
|
- training_args=config,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- logger.info(f"Remote training completed: {job_id} -> {adapter_path}")
|
|
|
|
|
- return adapter_path
|
|
|
|
|
|
|
+ _init_log_file(settings.data_dir, job_id)
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 查找数据集
|
|
|
|
|
+ from app.core.db import async_session, DatasetRecord
|
|
|
|
|
+ from sqlalchemy import select
|
|
|
|
|
+
|
|
|
|
|
+ dataset_path = None
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(select(DatasetRecord).where(
|
|
|
|
|
+ (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
|
|
|
|
|
+ ))
|
|
|
|
|
+ record = result.scalar_one_or_none()
|
|
|
|
|
+ if record:
|
|
|
|
|
+ dataset_path = record.file_path
|
|
|
|
|
+
|
|
|
|
|
+ if not dataset_path:
|
|
|
|
|
+ upload_path = settings.uploads_dir / dataset_id
|
|
|
|
|
+ if upload_path.exists():
|
|
|
|
|
+ dataset_path = str(upload_path)
|
|
|
|
|
+
|
|
|
|
|
+ if not dataset_path:
|
|
|
|
|
+ raise FileNotFoundError(f"Dataset not found: {dataset_id}")
|
|
|
|
|
+
|
|
|
|
|
+ _write_log(type="status", status="preprocessing")
|
|
|
|
|
+
|
|
|
|
|
+ # 预处理
|
|
|
|
|
+ processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
|
|
|
|
|
+ task_type = config.get("task_type", "sft")
|
|
|
|
|
+ template = config.get("dataset_template", "alpaca")
|
|
|
|
|
+
|
|
|
|
|
+ # 选择引擎
|
|
|
|
|
+ if model_type == "vision":
|
|
|
|
|
+ from app.engines.vision_engine import vision_engine
|
|
|
|
|
+ engine = vision_engine
|
|
|
|
|
+ elif model_type == "multimodal":
|
|
|
|
|
+ from app.engines.multimodal_engine import multimodal_engine
|
|
|
|
|
+ engine = multimodal_engine
|
|
|
|
|
+ else:
|
|
|
|
|
+ from app.engines.text_engine import text_engine
|
|
|
|
|
+ engine = text_engine
|
|
|
|
|
+
|
|
|
|
|
+ peft_method = config.get("peft_method", "lora")
|
|
|
|
|
+
|
|
|
|
|
+ await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
|
|
|
|
|
+
|
|
|
|
|
+ _write_log(type="status", status="loading_model")
|
|
|
|
|
+
|
|
|
|
|
+ # 加载模型
|
|
|
|
|
+ await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
|
|
|
|
|
+
|
|
|
|
|
+ # 构建 PEFT 配置
|
|
|
|
|
+ peft_config = engine.get_peft_config(peft_method, config)
|
|
|
|
|
+
|
|
|
|
|
+ _write_log(type="status", status="training")
|
|
|
|
|
+
|
|
|
|
|
+ # 训练 — 传入文件日志回调替代 WebSocket 回调
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ file_cb = FileProgressCallback(job_id)
|
|
|
|
|
+
|
|
|
|
|
+ adapter_path = await engine.train(
|
|
|
|
|
+ job_id=job_id,
|
|
|
|
|
+ dataset_path=processed_path,
|
|
|
|
|
+ peft_config=peft_config,
|
|
|
|
|
+ training_args=config,
|
|
|
|
|
+ callbacks=[file_cb],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ elapsed = round(time.time() - start_time, 2)
|
|
|
|
|
+ _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
|
|
|
|
|
+ return adapter_path
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ _write_log(type="error", message=str(e), traceback=traceback.format_exc())
|
|
|
|
|
+ logger.error(f"Remote training failed: {job_id} - {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|