|
@@ -1,20 +1,312 @@
|
|
|
|
|
+import asyncio
|
|
|
|
|
+import json
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Any
|
|
|
|
|
+
|
|
|
|
|
+from app.config import get_settings
|
|
|
|
|
+from app.core.logging import logger
|
|
|
from app.engines.base import BaseEngine
|
|
from app.engines.base import BaseEngine
|
|
|
|
|
|
|
|
|
|
+settings = get_settings()
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class TextEngine(BaseEngine):
|
|
class TextEngine(BaseEngine):
|
|
|
- """Training engine for LLaMA, Qwen, and other text-only LLMs."""
|
|
|
|
|
|
|
+ """文本模型训练引擎 (LLaMA/Qwen/ChatGLM 等因果语言模型)。"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ self._tokenizer = None
|
|
|
|
|
+ self._model = None
|
|
|
|
|
+
|
|
|
|
|
+ async def load_model(self, model_id: str, **kwargs: Any) -> None:
|
|
|
|
|
+ """下载并加载基础模型。"""
|
|
|
|
|
+ import torch
|
|
|
|
|
+ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
+
|
|
|
|
|
+ local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
|
|
|
+
|
|
|
|
|
+ # 如果本地没有,从 HF 下载
|
|
|
|
|
+ if not (Path(local_path) / "config.json").exists():
|
|
|
|
|
+ from huggingface_hub import snapshot_download
|
|
|
|
|
+
|
|
|
|
|
+ snapshot_download(
|
|
|
|
|
+ repo_id=model_id,
|
|
|
|
|
+ local_dir=local_path,
|
|
|
|
|
+ local_dir_use_symlinks=False,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ quantization = kwargs.get("quantization", None)
|
|
|
|
|
+ load_kwargs: dict[str, Any] = {
|
|
|
|
|
+ "torch_dtype": torch.float16,
|
|
|
|
|
+ "device_map": "auto",
|
|
|
|
|
+ }
|
|
|
|
|
+ if quantization == "4bit" or quantization == "qlora":
|
|
|
|
|
+ load_kwargs["load_in_4bit"] = True
|
|
|
|
|
+ load_kwargs["bnb_4bit_quant_type"] = "nf4"
|
|
|
|
|
+ load_kwargs["bnb_4bit_use_double_quant"] = True
|
|
|
|
|
+ elif quantization == "8bit":
|
|
|
|
|
+ load_kwargs["load_in_8bit"] = True
|
|
|
|
|
+
|
|
|
|
|
+ self._tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
|
|
|
|
|
+ if self._tokenizer.pad_token is None:
|
|
|
|
|
+ self._tokenizer.pad_token = self._tokenizer.eos_token
|
|
|
|
|
+
|
|
|
|
|
+ self._model = AutoModelForCausalLM.from_pretrained(local_path, **load_kwargs)
|
|
|
|
|
+ logger.info(f"Loaded model: {model_id}")
|
|
|
|
|
+
|
|
|
|
|
+ def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
|
|
|
|
|
+ """根据 PEFT 方法返回对应的配置对象。"""
|
|
|
|
|
+ from app.peft import (
|
|
|
|
|
+ build_adalora_config,
|
|
|
|
|
+ build_ia3_config,
|
|
|
|
|
+ build_lora_config,
|
|
|
|
|
+ build_prefix_tuning_config,
|
|
|
|
|
+ build_qlora_config,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ builders = {
|
|
|
|
|
+ "lora": build_lora_config,
|
|
|
|
|
+ "qlora": build_qlora_config,
|
|
|
|
|
+ "ia3": build_ia3_config,
|
|
|
|
|
+ "adalora": build_adalora_config,
|
|
|
|
|
+ "prefix_tuning": build_prefix_tuning_config,
|
|
|
|
|
+ }
|
|
|
|
|
+ builder = builders.get(method, build_lora_config)
|
|
|
|
|
+ return builder(params)
|
|
|
|
|
+
|
|
|
|
|
+ async def preprocess_dataset(
|
|
|
|
|
+ self,
|
|
|
|
|
+ dataset_path: str,
|
|
|
|
|
+ output_path: str,
|
|
|
|
|
+ task_type: str = "sft",
|
|
|
|
|
+ template: str = "alpaca",
|
|
|
|
|
+ **kwargs: Any,
|
|
|
|
|
+ ) -> str:
|
|
|
|
|
+ """将数据集预处理为训练格式。"""
|
|
|
|
|
+ from app.preprocessors import preprocess_file
|
|
|
|
|
+
|
|
|
|
|
+ processed = preprocess_file(dataset_path, output_path, task_type, template)
|
|
|
|
|
+ logger.info(f"Preprocessed {len(processed)} samples for {task_type}/{template}")
|
|
|
|
|
+ return output_path
|
|
|
|
|
+
|
|
|
|
|
+ async def train(
|
|
|
|
|
+ self,
|
|
|
|
|
+ job_id: str,
|
|
|
|
|
+ dataset_path: str,
|
|
|
|
|
+ peft_config: Any,
|
|
|
|
|
+ training_args: dict[str, Any],
|
|
|
|
|
+ ) -> str:
|
|
|
|
|
+ """执行训练。"""
|
|
|
|
|
+ from peft import get_peft_model
|
|
|
|
|
+ from transformers import DataCollatorForSeq2Seq, TrainingArguments
|
|
|
|
|
+
|
|
|
|
|
+ task_type = training_args.get("task_type", "sft")
|
|
|
|
|
+ epochs = training_args.get("epochs", 3)
|
|
|
|
|
+ batch_size = training_args.get("batch_size", 4)
|
|
|
|
|
+ gradient_accumulation = training_args.get("gradient_accumulation", 4)
|
|
|
|
|
+ learning_rate = training_args.get("learning_rate", 2e-4)
|
|
|
|
|
+ max_seq_length = training_args.get("max_seq_length", 2048)
|
|
|
|
|
+ warmup_ratio = training_args.get("warmup_ratio", 0.05)
|
|
|
|
|
+ save_strategy = training_args.get("save_strategy", "epoch")
|
|
|
|
|
+ deepspeed_config = training_args.get("deepspeed", None)
|
|
|
|
|
+
|
|
|
|
|
+ dataset = self._tokenize_dataset(dataset_path, max_seq_length)
|
|
|
|
|
+
|
|
|
|
|
+ self._model = get_peft_model(self._model, peft_config)
|
|
|
|
|
+ self._model.print_trainable_parameters()
|
|
|
|
|
+
|
|
|
|
|
+ output_dir = str(settings.adapters_dir / job_id)
|
|
|
|
|
+ tr_args = TrainingArguments(
|
|
|
|
|
+ output_dir=output_dir,
|
|
|
|
|
+ num_train_epochs=epochs,
|
|
|
|
|
+ per_device_train_batch_size=batch_size,
|
|
|
|
|
+ gradient_accumulation_steps=gradient_accumulation,
|
|
|
|
|
+ learning_rate=learning_rate,
|
|
|
|
|
+ warmup_ratio=warmup_ratio,
|
|
|
|
|
+ save_strategy=save_strategy,
|
|
|
|
|
+ logging_strategy="steps",
|
|
|
|
|
+ logging_steps=10,
|
|
|
|
|
+ fp16=True,
|
|
|
|
|
+ optim="adamw_torch",
|
|
|
|
|
+ remove_unused_columns=False,
|
|
|
|
|
+ report_to="none",
|
|
|
|
|
+ **({"deepspeed": deepspeed_config} if deepspeed_config else {}),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ callback = _ProgressCallback(job_id)
|
|
|
|
|
+
|
|
|
|
|
+ if task_type == "sft":
|
|
|
|
|
+ from transformers import Trainer
|
|
|
|
|
+
|
|
|
|
|
+ trainer = Trainer(
|
|
|
|
|
+ model=self._model,
|
|
|
|
|
+ args=tr_args,
|
|
|
|
|
+ train_dataset=dataset,
|
|
|
|
|
+ data_collator=DataCollatorForSeq2Seq(self._tokenizer),
|
|
|
|
|
+ callbacks=[callback],
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ from trl import (
|
|
|
|
|
+ DPOConfig,
|
|
|
|
|
+ DPOTrainer,
|
|
|
|
|
+ KTOConfig,
|
|
|
|
|
+ KTOTrainer,
|
|
|
|
|
+ ORPOConfig,
|
|
|
|
|
+ ORPOTrainer,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ base_trainer_kwargs = dict(
|
|
|
|
|
+ output_dir=output_dir,
|
|
|
|
|
+ num_train_epochs=epochs,
|
|
|
|
|
+ per_device_train_batch_size=batch_size,
|
|
|
|
|
+ gradient_accumulation_steps=gradient_accumulation,
|
|
|
|
|
+ learning_rate=learning_rate,
|
|
|
|
|
+ warmup_ratio=warmup_ratio,
|
|
|
|
|
+ save_strategy=save_strategy,
|
|
|
|
|
+ logging_steps=10,
|
|
|
|
|
+ fp16=True,
|
|
|
|
|
+ report_to="none",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if task_type == "dpo":
|
|
|
|
|
+ trainer = DPOTrainer(
|
|
|
|
|
+ model=self._model,
|
|
|
|
|
+ args=DPOConfig(**base_trainer_kwargs),
|
|
|
|
|
+ train_dataset=dataset,
|
|
|
|
|
+ processing_class=self._tokenizer,
|
|
|
|
|
+ )
|
|
|
|
|
+ elif task_type == "orpo":
|
|
|
|
|
+ trainer = ORPOTrainer(
|
|
|
|
|
+ model=self._model,
|
|
|
|
|
+ args=ORPOConfig(**base_trainer_kwargs),
|
|
|
|
|
+ train_dataset=dataset,
|
|
|
|
|
+ processing_class=self._tokenizer,
|
|
|
|
|
+ )
|
|
|
|
|
+ elif task_type == "kto":
|
|
|
|
|
+ trainer = KTOTrainer(
|
|
|
|
|
+ model=self._model,
|
|
|
|
|
+ args=KTOConfig(**base_trainer_kwargs),
|
|
|
|
|
+ train_dataset=dataset,
|
|
|
|
|
+ processing_class=self._tokenizer,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ trainer = Trainer(
|
|
|
|
|
+ model=self._model,
|
|
|
|
|
+ args=tr_args,
|
|
|
|
|
+ train_dataset=dataset,
|
|
|
|
|
+ data_collator=DataCollatorForSeq2Seq(self._tokenizer),
|
|
|
|
|
+ callbacks=[callback],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ trainer.train()
|
|
|
|
|
+ self._model.save_pretrained(output_dir)
|
|
|
|
|
+ self._tokenizer.save_pretrained(output_dir)
|
|
|
|
|
+ logger.info(f"Training completed for job {job_id}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Training failed for job {job_id}: {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
|
|
+ return output_dir
|
|
|
|
|
+
|
|
|
|
|
+ def get_model_info(self, model_id: str) -> dict[str, Any]:
|
|
|
|
|
+ """读取模型配置信息。"""
|
|
|
|
|
+ import json
|
|
|
|
|
+ from pathlib import Path
|
|
|
|
|
+
|
|
|
|
|
+ model_dir = settings.models_dir / model_id.replace("/", "_")
|
|
|
|
|
+ config_path = model_dir / "config.json"
|
|
|
|
|
+
|
|
|
|
|
+ if config_path.exists():
|
|
|
|
|
+ with open(config_path) as f:
|
|
|
|
|
+ config = json.load(f)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "model_type": config.get("model_type", "causal_lm"),
|
|
|
|
|
+ "context_length": config.get("max_position_embeddings", config.get("max_sequence_length", 2048)),
|
|
|
|
|
+ "hidden_size": config.get("hidden_size", 0),
|
|
|
|
|
+ "num_layers": config.get("num_hidden_layers", 0),
|
|
|
|
|
+ }
|
|
|
|
|
+ return {"model_type": "causal_lm", "context_length": 2048}
|
|
|
|
|
+
|
|
|
|
|
+ def _tokenize_dataset(self, dataset_path: str, max_seq_length: int):
|
|
|
|
|
+ """Tokenize 处理后的 JSONL 数据集。"""
|
|
|
|
|
+ from datasets import Dataset as HFDataset
|
|
|
|
|
+
|
|
|
|
|
+ data = []
|
|
|
|
|
+ with open(dataset_path, "r", encoding="utf-8") as f:
|
|
|
|
|
+ for line in f:
|
|
|
|
|
+ line = line.strip()
|
|
|
|
|
+ if line:
|
|
|
|
|
+ data.append(json.loads(line))
|
|
|
|
|
+
|
|
|
|
|
+ hf_dataset = HFDataset.from_list(data)
|
|
|
|
|
+
|
|
|
|
|
+ def tokenize_fn(batch):
|
|
|
|
|
+ prompts = batch.get("prompt", [""] * len(data))
|
|
|
|
|
+ completions = batch.get("completion", [""] * len(data))
|
|
|
|
|
+
|
|
|
|
|
+ if isinstance(prompts, str):
|
|
|
|
|
+ prompts = [prompts]
|
|
|
|
|
+ if isinstance(completions, str):
|
|
|
|
|
+ completions = [completions]
|
|
|
|
|
+
|
|
|
|
|
+ full_texts = [f"{p}\n{c}" for p, c in zip(prompts, completions)]
|
|
|
|
|
+ tokenized = self._tokenizer(
|
|
|
|
|
+ full_texts, truncation=True, max_length=max_seq_length, padding=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ tokenized["labels"] = list(tokenized["input_ids"])
|
|
|
|
|
+ return tokenized
|
|
|
|
|
+
|
|
|
|
|
+ return hf_dataset.map(tokenize_fn, batched=True)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _ProgressCallback:
|
|
|
|
|
+ """自定义训练进度回调,通过 WebSocket 发送进度。"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, job_id: str):
|
|
|
|
|
+ self.job_id = job_id
|
|
|
|
|
+
|
|
|
|
|
+ def on_log(self, args, state, control, logs=None, **kwargs):
|
|
|
|
|
+ if logs and "loss" in logs:
|
|
|
|
|
+ asyncio.create_task(
|
|
|
|
|
+ send_progress(
|
|
|
|
|
+ self.job_id,
|
|
|
|
|
+ epoch=int(state.epoch or 0),
|
|
|
|
|
+ step=state.global_step,
|
|
|
|
|
+ total_steps=state.max_steps or 0,
|
|
|
|
|
+ loss=logs["loss"],
|
|
|
|
|
+ learning_rate=logs.get("learning_rate", 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def on_epoch_end(self, args, state, control, **kwargs):
|
|
|
|
|
+ asyncio.create_task(
|
|
|
|
|
+ send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def on_train_end(self, args, state, control, **kwargs):
|
|
|
|
|
+ asyncio.create_task(
|
|
|
|
|
+ send_completed(
|
|
|
|
|
+ self.job_id,
|
|
|
|
|
+ total_time_seconds=getattr(state, "train_runtime", 0),
|
|
|
|
|
+ adapter_path=str(settings.adapters_dir / self.job_id),
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def on_train_begin(self, args, state, control, **kwargs):
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ def on_step_end(self, args, state, control, **kwargs):
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
- async def load_model(self, model_id: str, **kwargs):
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
|
|
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
- def get_peft_config(self, method: str, params: dict):
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
|
|
+ def on_save(self, args, state, control, **kwargs):
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
- async def preprocess_dataset(self, dataset_path: str, output_path: str, **kwargs):
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
|
|
+ def on_predict(self, args, state, control, metrics=None, **kwargs):
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
- async def train(self, job_id: str, dataset_path: str, peft_config, training_args: dict):
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
- def get_model_info(self, model_id: str):
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
|
|
+# 全局单例
|
|
|
|
|
+text_engine = TextEngine()
|