vision_engine.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import os
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. # 远程训练节点没有 pydantic-settings/数据库,直接用环境变量
  6. from types import SimpleNamespace
  7. _data_dir = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
  8. settings = SimpleNamespace(
  9. data_dir=_data_dir,
  10. processed_dir=_data_dir / "processed",
  11. adapters_dir=_data_dir / "adapters",
  12. models_dir=_data_dir / "models",
  13. )
  14. import logging
  15. logger = logging.getLogger(__name__)
  16. from app.engines.base import BaseEngine
  17. class VisionEngine(BaseEngine):
  18. """视觉模型训练引擎 (ViT/CLIP/图像分类)。"""
  19. def __init__(self):
  20. self._processor = None
  21. self._model = None
  22. async def load_model(self, model_id: str, **kwargs: Any) -> None:
  23. """下载并加载视觉模型。"""
  24. import torch
  25. from transformers import AutoImageProcessor, AutoModelForImageClassification
  26. local_path = str(settings.models_dir / model_id.replace("/", "_"))
  27. if not (Path(local_path) / "config.json").exists():
  28. ms_path = settings.models_dir / model_id
  29. if (ms_path / "config.json").exists():
  30. local_path = str(ms_path)
  31. else:
  32. from huggingface_hub import snapshot_download
  33. snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
  34. self._processor = AutoImageProcessor.from_pretrained(local_path, trust_remote_code=True)
  35. self._model = AutoModelForImageClassification.from_pretrained(
  36. local_path,
  37. dtype=torch.float16,
  38. device_map="auto",
  39. trust_remote_code=True,
  40. )
  41. logger.info(f"Loaded vision model: {model_id}")
  42. def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
  43. from peft import LoraConfig, TaskType
  44. target_modules = params.get("lora_target_modules", "all-linear")
  45. if isinstance(target_modules, str) and target_modules == "all-linear":
  46. target_modules = ["linear", "q_proj", "v_proj"]
  47. return LoraConfig(
  48. r=params.get("lora_r", 16),
  49. lora_alpha=params.get("lora_alpha", 32),
  50. lora_dropout=params.get("lora_dropout", 0.05),
  51. target_modules=target_modules,
  52. task_type=TaskType.IMAGE_CLS,
  53. )
  54. async def preprocess_dataset(
  55. self, dataset_path: str, output_path: str, **kwargs: Any
  56. ) -> str:
  57. """图像数据集预处理(提取 image_path + label)。"""
  58. from app.preprocessors import preprocess_file
  59. processed = preprocess_file(dataset_path, output_path, "sft", "raw")
  60. logger.info(f"Preprocessed {len(processed)} vision samples")
  61. return output_path
  62. async def train(
  63. self,
  64. job_id: str,
  65. dataset_path: str,
  66. peft_config: Any,
  67. training_args: dict[str, Any],
  68. callbacks: list | None = None,
  69. ) -> str:
  70. from peft import get_peft_model
  71. from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
  72. from datasets import Dataset as HFDataset
  73. # Load and preprocess data
  74. data = []
  75. with open(dataset_path, "r", encoding="utf-8") as f:
  76. for line in f:
  77. line = line.strip()
  78. if line:
  79. data.append(json.loads(line))
  80. def transform(examples):
  81. images = []
  82. labels = []
  83. for item in examples:
  84. if "image_path" in item and Path(item["image_path"]).exists():
  85. from PIL import Image
  86. images.append(self._processor(Image.open(item["image_path"]).convert("RGB"))["pixel_values"])
  87. labels.append(int(item.get("label", 0)))
  88. elif "text" in item:
  89. # fallback: use text as label for classification
  90. labels.append(item.get("label", 0))
  91. if images:
  92. return {"pixel_values": images, "labels": labels}
  93. return {"pixel_values": [], "labels": []}
  94. hf_dataset = HFDataset.from_list(data)
  95. hf_dataset.set_transform(transform)
  96. # 计算总步数(AdaLoRA 需要在 get_peft_model 之前设置 total_step)
  97. epochs = training_args.get("epochs", 3)
  98. batch_size = training_args.get("batch_size", 4)
  99. learning_rate = training_args.get("learning_rate", 2e-4)
  100. dataset_len = len(hf_dataset)
  101. max_steps = max(1, (dataset_len * epochs) // batch_size)
  102. from peft import AdaLoraConfig
  103. if isinstance(peft_config, AdaLoraConfig):
  104. peft_config.total_step = max_steps
  105. self._model = get_peft_model(self._model, peft_config)
  106. self._model.print_trainable_parameters()
  107. output_dir = str(settings.adapters_dir / job_id)
  108. tr_args = TrainingArguments(
  109. output_dir=output_dir,
  110. num_train_epochs=epochs,
  111. max_steps=max_steps,
  112. per_device_train_batch_size=batch_size,
  113. learning_rate=learning_rate,
  114. save_strategy="epoch",
  115. logging_steps=10,
  116. fp16=True,
  117. optim="adamw_torch",
  118. remove_unused_columns=False,
  119. report_to="none",
  120. dataloader_num_workers=4,
  121. dataloader_pin_memory=False,
  122. )
  123. all_callbacks = callbacks if callbacks else [_ProgressCallback(job_id)]
  124. trainer = Trainer(
  125. model=self._model,
  126. args=tr_args,
  127. train_dataset=hf_dataset,
  128. data_collator=DataCollatorWithPadding(self._processor),
  129. callbacks=all_callbacks,
  130. )
  131. try:
  132. trainer.train()
  133. self._model.save_pretrained(output_dir)
  134. self._processor.save_pretrained(output_dir)
  135. logger.info(f"Vision training completed for job {job_id}")
  136. except Exception as e:
  137. logger.error(f"Vision training failed for job {job_id}: {e}")
  138. raise
  139. return output_dir
  140. def get_model_info(self, model_id: str) -> dict[str, Any]:
  141. model_dir = settings.models_dir / model_id.replace("/", "_")
  142. config_path = model_dir / "config.json"
  143. if config_path.exists():
  144. with open(config_path) as f:
  145. config = json.load(f)
  146. return {
  147. "model_type": config.get("model_type", "vision"),
  148. "context_length": config.get("max_position_embeddings", 2048),
  149. "hidden_size": config.get("hidden_size", 0),
  150. "num_layers": config.get("num_hidden_layers", 0),
  151. }
  152. return {"model_type": "vision", "context_length": 2048}
  153. try:
  154. from transformers import TrainerCallback as _TrainerCallbackBase
  155. except ImportError:
  156. _TrainerCallbackBase = object # 151 主节点无 transformers,仅做占位
  157. class _ProgressCallback(_TrainerCallbackBase):
  158. def __init__(self, job_id: str):
  159. super().__init__()
  160. self.job_id = job_id
  161. def on_log(self, args, state, control, logs=None, **kwargs):
  162. if logs and "loss" in logs:
  163. import asyncio
  164. asyncio.create_task(
  165. send_progress(self.job_id, epoch=int(state.epoch or 0), step=state.global_step,
  166. total_steps=state.max_steps or 0, loss=logs["loss"], learning_rate=logs.get("learning_rate", 0))
  167. )
  168. def on_epoch_end(self, args, state, control, **kwargs):
  169. import asyncio
  170. asyncio.create_task(send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None))
  171. def on_train_end(self, args, state, control, **kwargs):
  172. import asyncio
  173. asyncio.create_task(send_completed(self.job_id, total_time_seconds=getattr(state, "train_runtime", 0),
  174. adapter_path=str(settings.adapters_dir / self.job_id)))
  175. from app.core.websocket import send_completed, send_epoch_done, send_progress
  176. vision_engine = VisionEngine()