multimodal_engine.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import os
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. # 远程训练节点没有 pydantic-settings/数据库,直接用环境变量
  6. from types import SimpleNamespace
  7. def _resolve_data_dir() -> Path:
  8. v = os.environ.get("DATA_DIR") or os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR")
  9. if v:
  10. return Path(v)
  11. env_file = Path(__file__).resolve().parent.parent.parent / ".env"
  12. if env_file.exists():
  13. for line in env_file.read_text():
  14. if line.strip().startswith("DATA_DIR="):
  15. return Path(line.split("=", 1)[1].strip())
  16. return Path("/root/Fine-tuning/backend/data")
  17. _data_dir = _resolve_data_dir()
  18. settings = SimpleNamespace(
  19. data_dir=_data_dir,
  20. processed_dir=_data_dir / "processed",
  21. adapters_dir=_data_dir / "adapters",
  22. models_dir=_data_dir / "models",
  23. )
  24. import logging
  25. logger = logging.getLogger(__name__)
  26. from app.engines.base import BaseEngine
  27. class MultimodalEngine(BaseEngine):
  28. """多模态模型训练引擎 (LLaVA/Qwen-VL 等视觉语言模型)。"""
  29. def __init__(self):
  30. self._processor = None
  31. self._model = None
  32. async def load_model(self, model_id: str, **kwargs: Any) -> None:
  33. """下载并加载多模态模型。"""
  34. import torch
  35. from transformers import AutoProcessor, LlavaForConditionalGeneration
  36. local_path = str(settings.models_dir / model_id.replace("/", "_"))
  37. if not (Path(local_path) / "config.json").exists():
  38. ms_path = settings.models_dir / model_id
  39. if (ms_path / "config.json").exists():
  40. local_path = str(ms_path)
  41. else:
  42. from huggingface_hub import snapshot_download
  43. snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
  44. self._processor = AutoProcessor.from_pretrained(local_path, trust_remote_code=True)
  45. self._model = LlavaForConditionalGeneration.from_pretrained(
  46. local_path,
  47. dtype=torch.float16,
  48. device_map="auto",
  49. trust_remote_code=True,
  50. )
  51. logger.info(f"Loaded multimodal model: {model_id}")
  52. def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
  53. from peft import LoraConfig, TaskType
  54. target_modules = params.get("lora_target_modules", "all-linear")
  55. if isinstance(target_modules, str) and target_modules == "all-linear":
  56. target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
  57. return LoraConfig(
  58. r=params.get("lora_r", 16),
  59. lora_alpha=params.get("lora_alpha", 32),
  60. lora_dropout=params.get("lora_dropout", 0.05),
  61. target_modules=target_modules,
  62. task_type=TaskType.CAUSAL_LM,
  63. )
  64. async def preprocess_dataset(
  65. self, dataset_path: str, output_path: str, **kwargs: Any
  66. ) -> str:
  67. """多模态数据集预处理 (image + text pairs)。"""
  68. from app.preprocessors import preprocess_file
  69. processed = preprocess_file(dataset_path, output_path, "sft", "raw")
  70. logger.info(f"Preprocessed {len(processed)} multimodal samples")
  71. return output_path
  72. async def train(
  73. self,
  74. job_id: str,
  75. dataset_path: str,
  76. peft_config: Any,
  77. training_args: dict[str, Any],
  78. callbacks: list | None = None,
  79. ) -> str:
  80. from peft import get_peft_model
  81. from transformers import Trainer, TrainingArguments
  82. from datasets import Dataset as HFDataset
  83. data = []
  84. with open(dataset_path, "r", encoding="utf-8") as f:
  85. for line in f:
  86. line = line.strip()
  87. if line:
  88. data.append(json.loads(line))
  89. def collate_fn(examples):
  90. texts = [item.get("text", "") for item in examples]
  91. image_paths = [item.get("image_path", "") for item in examples if "image_path" in item]
  92. if image_paths:
  93. from PIL import Image
  94. images = [Image.open(p).convert("RGB") for p in image_paths if Path(p).exists()]
  95. if images:
  96. inputs = self._processor(text=texts, images=images, return_tensors="pt", padding=True)
  97. inputs["labels"] = inputs["input_ids"].clone()
  98. return inputs
  99. # fallback: text-only
  100. inputs = self._processor(text=texts, return_tensors="pt", padding=True)
  101. inputs["labels"] = inputs["input_ids"].clone()
  102. return inputs
  103. hf_dataset = HFDataset.from_list(data)
  104. # 计算总步数(AdaLoRA 需要在 get_peft_model 之前设置 total_step)
  105. epochs = training_args.get("epochs", 3)
  106. batch_size = training_args.get("batch_size", 4)
  107. learning_rate = training_args.get("learning_rate", 2e-4)
  108. dataset_len = len(hf_dataset)
  109. max_steps = max(1, (dataset_len * epochs) // batch_size)
  110. from peft import AdaLoraConfig
  111. if isinstance(peft_config, AdaLoraConfig):
  112. peft_config.total_step = max_steps
  113. self._model = get_peft_model(self._model, peft_config)
  114. self._model.print_trainable_parameters()
  115. output_dir = str(settings.adapters_dir / job_id)
  116. tr_args = TrainingArguments(
  117. output_dir=output_dir,
  118. num_train_epochs=epochs,
  119. max_steps=max_steps,
  120. per_device_train_batch_size=batch_size,
  121. learning_rate=learning_rate,
  122. save_strategy="epoch",
  123. logging_steps=10,
  124. fp16=True,
  125. optim="adamw_torch",
  126. remove_unused_columns=False,
  127. report_to="none",
  128. dataloader_num_workers=4,
  129. dataloader_pin_memory=False,
  130. )
  131. all_callbacks = callbacks if callbacks else [_ProgressCallback(job_id)]
  132. trainer = Trainer(
  133. model=self._model,
  134. args=tr_args,
  135. train_dataset=hf_dataset,
  136. data_collator=collate_fn,
  137. callbacks=all_callbacks,
  138. )
  139. try:
  140. trainer.train()
  141. self._model.save_pretrained(output_dir)
  142. self._processor.save_pretrained(output_dir)
  143. logger.info(f"Multimodal training completed for job {job_id}")
  144. except Exception as e:
  145. logger.error(f"Multimodal training failed for job {job_id}: {e}")
  146. raise
  147. return output_dir
  148. def get_model_info(self, model_id: str) -> dict[str, Any]:
  149. model_dir = settings.models_dir / model_id.replace("/", "_")
  150. config_path = model_dir / "config.json"
  151. if config_path.exists():
  152. with open(config_path) as f:
  153. config = json.load(f)
  154. return {
  155. "model_type": config.get("model_type", "multimodal"),
  156. "context_length": config.get("max_position_embeddings", 2048),
  157. "hidden_size": config.get("hidden_size", 0),
  158. "num_layers": config.get("num_hidden_layers", 0),
  159. }
  160. return {"model_type": "multimodal", "context_length": 4096}
  161. try:
  162. from transformers import TrainerCallback as _TrainerCallbackBase
  163. except ImportError:
  164. _TrainerCallbackBase = object # 151 主节点无 transformers,仅做占位
  165. class _ProgressCallback(_TrainerCallbackBase):
  166. def __init__(self, job_id: str):
  167. super().__init__()
  168. self.job_id = job_id
  169. def on_log(self, args, state, control, logs=None, **kwargs):
  170. if logs and "loss" in logs:
  171. import asyncio
  172. asyncio.create_task(
  173. send_progress(self.job_id, epoch=int(state.epoch or 0), step=state.global_step,
  174. total_steps=state.max_steps or 0, loss=logs["loss"], learning_rate=logs.get("learning_rate", 0))
  175. )
  176. def on_epoch_end(self, args, state, control, **kwargs):
  177. import asyncio
  178. asyncio.create_task(send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None))
  179. def on_train_end(self, args, state, control, **kwargs):
  180. import asyncio
  181. asyncio.create_task(send_completed(self.job_id, total_time_seconds=getattr(state, "train_runtime", 0),
  182. adapter_path=str(settings.adapters_dir / self.job_id)))
  183. from app.core.websocket import send_completed, send_epoch_done, send_progress
  184. multimodal_engine = MultimodalEngine()