Просмотр исходного кода

修复模型微调显存占用过大问题

lxylxy123321 6 дней назад
Родитель
Сommit
4128ea1efe

+ 25 - 0
backend/app/core/job_queue.py

@@ -196,6 +196,9 @@ class JobQueue:
                 if not dataset_path:
                     raise FileNotFoundError(f"Dataset not found: {dataset_id}")
 
+                # 启动新训练前清理容器内所有残留的 python 进程(释放 GPU ring buffer)
+                await self._cleanup_remote_processes()
+
                 self.update_job(job_id, status=JobStatus.TRAINING)
                 await self._notify_callbacks()
 
@@ -271,6 +274,28 @@ class JobQueue:
 
         return None
 
+    async def _cleanup_remote_processes(self):
+        """通过 SSH 清理容器内所有残留的 python 进程(包括僵尸进程),释放 GPU ring buffer。"""
+        from app.config import get_settings
+        from app.core.remote_executor import ssh_exec
+
+        settings = get_settings()
+        container = settings.compute_node_docker_container
+
+        # 先查找所有 python 进程(包括僵尸)
+        cmd = f"docker exec {container} bash -c 'ps aux | grep \"[p]ython\" | grep -v grep | awk \"{{print \\$2}}\"'"
+        code, stdout, _ = await asyncio.to_thread(ssh_exec, cmd, timeout=15)
+        if code == 0 and stdout.strip():
+            pids = stdout.strip().split("\n")
+            for pid in pids:
+                pid = pid.strip()
+                if not pid:
+                    continue
+                # 强制 kill(僵尸进程需要父进程 reaper 清理,kill -9 后 PID 1 会自动 reap)
+                kill_cmd = f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; wait {pid} 2>/dev/null'"
+                await asyncio.to_thread(ssh_exec, kill_cmd, timeout=5)
+            logger.info(f"Cleaned up {len(pids)} remote python processes in container {container}")
+
     async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
         """从数据库查找数据集路径。"""
         from app.core.db import async_session, DatasetRecord

+ 5 - 3
backend/app/core/remote_executor.py

