فهرست منبع

仅使用GPU时用253的容器

lxylxy123321 1 هفته پیش
والد
کامیت
4ac59d0563

+ 5 - 18
backend/Dockerfile

@@ -1,32 +1,19 @@
-# 使用本地已有的沐曦 maca + PyTorch 2.8 + Python 3.10 镜像
-# 驱动版本 maca 3.5.3.502,PyTorch 2.8,兼容性好
-FROM cr.metax-tech.com/public-ai-release/maca/vllm-metax:0.19.0-maca.ai3.5.3.502-torch2.8-py310-ubuntu22.04-amd64
+# 主节点(151)后端 — 轻量级 Python 镜像,不含 GPU 依赖
+# 仅负责 API/DB/WebSocket/SSH 调度,实际训练/推理在 253 算力节点执行
+FROM python:3.10-slim
 
 
 WORKDIR /app
 WORKDIR /app
 
 
-# 设置 conda Python 路径(镜像使用 /opt/conda)
-ENV PATH="/opt/conda/bin:$PATH"
-
 RUN apt-get update && apt-get install -y git openssh-client sshpass && rm -rf /var/lib/apt/lists/*
 RUN apt-get update && apt-get install -y git openssh-client sshpass && rm -rf /var/lib/apt/lists/*
 
 
-# 升级 pip
-RUN /opt/conda/bin/pip install --no-cache-dir --upgrade pip
-
-# 复制依赖文件并安装(跳过 torch,镜像已自带)
 COPY requirements.txt .
 COPY requirements.txt .
-RUN /opt/conda/bin/pip install --no-cache-dir -r requirements.txt
+RUN pip install --no-cache-dir -r requirements.txt
 
 
-# 复制应用代码
 COPY . .
 COPY . .
 
 
-# 沐曦 maca 环境变量(镜像中通常已设置,这里显式声明)
-ENV MACA_PATH=/opt/maca
-ENV LD_LIBRARY_PATH=/opt/maca/lib:/opt/maca/mxgpu_llvm/lib:/opt/maca/ompi/lib:${LD_LIBRARY_PATH}
-ENV MACA_CLANG_PATH=/opt/maca/mxgpu_llvm/bin
-
 EXPOSE 8010
 EXPOSE 8010
 
 
 HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
 HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
-    CMD /opt/conda/bin/python -c "import urllib.request; urllib.request.urlopen('http://localhost:8010/health')" || exit 1
+    CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8010/health')" || exit 1
 
 
 CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8010"]
 CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8010"]

+ 1 - 0
backend/app/config.py

@@ -95,6 +95,7 @@ class Settings(BaseSettings):
     compute_node_ssh_user: str = "root"
     compute_node_ssh_user: str = "root"
     compute_node_ssh_password: str = ""  # SSH 密码(与密钥二选一)
     compute_node_ssh_password: str = ""  # SSH 密码(与密钥二选一)
     compute_node_ssh_key: str = ""  # SSH 私钥路径
     compute_node_ssh_key: str = ""  # SSH 私钥路径
+    compute_node_docker_container: str = "finetune-trainer"  # 算力节点上的训练容器名
     compute_node_python: str = "/opt/conda/bin/python"
     compute_node_python: str = "/opt/conda/bin/python"
     compute_node_workdir: str = "/root/Fine-tuning/backend"
     compute_node_workdir: str = "/root/Fine-tuning/backend"
     compute_node_remote_data_dir: str = "/root/Fine-tuning/backend/data"
     compute_node_remote_data_dir: str = "/root/Fine-tuning/backend/data"

+ 120 - 8
backend/app/core/job_queue.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import json
 from datetime import datetime, timezone
 from datetime import datetime, timezone
 from enum import Enum
 from enum import Enum
 from typing import Any, Callable, Coroutine, Optional
 from typing import Any, Callable, Coroutine, Optional
@@ -193,17 +194,15 @@ class JobQueue:
                 self.update_job(job_id, status=JobStatus.TRAINING)
                 self.update_job(job_id, status=JobStatus.TRAINING)
                 await self._notify_callbacks()
                 await self._notify_callbacks()
 
 
-                from app.core.remote_executor import run_training_remote
-                success = run_training_remote(job_id, model_id, model_type, dataset_id, config)
+                from app.core.remote_executor import run_training_remote, is_process_running
+                pid = run_training_remote(job_id, model_id, model_type, dataset_id, config)
 
 
-                if not success:
+                if not pid:
                     raise RuntimeError("Failed to launch remote training")
                     raise RuntimeError("Failed to launch remote training")
 
 
-                # 远程模式下,训练完成后通过日志解析或轮询获取 adapter_path
-                # 这里暂时标记为完成,adapter_path 由远程脚本回写
-                self.update_job(job_id, status=JobStatus.COMPLETED,
-                                adapter_path=str(settings.adapters_dir / job_id))
-                await self._notify_callbacks()
+                # 轮询共享日志文件解析进度
+                await self._poll_remote_progress(job_id, pid)
+
                 logger.info(f"Remote training launched for job {job_id}")
                 logger.info(f"Remote training launched for job {job_id}")
             else:
             else:
                 # 本地训练模式
                 # 本地训练模式
@@ -276,6 +275,119 @@ class JobQueue:
             from app.engines.text_engine import text_engine
             from app.engines.text_engine import text_engine
             return text_engine
             return text_engine
 
 
+    async def _poll_remote_progress(self, job_id: str, pid: str):
+        """轮询共享日志文件,解析远程训练进度并通过 WebSocket 推送。"""
+        from app.config import get_settings
+        from app.core.websocket import send_progress, send_epoch_done, send_completed, send_error
+
+        settings = get_settings()
+        log_file = settings.data_dir / "logs" / f"{job_id}.jsonl"
+        last_offset = 0
+        poll_interval = 5  # 每 5 秒轮询一次
+        max_polls = 8640  # 最多轮询 12 小时 (8640 * 5s)
+
+        for _ in range(max_polls):
+            if self.is_cancelled(job_id):
+                # 取消容器内的远程进程
+                from app.core.remote_executor import ssh_exec
+                from app.config import get_settings
+                _s = get_settings()
+                ssh_exec(f"docker exec {_s.compute_node_docker_container} bash -c 'kill {pid} 2>/dev/null'", timeout=10)
+                self.update_job(job_id, status=JobStatus.CANCELLED)
+                await self._notify_callbacks()
+                await send_error(job_id, "Training cancelled")
+                return
+
+            # 检查进程是否还在运行
+            from app.core.remote_executor import is_process_running
+            process_alive = is_process_running(pid)
+
+            # 读取新的日志行
+            if log_file.exists():
+                try:
+                    with open(log_file, "r", encoding="utf-8") as f:
+                        f.seek(last_offset)
+                        new_lines = f.readlines()
+                        last_offset = f.tell()
+
+                    for line in new_lines:
+                        line = line.strip()
+                        if not line:
+                            continue
+                        try:
+                            entry = json.loads(line)
+                        except json.JSONDecodeError:
+                            continue
+
+                        entry_type = entry.get("type")
+                        if entry_type == "progress":
+                            self.update_job(job_id,
+                                            epoch=entry.get("epoch", 0),
+                                            current_step=entry.get("step", 0),
+                                            total_steps=entry.get("total_steps", 0),
+                                            loss=entry.get("loss"),
+                                            progress=round(entry.get("step", 0) / max(entry.get("total_steps", 1), 1) * 100, 1))
+                            await self._notify_callbacks()
+                            await send_progress(job_id, **{k: v for k, v in entry.items() if k != "type"})
+
+                        elif entry_type == "epoch_begin":
+                            self.update_job(job_id, current_epoch=entry.get("epoch", 0))
+                            await self._notify_callbacks()
+
+                        elif entry_type == "epoch_done":
+                            await self._notify_callbacks()
+                            await send_epoch_done(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
+
+                        elif entry_type == "completed":
+                            adapter_path = entry.get("adapter_path", str(settings.adapters_dir / job_id))
+                            self.update_job(job_id,
+                                            status=JobStatus.COMPLETED,
+                                            adapter_path=adapter_path,
+                                            progress=100.0)
+                            await self._notify_callbacks()
+                            await send_completed(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
+                            return
+
+                        elif entry_type == "error":
+                            self.update_job(job_id,
+                                            status=JobStatus.FAILED,
+                                            error_message=entry.get("message", "Unknown error"))
+                            await self._notify_callbacks()
+                            await send_error(job_id, entry.get("message", "Unknown error"))
+                            return
+                except Exception as e:
+                    logger.warning(f"Error reading remote log file: {e}")
+
+            # 进程已退出但日志里没有 completed/error,可能异常退出
+            if not process_alive:
+                # 再等一轮确认
+                await asyncio.sleep(2)
+                if not is_process_running(pid):
+                    # 检查日志里是否有最终状态
+                    if log_file.exists():
+                        try:
+                            with open(log_file, "r", encoding="utf-8") as f:
+                                content = f.read()
+                                if "completed" in content or "error" in content:
+                                    # 上面已经处理过了
+                                    continue
+                        except Exception:
+                            pass
+                    # 进程退出但没有最终状态,视为失败
+                    self.update_job(job_id,
+                                    status=JobStatus.FAILED,
+                                    error_message=f"Remote process exited unexpectedly (pid={pid})")
+                    await self._notify_callbacks()
+                    await send_error(job_id, f"Remote process exited unexpectedly (pid={pid})")
+                    return
+
+            await asyncio.sleep(poll_interval)
+
+        # 超时
+        self.update_job(job_id, status=JobStatus.FAILED, error_message="Remote training timed out")
+        await self._notify_callbacks()
+        await send_error(job_id, "Remote training timed out")
+
     @property
     @property
     def jobs(self) -> dict[str, TrainingJob]:
     def jobs(self) -> dict[str, TrainingJob]:
         return dict(self._jobs)
         return dict(self._jobs)

+ 21 - 15
backend/app/core/remote_executor.py

@@ -56,34 +56,40 @@ def run_training_remote(
     model_type: str,
     model_type: str,
     dataset_id: str,
     dataset_id: str,
     config: dict[str, Any],
     config: dict[str, Any],
-) -> bool:
-    """在算力节点启动训练任务(后台执行,不阻塞)。
+) -> str | None:
+    """在算力节点启动训练任务(通过 docker exec,后台执行)。
 
 
-    使用 nohup + & 让训练在后台运行,通过 WebSocket 回传进度
+    在容器内用 nohup 启动训练,返回 PID 以便后续检测
     """
     """
     config_json = json.dumps(config, ensure_ascii=False)
     config_json = json.dumps(config, ensure_ascii=False)
-    # 转义双引号避免 shell 解析问题
     config_escaped = config_json.replace('"', '\\"')
     config_escaped = config_json.replace('"', '\\"')
 
 
-    log_path = os.path.join(settings.compute_node_workdir, f"logs/{job_id}.log")
-    log_dir = os.path.dirname(log_path)
-
     remote_cmd = (
     remote_cmd = (
-        f"mkdir -p {log_dir} && "
-        f"cd {settings.compute_node_workdir} && "
-        f"nohup {settings.compute_node_python} -m app.engines.remote_train "
+        f"docker exec {settings.compute_node_docker_container} "
+        f"bash -c 'nohup {settings.compute_node_python} -m app.engines.remote_train "
         f"'{job_id}' '{model_id}' '{model_type}' '{dataset_id}' '{config_escaped}' "
         f"'{job_id}' '{model_id}' '{model_type}' '{dataset_id}' '{config_escaped}' "
-        f"> {log_path} 2>&1 & echo $!"
+        f">/tmp/train_{job_id}.log 2>&1 & echo $!'"
     )
     )
 
 
     code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
     code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
 
 
     if code != 0:
     if code != 0:
         logger.error(f"Remote training launch failed: {stderr}")
         logger.error(f"Remote training launch failed: {stderr}")
-        return False
+        return None
+
+    pid = stdout.strip()
+    logger.info(f"Remote training launched in container: job={job_id}, container_pid={pid}")
+    return pid
+
 
 
-    logger.info(f"Remote training launched: job={job_id}, pid={stdout.strip()}")
-    return True
+def is_process_running(pid: str) -> bool:
+    """检查远程训练进程是否还在运行。
+
+    通过 docker exec 进入容器检查 PID 是否存在。
+    """
+    cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'kill -0 {pid} 2>/dev/null && echo running || echo stopped'"
+    code, stdout, stderr = ssh_exec(cmd, timeout=10)
+    return code == 0 and "running" in stdout
 
 
 
 
 def run_inference_remote(
 def run_inference_remote(
@@ -100,7 +106,7 @@ def run_inference_remote(
     safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
     safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
 
 
     remote_cmd = (
     remote_cmd = (
-        f"cd {settings.compute_node_workdir} && "
+        f"docker exec {settings.compute_node_docker_container} "
         f"{settings.compute_node_python} -c \""
         f"{settings.compute_node_python} -c \""
         "import asyncio, json; "
         "import asyncio, json; "
         "from app.config import get_settings; "
         "from app.config import get_settings; "

+ 3 - 2
backend/app/engines/multimodal_engine.py

@@ -73,6 +73,7 @@ class MultimodalEngine(BaseEngine):
         dataset_path: str,
         dataset_path: str,
         peft_config: Any,
         peft_config: Any,
         training_args: dict[str, Any],
         training_args: dict[str, Any],
+        callbacks: list | None = None,
     ) -> str:
     ) -> str:
         from peft import get_peft_model
         from peft import get_peft_model
         from transformers import Trainer, TrainingArguments
         from transformers import Trainer, TrainingArguments
@@ -125,13 +126,13 @@ class MultimodalEngine(BaseEngine):
             report_to="none",
             report_to="none",
         )
         )
 
 
-        callback = _ProgressCallback(job_id)
+        all_callbacks = callbacks if callbacks else [_ProgressCallback(job_id)]
         trainer = Trainer(
         trainer = Trainer(
             model=self._model,
             model=self._model,
             args=tr_args,
             args=tr_args,
             train_dataset=hf_dataset,
             train_dataset=hf_dataset,
             data_collator=collate_fn,
             data_collator=collate_fn,
-            callbacks=[callback],
+            callbacks=all_callbacks,
         )
         )
 
 
         try:
         try:

+ 139 - 60
backend/app/engines/remote_train.py

@@ -4,12 +4,73 @@ import json
 import os
 import os
 import sys
 import sys
 import signal
 import signal
+import time
+import traceback
+from datetime import datetime, timezone
 from pathlib import Path
 from pathlib import Path
 
 
 # 禁用 FlashAttention
 # 禁用 FlashAttention
 os.environ["PYTORCH_NO_FLASH"] = "1"
 os.environ["PYTORCH_NO_FLASH"] = "1"
 os.environ["FLASH_ATTENTION_ENABLED"] = "0"
 os.environ["FLASH_ATTENTION_ENABLED"] = "0"
 
 
+_progress_log_file = None
+
+
+def _init_log_file(data_dir: Path, job_id: str):
+    """初始化进度日志文件(通过 SSHFS 共享给主节点读取)。"""
+    global _progress_log_file
+    log_dir = data_dir / "logs"
+    log_dir.mkdir(parents=True, exist_ok=True)
+    _progress_log_file = log_dir / f"{job_id}.jsonl"
+    _write_log(type="start", job_id=job_id)
+
+
+def _write_log(**kwargs):
+    """追加一行 JSON 到共享日志文件。"""
+    if _progress_log_file:
+        entry = {"ts": datetime.now(timezone.utc).isoformat(), **kwargs}
+        with open(_progress_log_file, "a", encoding="utf-8") as f:
+            f.write(json.dumps(entry, ensure_ascii=False) + "\n")
+            f.flush()
+
+
+class FileProgressCallback:
+    """HuggingFace Trainer 回调 — 写进度到共享日志文件。"""
+
+    def __init__(self, job_id: str):
+        self.job_id = job_id
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        if logs and "loss" in logs:
+            _write_log(type="progress", epoch=int(state.epoch or 0),
+                       step=state.global_step, total_steps=state.max_steps or 0,
+                       loss=round(logs["loss"], 4),
+                       learning_rate=round(logs.get("learning_rate", 0), 8))
+
+    def on_epoch_begin(self, args, state, control, **kwargs):
+        _write_log(type="epoch_begin", epoch=int(state.epoch or 0))
+
+    def on_epoch_end(self, args, state, control, metrics=None, **kwargs):
+        _write_log(type="epoch_done", epoch=int(state.epoch or 0),
+                   eval_loss=metrics.get("eval_loss") if metrics and hasattr(metrics, "get") else None,
+                   eval_accuracy=metrics.get("eval_accuracy") if metrics and hasattr(metrics, "get") else None)
+
+    def on_train_end(self, args, state, control, **kwargs):
+        _write_log(type="completed", total_time_seconds=getattr(state, "train_runtime", 0),
+                   adapter_path=args.output_dir)
+
+    def on_train_begin(self, args, state, control, **kwargs):
+        _write_log(type="status", status="training")
+
+    def on_save(self, args, state, control, **kwargs):
+        _write_log(type="save", step=state.global_step)
+
+    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
+        if metrics:
+            _write_log(type="evaluate", epoch=int(state.epoch or 0),
+                       eval_loss=metrics.get("eval_loss"),
+                       eval_accuracy=metrics.get("eval_accuracy"))
+
 
 
 async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict):
 async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict):
     """执行单个训练任务(远程调用入口)。"""
     """执行单个训练任务(远程调用入口)。"""
@@ -17,66 +78,84 @@ async def run_training(job_id: str, model_id: str, model_type: str, dataset_id:
     from app.core.logging import logger
     from app.core.logging import logger
 
 
     settings = get_settings()
     settings = get_settings()
-
-    # 查找数据集
-    from app.core.db import async_session, DatasetRecord
-    from sqlalchemy import select
-
-    dataset_path = None
-    async with async_session() as session:
-        result = await session.execute(select(DatasetRecord).where(
-            (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
-        ))
-        record = result.scalar_one_or_none()
-        if record:
-            dataset_path = record.file_path
-
-    if not dataset_path:
-        # 尝试 uploads 目录
-        upload_path = settings.uploads_dir / dataset_id
-        if upload_path.exists():
-            dataset_path = str(upload_path)
-
-    if not dataset_path:
-        raise FileNotFoundError(f"Dataset not found: {dataset_id}")
-
-    # 预处理
-    processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
-    task_type = config.get("task_type", "sft")
-    template = config.get("dataset_template", "alpaca")
-
-    # 选择引擎
-    if model_type == "vision":
-        from app.engines.vision_engine import vision_engine
-        engine = vision_engine
-    elif model_type == "multimodal":
-        from app.engines.multimodal_engine import multimodal_engine
-        engine = multimodal_engine
-    else:
-        from app.engines.text_engine import text_engine
-        engine = text_engine
-
-    peft_method = config.get("peft_method", "lora")
-
-    # 预处理数据集
-    await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
-
-    # 加载模型
-    await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
-
-    # 构建 PEFT 配置
-    peft_config = engine.get_peft_config(peft_method, config)
-
-    # 训练
-    adapter_path = await engine.train(
-        job_id=job_id,
-        dataset_path=processed_path,
-        peft_config=peft_config,
-        training_args=config,
-    )
-
-    logger.info(f"Remote training completed: {job_id} -> {adapter_path}")
-    return adapter_path
+    _init_log_file(settings.data_dir, job_id)
+
+    try:
+        # 查找数据集
+        from app.core.db import async_session, DatasetRecord
+        from sqlalchemy import select
+
+        dataset_path = None
+        async with async_session() as session:
+            result = await session.execute(select(DatasetRecord).where(
+                (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
+            ))
+            record = result.scalar_one_or_none()
+            if record:
+                dataset_path = record.file_path
+
+        if not dataset_path:
+            upload_path = settings.uploads_dir / dataset_id
+            if upload_path.exists():
+                dataset_path = str(upload_path)
+
+        if not dataset_path:
+            raise FileNotFoundError(f"Dataset not found: {dataset_id}")
+
+        _write_log(type="status", status="preprocessing")
+
+        # 预处理
+        processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
+        task_type = config.get("task_type", "sft")
+        template = config.get("dataset_template", "alpaca")
+
+        # 选择引擎
+        if model_type == "vision":
+            from app.engines.vision_engine import vision_engine
+            engine = vision_engine
+        elif model_type == "multimodal":
+            from app.engines.multimodal_engine import multimodal_engine
+            engine = multimodal_engine
+        else:
+            from app.engines.text_engine import text_engine
+            engine = text_engine
+
+        peft_method = config.get("peft_method", "lora")
+
+        await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
+
+        _write_log(type="status", status="loading_model")
+
+        # 加载模型
+        await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
+
+        # 构建 PEFT 配置
+        peft_config = engine.get_peft_config(peft_method, config)
+
+        _write_log(type="status", status="training")
+
+        # 训练 — 传入文件日志回调替代 WebSocket 回调
+        start_time = time.time()
+        file_cb = FileProgressCallback(job_id)
+
+        adapter_path = await engine.train(
+            job_id=job_id,
+            dataset_path=processed_path,
+            peft_config=peft_config,
+            training_args=config,
+            callbacks=[file_cb],
+        )
+
+        elapsed = round(time.time() - start_time, 2)
+        _write_log(type="completed", adapter_path=str(adapter_path), total_time=elapsed)
+
+        logger.info(f"Remote training completed: {job_id} -> {adapter_path} ({elapsed}s)")
+        return adapter_path
+
+    except Exception as e:
+        _write_log(type="error", message=str(e), traceback=traceback.format_exc())
+        logger.error(f"Remote training failed: {job_id} - {e}")
+        raise
 
 
 
 
 def main():
 def main():

+ 5 - 3
backend/app/engines/text_engine.py

@@ -126,6 +126,7 @@ class TextEngine(BaseEngine):
         dataset_path: str,
         dataset_path: str,
         peft_config: Any,
         peft_config: Any,
         training_args: dict[str, Any],
         training_args: dict[str, Any],
+        callbacks: list | None = None,
     ) -> str:
     ) -> str:
         """执行训练。"""
         """执行训练。"""
         from peft import get_peft_model
         from peft import get_peft_model
@@ -167,7 +168,8 @@ class TextEngine(BaseEngine):
             **({"deepspeed": deepspeed_config} if deepspeed_config else {}),
             **({"deepspeed": deepspeed_config} if deepspeed_config else {}),
         )
         )
 
 
-        callback = _ProgressCallback(job_id)
+        # 本地模式用 WebSocket 回调,远程模式用传入的文件日志回调
+        all_callbacks = callbacks if callbacks else [_ProgressCallback(job_id)]
 
 
         if task_type == "sft":
         if task_type == "sft":
             from transformers import Trainer
             from transformers import Trainer
@@ -177,7 +179,7 @@ class TextEngine(BaseEngine):
                 args=tr_args,
                 args=tr_args,
                 train_dataset=dataset,
                 train_dataset=dataset,
                 data_collator=DataCollatorForSeq2Seq(self._tokenizer),
                 data_collator=DataCollatorForSeq2Seq(self._tokenizer),
-                callbacks=[callback],
+                callbacks=all_callbacks,
             )
             )
         else:
         else:
             from trl import (
             from trl import (
@@ -229,7 +231,7 @@ class TextEngine(BaseEngine):
                     args=tr_args,
                     args=tr_args,
                     train_dataset=dataset,
                     train_dataset=dataset,
                     data_collator=DataCollatorForSeq2Seq(self._tokenizer),
                     data_collator=DataCollatorForSeq2Seq(self._tokenizer),
-                    callbacks=[callback],
+                    callbacks=all_callbacks,
                 )
                 )
 
 
         try:
         try:

+ 3 - 2
backend/app/engines/vision_engine.py

@@ -73,6 +73,7 @@ class VisionEngine(BaseEngine):
         dataset_path: str,
         dataset_path: str,
         peft_config: Any,
         peft_config: Any,
         training_args: dict[str, Any],
         training_args: dict[str, Any],
+        callbacks: list | None = None,
     ) -> str:
     ) -> str:
         from peft import get_peft_model
         from peft import get_peft_model
         from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
         from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
@@ -125,13 +126,13 @@ class VisionEngine(BaseEngine):
             report_to="none",
             report_to="none",
         )
         )
 
 
-        callback = _ProgressCallback(job_id)
+        all_callbacks = callbacks if callbacks else [_ProgressCallback(job_id)]
         trainer = Trainer(
         trainer = Trainer(
             model=self._model,
             model=self._model,
             args=tr_args,
             args=tr_args,
             train_dataset=hf_dataset,
             train_dataset=hf_dataset,
             data_collator=DataCollatorWithPadding(self._processor),
             data_collator=DataCollatorWithPadding(self._processor),
-            callbacks=[callback],
+            callbacks=all_callbacks,
         )
         )
 
 
         try:
         try:

+ 1 - 0
docker-compose.yml

@@ -41,6 +41,7 @@ services:
       - COMPUTE_NODE_SSH_PASSWORD=ictrek
       - COMPUTE_NODE_SSH_PASSWORD=ictrek
       # - COMPUTE_NODE_SSH_KEY=/root/.ssh/id_rsa  # 优先用密钥,密码为备选
       # - COMPUTE_NODE_SSH_KEY=/root/.ssh/id_rsa  # 优先用密钥,密码为备选
       - COMPUTE_NODE_PYTHON=/opt/conda/bin/python
       - COMPUTE_NODE_PYTHON=/opt/conda/bin/python
+      - COMPUTE_NODE_DOCKER_CONTAINER=finetune-trainer
       - COMPUTE_NODE_WORKDIR=/root/Fine-tuning/backend
       - COMPUTE_NODE_WORKDIR=/root/Fine-tuning/backend
       - COMPUTE_NODE_REMOTE_DATA_DIR=/root/Fine-tuning/backend/data
       - COMPUTE_NODE_REMOTE_DATA_DIR=/root/Fine-tuning/backend/data
       - COMPUTE_NODE_REMOTE_ENV=production
       - COMPUTE_NODE_REMOTE_ENV=production

+ 19 - 49
result.txt

@@ -1,49 +1,19 @@
-(base) [root@localhost Fine-tuning]# docker exec finetune-backend pip install --upgrade datasets
-Looking in indexes: http://mirrors.aliyun.com/pypi/simple
-Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (4.8.5)
-Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from datasets) (3.29.0)
-Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.26.4)
-Requirement already satisfied: pyarrow>=21.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (24.0.0)
-Requirement already satisfied: dill<0.4.2,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.4.1)
-Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (2.3.3)
-Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.32.3)
-Requirement already satisfied: httpx<1.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.28.1)
-Requirement already satisfied: tqdm>=4.66.3 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.67.1)
-Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets) (3.7.0)
-Requirement already satisfied: multiprocess<0.70.20 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.70.19)
-Requirement already satisfied: fsspec<=2026.2.0,>=2023.1.0 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (2025.5.1)
-Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.14.0)
-Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (26.2)
-Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (6.0.3)
-Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (3.13.5)
-Requirement already satisfied: anyio in /opt/conda/lib/python3.10/site-packages (from httpx<1.0.0->datasets) (4.13.0)
-Requirement already satisfied: certifi in /opt/conda/lib/python3.10/site-packages (from httpx<1.0.0->datasets) (2026.4.22)
-Requirement already satisfied: httpcore==1.* in /opt/conda/lib/python3.10/site-packages (from httpx<1.0.0->datasets) (1.0.9)
-Requirement already satisfied: idna in /opt/conda/lib/python3.10/site-packages (from httpx<1.0.0->datasets) (3.10)
-Requirement already satisfied: h11>=0.16 in /opt/conda/lib/python3.10/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)
-Requirement already satisfied: hf-xet<2.0.0,>=1.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.4.3)
-Requirement already satisfied: typer>=0.20.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (0.25.0)
-Requirement already satisfied: typing-extensions>=4.1.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)
-Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (2.6.1)
-Requirement already satisfied: aiosignal>=1.4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.4.0)
-Requirement already satisfied: async-timeout<6.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (5.0.1)
-Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (26.1.0)
-Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.8.0)
-Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (6.7.1)
-Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (0.4.1)
-Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.23.0)
-Requirement already satisfied: charset_normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (3.4.1)
-Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (2.3.0)
-Requirement already satisfied: click>=8.2.1 in /opt/conda/lib/python3.10/site-packages (from typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (8.2.1)
-Requirement already satisfied: shellingham>=1.3.0 in /opt/conda/lib/python3.10/site-packages (from typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (1.5.4)
-Requirement already satisfied: rich>=13.8.0 in /opt/conda/lib/python3.10/site-packages (from typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (15.0.0)
-Requirement already satisfied: annotated-doc>=0.0.2 in /opt/conda/lib/python3.10/site-packages (from typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (0.0.4)
-Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich>=13.8.0->typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (4.0.0)
-Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich>=13.8.0->typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (2.19.2)
-Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=13.8.0->typer>=0.20.0->huggingface-hub<2.0,>=0.25.0->datasets) (0.1.2)
-Requirement already satisfied: exceptiongroup>=1.0.2 in /opt/conda/lib/python3.10/site-packages (from anyio->httpx<1.0.0->datasets) (1.3.0)
-Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)
-Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2026.1.post1)
-Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2026.2)
-Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)
-WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
+(base) [root@localhost backend]# sshpass -p 'lq123456!' ssh -o StrictHostKeyChecking=no lq@192.168.92.151 "ls -la /home/lq/Fine-tuning/"
+total 104
+drwxrwxr-x 5 lq lq  4096 May 19 03:50 .
+drwxr-xr-x 8 lq lq  4096 May 19 08:05 ..
+drwxrwxr-x 4 lq lq  4096 May 19 03:59 backend
+-rw-rw-r-- 1 lq lq  5050 May 19 03:49 CLAUDE.md
+-rw-rw-r-- 1 lq lq  2031 May 19 03:50 DEPLOY.md
+-rw-rw-r-- 1 lq lq  1892 May 19 03:50 docker-compose.yml
+-rw-rw-r-- 1 lq lq    86 May 19 03:49 .dockerignore
+-rw-rw-r-- 1 lq lq   543 May 19 03:49 .env
+-rw-rw-r-- 1 lq lq   375 May 19 03:49 .env.example
+drwxrwxr-x 3 lq lq  4096 May 19 03:49 frontend
+drwxrwxr-x 8 lq lq  4096 May 19 03:50 .git
+-rw-rw-r-- 1 lq lq   210 May 19 03:49 .gitignore
+-rw-rw-r-- 1 lq lq     0 May 19 03:49 peft-finetune-frontend@0.1.0
+-rw-rw-r-- 1 lq lq  3485 May 19 03:49 README.md
+-rw-rw-r-- 1 lq lq  7036 May 19 03:49 result.txt
+-rw-rw-r-- 1 lq lq 15237 May 19 03:50 样本中心提供API接口文档_外部.md
+-rw-rw-r-- 1 lq lq 20943 May 19 03:50 统一认证平台接入流程及API接口文档(1)(1).md