import os import json from pathlib import Path from typing import Any # 远程训练节点没有 pydantic-settings/数据库,直接用环境变量 from types import SimpleNamespace _data_dir = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data")) settings = SimpleNamespace( data_dir=_data_dir, processed_dir=_data_dir / "processed", adapters_dir=_data_dir / "adapters", models_dir=_data_dir / "models", ) import logging logger = logging.getLogger(__name__) from app.engines.base import BaseEngine class VisionEngine(BaseEngine): """视觉模型训练引擎 (ViT/CLIP/图像分类)。""" def __init__(self): self._processor = None self._model = None async def load_model(self, model_id: str, **kwargs: Any) -> None: """下载并加载视觉模型。""" import torch from transformers import AutoImageProcessor, AutoModelForImageClassification local_path = str(settings.models_dir / model_id.replace("/", "_")) if not (Path(local_path) / "config.json").exists(): ms_path = settings.models_dir / model_id if (ms_path / "config.json").exists(): local_path = str(ms_path) else: from huggingface_hub import snapshot_download snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False) self._processor = AutoImageProcessor.from_pretrained(local_path, trust_remote_code=True) self._model = AutoModelForImageClassification.from_pretrained( local_path, dtype=torch.float16, device_map="auto", trust_remote_code=True, ) logger.info(f"Loaded vision model: {model_id}") def get_peft_config(self, method: str, params: dict[str, Any]) -> Any: from peft import LoraConfig, TaskType target_modules = params.get("lora_target_modules", "all-linear") if isinstance(target_modules, str) and target_modules == "all-linear": target_modules = ["linear", "q_proj", "v_proj"] return LoraConfig( r=params.get("lora_r", 16), lora_alpha=params.get("lora_alpha", 32), lora_dropout=params.get("lora_dropout", 0.05), target_modules=target_modules, task_type=TaskType.IMAGE_CLS, ) async def preprocess_dataset( self, dataset_path: str, output_path: str, **kwargs: Any ) -> str: """图像数据集预处理(提取 image_path + label)。""" from app.preprocessors import preprocess_file processed = preprocess_file(dataset_path, output_path, "sft", "raw") logger.info(f"Preprocessed {len(processed)} vision samples") return output_path async def train( self, job_id: str, dataset_path: str, peft_config: Any, training_args: dict[str, Any], callbacks: list | None = None, ) -> str: from peft import get_peft_model from transformers import DataCollatorWithPadding, Trainer, TrainingArguments from datasets import Dataset as HFDataset # Load and preprocess data data = [] with open(dataset_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: data.append(json.loads(line)) def transform(examples): images = [] labels = [] for item in examples: if "image_path" in item and Path(item["image_path"]).exists(): from PIL import Image images.append(self._processor(Image.open(item["image_path"]).convert("RGB"))["pixel_values"]) labels.append(int(item.get("label", 0))) elif "text" in item: # fallback: use text as label for classification labels.append(item.get("label", 0)) if images: return {"pixel_values": images, "labels": labels} return {"pixel_values": [], "labels": []} hf_dataset = HFDataset.from_list(data) hf_dataset.set_transform(transform) # 计算总步数(AdaLoRA 需要在 get_peft_model 之前设置 total_step) epochs = training_args.get("epochs", 3) batch_size = training_args.get("batch_size", 4) learning_rate = training_args.get("learning_rate", 2e-4) dataset_len = len(hf_dataset) max_steps = max(1, (dataset_len * epochs) // batch_size) from peft import AdaLoraConfig if isinstance(peft_config, AdaLoraConfig): peft_config.total_step = max_steps self._model = get_peft_model(self._model, peft_config) self._model.print_trainable_parameters() output_dir = str(settings.adapters_dir / job_id) tr_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, max_steps=max_steps, per_device_train_batch_size=batch_size, learning_rate=learning_rate, save_strategy="epoch", logging_steps=10, fp16=True, optim="adamw_torch", remove_unused_columns=False, report_to="none", dataloader_num_workers=4, dataloader_pin_memory=False, ) all_callbacks = callbacks if callbacks else [_ProgressCallback(job_id)] trainer = Trainer( model=self._model, args=tr_args, train_dataset=hf_dataset, data_collator=DataCollatorWithPadding(self._processor), callbacks=all_callbacks, ) try: trainer.train() self._model.save_pretrained(output_dir) self._processor.save_pretrained(output_dir) logger.info(f"Vision training completed for job {job_id}") except Exception as e: logger.error(f"Vision training failed for job {job_id}: {e}") raise return output_dir def get_model_info(self, model_id: str) -> dict[str, Any]: model_dir = settings.models_dir / model_id.replace("/", "_") config_path = model_dir / "config.json" if config_path.exists(): with open(config_path) as f: config = json.load(f) return { "model_type": config.get("model_type", "vision"), "context_length": config.get("max_position_embeddings", 2048), "hidden_size": config.get("hidden_size", 0), "num_layers": config.get("num_hidden_layers", 0), } return {"model_type": "vision", "context_length": 2048} class _ProgressCallback: def __init__(self, job_id: str): self.job_id = job_id def on_log(self, args, state, control, logs=None, **kwargs): if logs and "loss" in logs: import asyncio asyncio.create_task( send_progress(self.job_id, epoch=int(state.epoch or 0), step=state.global_step, total_steps=state.max_steps or 0, loss=logs["loss"], learning_rate=logs.get("learning_rate", 0)) ) def on_epoch_end(self, args, state, control, **kwargs): import asyncio asyncio.create_task(send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None)) def on_train_end(self, args, state, control, **kwargs): import asyncio asyncio.create_task(send_completed(self.job_id, total_time_seconds=getattr(state, "train_runtime", 0), adapter_path=str(settings.adapters_dir / self.job_id))) def on_train_begin(self, args, state, control, **kwargs): pass def on_step_end(self, args, state, control, **kwargs): pass def on_evaluate(self, args, state, control, metrics=None, **kwargs): pass def on_save(self, args, state, control, **kwargs): pass def on_predict(self, args, state, control, metrics=None, **kwargs): pass def on_init_end(self, args, state, control, **kwargs): pass def on_epoch_begin(self, args, state, control, **kwargs): pass from app.core.websocket import send_completed, send_epoch_done, send_progress vision_engine = VisionEngine()