| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import json
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.logging import logger
- from app.engines.base import BaseEngine
- settings = get_settings()
- class MultimodalEngine(BaseEngine):
- """多模态模型训练引擎 (LLaVA/Qwen-VL 等视觉语言模型)。"""
- def __init__(self):
- self._processor = None
- self._model = None
- async def load_model(self, model_id: str, **kwargs: Any) -> None:
- """下载并加载多模态模型。"""
- import torch
- from transformers import AutoProcessor, LlavaForConditionalGeneration
- # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
- from app.services.model_service import resolve_model_path
- model_path = await resolve_model_path(model_id)
- if model_path:
- local_path = model_path
- else:
- local_path = str(settings.models_dir / model_id.replace("/", "_"))
- if not (Path(local_path) / "config.json").exists():
- from huggingface_hub import snapshot_download
- snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
- self._processor = AutoProcessor.from_pretrained(local_path, trust_remote_code=True)
- self._model = LlavaForConditionalGeneration.from_pretrained(
- local_path,
- torch_dtype=torch.float16,
- device_map="auto",
- trust_remote_code=True,
- )
- logger.info(f"Loaded multimodal model: {model_id}")
- def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
- from peft import LoraConfig, TaskType
- target_modules = params.get("lora_target_modules", "all-linear")
- if isinstance(target_modules, str) and target_modules == "all-linear":
- target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
- return LoraConfig(
- r=params.get("lora_r", 16),
- lora_alpha=params.get("lora_alpha", 32),
- lora_dropout=params.get("lora_dropout", 0.05),
- target_modules=target_modules,
- task_type=TaskType.CAUSAL_LM,
- )
- async def preprocess_dataset(
- self, dataset_path: str, output_path: str, **kwargs: Any
- ) -> str:
- """多模态数据集预处理 (image + text pairs)。"""
- from app.preprocessors import preprocess_file
- processed = preprocess_file(dataset_path, output_path, "sft", "raw")
- logger.info(f"Preprocessed {len(processed)} multimodal samples")
- return output_path
- async def train(
- self,
- job_id: str,
- dataset_path: str,
- peft_config: Any,
- training_args: dict[str, Any],
- ) -> str:
- from peft import get_peft_model
- from transformers import Trainer, TrainingArguments
- from datasets import Dataset as HFDataset
- data = []
- with open(dataset_path, "r", encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if line:
- data.append(json.loads(line))
- def collate_fn(examples):
- texts = [item.get("text", "") for item in examples]
- image_paths = [item.get("image_path", "") for item in examples if "image_path" in item]
- if image_paths:
- from PIL import Image
- images = [Image.open(p).convert("RGB") for p in image_paths if Path(p).exists()]
- if images:
- inputs = self._processor(text=texts, images=images, return_tensors="pt", padding=True)
- inputs["labels"] = inputs["input_ids"].clone()
- return inputs
- # fallback: text-only
- inputs = self._processor(text=texts, return_tensors="pt", padding=True)
- inputs["labels"] = inputs["input_ids"].clone()
- return inputs
- hf_dataset = HFDataset.from_list(data)
- self._model = get_peft_model(self._model, peft_config)
- self._model.print_trainable_parameters()
- output_dir = str(settings.adapters_dir / job_id)
- epochs = training_args.get("epochs", 3)
- batch_size = training_args.get("batch_size", 4)
- learning_rate = training_args.get("learning_rate", 2e-4)
- tr_args = TrainingArguments(
- output_dir=output_dir,
- num_train_epochs=epochs,
- per_device_train_batch_size=batch_size,
- learning_rate=learning_rate,
- save_strategy="epoch",
- logging_steps=10,
- fp16=True,
- optim="adamw_torch",
- remove_unused_columns=False,
- report_to="none",
- )
- callback = _ProgressCallback(job_id)
- trainer = Trainer(
- model=self._model,
- args=tr_args,
- train_dataset=hf_dataset,
- data_collator=collate_fn,
- callbacks=[callback],
- )
- try:
- trainer.train()
- self._model.save_pretrained(output_dir)
- self._processor.save_pretrained(output_dir)
- logger.info(f"Multimodal training completed for job {job_id}")
- except Exception as e:
- logger.error(f"Multimodal training failed for job {job_id}: {e}")
- raise
- return output_dir
- def get_model_info(self, model_id: str) -> dict[str, Any]:
- model_dir = settings.models_dir / model_id.replace("/", "_")
- config_path = model_dir / "config.json"
- if config_path.exists():
- with open(config_path) as f:
- config = json.load(f)
- return {
- "model_type": config.get("model_type", "multimodal"),
- "context_length": config.get("max_position_embeddings", 2048),
- "hidden_size": config.get("hidden_size", 0),
- "num_layers": config.get("num_hidden_layers", 0),
- }
- return {"model_type": "multimodal", "context_length": 4096}
- class _ProgressCallback:
- def __init__(self, job_id: str):
- self.job_id = job_id
- def on_log(self, args, state, control, logs=None, **kwargs):
- if logs and "loss" in logs:
- import asyncio
- asyncio.create_task(
- send_progress(self.job_id, epoch=int(state.epoch or 0), step=state.global_step,
- total_steps=state.max_steps or 0, loss=logs["loss"], learning_rate=logs.get("learning_rate", 0))
- )
- def on_epoch_end(self, args, state, control, **kwargs):
- import asyncio
- asyncio.create_task(send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None))
- def on_train_end(self, args, state, control, **kwargs):
- import asyncio
- asyncio.create_task(send_completed(self.job_id, total_time_seconds=getattr(state, "train_runtime", 0),
- adapter_path=str(settings.adapters_dir / self.job_id)))
- def on_train_begin(self, args, state, control, **kwargs): pass
- def on_step_end(self, args, state, control, **kwargs): pass
- def on_evaluate(self, args, state, control, metrics=None, **kwargs): pass
- def on_save(self, args, state, control, **kwargs): pass
- def on_predict(self, args, state, control, metrics=None, **kwargs): pass
- from app.core.websocket import send_completed, send_epoch_done, send_progress
- multimodal_engine = MultimodalEngine()
|