@@ -154,6 +154,8 @@ def run_training_remote(
     remote_log_dir = f"{settings.compute_node_remote_data_dir}/logs"
     _, _, _ = ssh_exec(f"mkdir -p {remote_log_dir}")
 
+    # 使用 setsid 启动训练进程,确保进程组独立,kill 时能正确清理子进程
+    # trap 确保 shell 退出时会 wait 子进程,避免产生僵尸进程
     remote_cmd = (
         f"docker exec "
         f"-e MACA_MPS_MODE=1 "
@@ -161,10 +163,10 @@ def run_training_remote(
         f"-w {settings.compute_node_workdir} "
         f"{settings.compute_node_docker_container} "
         f"bash -c '"
-        f"nohup {settings.compute_node_python} -m app.engines.remote_train "
+        f"setsid {settings.compute_node_python} -m app.engines.remote_train "
         f"{job_id} {model_id} {model_type} {remote_dataset_path} {remote_config_path} "
-        f"</dev/null >/tmp/train_{job_id}.log 2>&1 "
-        f"& echo $!'"
+        f"</dev/null >/tmp/train_{job_id}.log 2>&1 &"
+        f" disown; echo $!'"
     )
 
     code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)

+ 1 - 1
backend/app/engines/multimodal_engine.py

@@ -45,7 +45,7 @@ class MultimodalEngine(BaseEngine):
         self._processor = AutoProcessor.from_pretrained(local_path, trust_remote_code=True)
         self._model = LlavaForConditionalGeneration.from_pretrained(
             local_path,
-            torch_dtype=torch.float16,
+            dtype=torch.float16,
             device_map="auto",
             trust_remote_code=True,
         )

+ 2 - 2
backend/app/engines/text_engine.py

@@ -85,7 +85,7 @@ class TextEngine(BaseEngine):
         device_map = {"": first_gpu}
 
         load_kwargs: dict[str, Any] = {
-            "torch_dtype": torch.float16,
+            "dtype": torch.float16,
             "device_map": device_map,
             "low_cpu_mem_usage": True,
             "use_safetensors": True,
@@ -207,7 +207,7 @@ class TextEngine(BaseEngine):
             optim="adamw_torch",
             remove_unused_columns=False,
             report_to="none",
-            gradient_checkpointing=False,
+            gradient_checkpointing=True,
             dataloader_num_workers=0,
             dataloader_pin_memory=False,
             **({"deepspeed": deepspeed_config} if deepspeed_config else {}),

+ 1 - 1
backend/app/engines/vision_engine.py

@@ -45,7 +45,7 @@ class VisionEngine(BaseEngine):
         self._processor = AutoImageProcessor.from_pretrained(local_path, trust_remote_code=True)
         self._model = AutoModelForImageClassification.from_pretrained(
             local_path,
-            torch_dtype=torch.float16,
+            dtype=torch.float16,
             device_map="auto",
             trust_remote_code=True,
         )

+ 16 - 18
result.txt

@@ -1,18 +1,16 @@
-(base) [root@localhost ~]# docker exec finetune-trainer /opt/conda/bin/python -c "from transformers import AutoModelForCausalLM, AutoConfig; cfg = AutoConfig.from_pretrained('/root/Fine-tuning/backend/data/models/Qwen_Qwen3.5-0.8B'); print('model_type:', cfg.model_type); print('architectures:', cfg.architectures)"
-model_type: qwen3_5
-architectures: ['Qwen3_5ForConditionalGeneration']
-(base) [root@localhost ~]# docker exec finetune-trainer /opt/conda/bin/python -c "import torch; print('torch:', torch.__version__); print('cuda:', torch.cuda.is_available()); print('devices:', torch.cuda.device_count())"
-torch: 2.8.0+metax3.5.3.9
-cuda: True
-devices: 4
-(base) [root@localhost ~]# docker exec finetune-trainer /opt/conda/bin/python -c "import torch; from transformers import AutoModelForCausalLM; m = AutoModelForCausalLM.from_pretrained('/root/Fine-tuning/backend/data/models/Qwen_Qwen3.5-0.8B', torch_dtype=torch.float16, device_map='auto'); print('Loaded OK')"
-[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
-Current Triton version 3.0.0 is below the recommended 3.2.0 version. Errors may occur and these issues will not be fixed. Please consider upgrading Triton.
-Current Python version 3.10 is below the recommended 3.11 version. It is recommended to upgrade to Python 3.11 or higher for the best experience.
-torch.compile is not available in Python 3.10, using identity decorator instead
-/opt/conda/lib/python3.10/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
-  warnings.warn(_BETA_TRANSFORMS_WARNING)
-/opt/conda/lib/python3.10/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
-  warnings.warn(_BETA_TRANSFORMS_WARNING)
-Loading weights: 100%|██████████| 320/320 [00:06<00:00, 48.32it/s]
-Loaded OK
+(base) [root@localhost ~]# docker exec finetune-trainer ps aux
+USER         PID %CPU %MEM    VSZ   RSS TTY      STAT START   TIME COMMAND
+root           1  0.0  0.0   2824  1060 ?        Ss   May21   0:00 tail -f /dev/null
+root         300  0.1  0.0      0     0 ?        Z    May21   0:30 [python] <defunct>
+root         405  0.2  0.0      0     0 ?        Z    May21   0:59 [python] <defunct>
+root        1139  0.6  0.0      0     0 ?        Z    May21   2:57 [python] <defunct>
+root        1496  0.1  0.0      0     0 ?        Z    May21   0:37 [python] <defunct>
+root        1655  0.2  0.0      0     0 ?        Z    May21   1:21 [python] <defunct>
+root       13911  2.3  0.0      0     0 ?        Z    May21   0:38 [python] <defunct>
+root       14070  4.2  0.0      0     0 ?        Z    May21   0:59 [python] <defunct>
+root       14488  7.3  0.0      0     0 ?        Z    May21   1:00 [python] <defunct>
+root       14906  147  2.1 56294212 11559636 ?   Sl   May21   7:06 /opt/conda/bin/python -m app.engines.remote_train 3485b881-b7be-4a0d-83bd-e8330d9b0fad Qwen/Qwen1.5-0.5B text /root/Fine-tuning/backend/data/datasets/data.jsonl /root/Fine-tuning/backend/data/config_3485b881-b7be-4a0d-83bd-e8330d9b0fad.json
+root       15565  0.0  0.0   7064  1592 ?        Rs   00:03   0:00 ps aux
+(base) [root@localhost ~]# docker exec finetune-trainer bash -c 'maca-smi || nvidia-smi'
+bash: line 1: maca-smi: command not found
+bash: line 1: nvidia-smi: command not found