Sfoglia il codice sorgente

增加样本中心接口,改为单卡训练

lxylxy123321 5 giorni fa
parent
commit
a95623c0f8

+ 23 - 3
backend/app/api/datasets.py

@@ -7,6 +7,7 @@ from app.schemas.dataset import (
     DatasetUploadResponse,
     DatasetValidationResult,
 )
+from app.schemas.background_task import DatasetDownloadTaskResponse
 from app.services import dataset_service
 
 router = APIRouter()
@@ -14,13 +15,32 @@ router = APIRouter()
 
 @router.post("/download", response_model=DatasetDownloadResponse, status_code=200)
 async def download_dataset(req: DatasetDownloadRequest):
-    """从 HuggingFace 或 ModelScope 下载数据集。"""
+    """启动数据集下载后台任务,立即返回 task_id。"""
     result = await dataset_service.download_dataset(req)
-    if result.status == "failed":
-        raise HTTPException(status_code=400, detail=result.error or "Dataset download failed")
     return result
 
 
+@router.get("/download/{task_id}")
+async def get_dataset_download_status(task_id: str):
+    """查询数据集下载任务状态。"""
+    result = await dataset_service.get_dataset_download_status(task_id)
+    if result.get("status") == "not_found":
+        raise HTTPException(status_code=404, detail="Download task not found")
+    return result
+
+
+@router.get("/downloads")
+async def list_dataset_downloads():
+    """列出所有数据集下载任务。"""
+    return await dataset_service.list_dataset_downloads()
+
+
+@router.post("/download/{task_id}/cancel")
+async def cancel_dataset_download(task_id: str):
+    """取消数据集下载任务。"""
+    return await dataset_service.cancel_dataset_download(task_id)
+
+
 @router.post("/upload", response_model=DatasetUploadResponse, status_code=201)
 async def upload_dataset(file: UploadFile = File(...)):
     """上传数据集文件(JSONL / CSV / Parquet / JSON)。"""

+ 1 - 1
backend/app/api/deployment.py

@@ -8,7 +8,7 @@ router = APIRouter()
 
 @router.post("/export", response_model=DeployResponse)
 async def export_adapter(config: DeployConfig):
-    """合并 adapter 与基础模型,可选导出为 GGUF。"""
+    """启动导出后台任务,立即返回 task_id。"""
     result = await deploy_service.export_adapter(
         config.job_id,
         {"merge_with_base": config.merge_with_base, "export_format": config.export_format},

+ 5 - 3
backend/app/api/evaluation.py

@@ -1,4 +1,4 @@
-from fastapi import APIRouter
+from fastapi import APIRouter, HTTPException
 
 from app.schemas.evaluation import EvalConfig, EvalResult
 from app.services import eval_service
@@ -8,13 +8,15 @@ router = APIRouter()
 
 @router.post("/run", response_model=EvalResult)
 async def run_evaluation(config: EvalConfig):
-    """对已训练的 adapter 运行评估。"""
+    """启动评估后台任务,立即返回 eval_id。"""
     result = await eval_service.run_evaluation(config.job_id, config.model_dump())
     return EvalResult(**result)
 
 
 @router.get("/{eval_id}/results", response_model=EvalResult)
 async def get_evaluation_results(eval_id: str):
-    """获取已完成评估结果。"""
+    """获取评估结果或状态。"""
     result = await eval_service.get_evaluation_results(eval_id)
+    if result.get("status") == "not_found":
+        raise HTTPException(status_code=404, detail="Evaluation not found")
     return EvalResult(**result)

+ 26 - 9
backend/app/api/models.py

@@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException
 
 from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
 from app.schemas.model_test import ModelTestRequest, ModelTestResponse
+from app.schemas.background_task import ModelDownloadTaskResponse
 from app.services import model_service, model_test_service
 
 router = APIRouter()
@@ -25,22 +26,38 @@ async def list_models():
     ]
 
 
-@router.post("/download", response_model=ModelDownloadResponse, status_code=200)
+@router.post("/download", response_model=ModelDownloadTaskResponse, status_code=200)
 async def download_model(req: ModelDownloadRequest):
-    """从 HuggingFace 或 ModelScope 下载模型。"""
+    """启动模型下载后台任务,立即返回 task_id。"""
     result = await model_service.download_model(req.model_id, req.use_modelscope)
-
-    if result["status"] == "failed":
-        raise HTTPException(status_code=400, detail=result.get("error", "Download failed"))
-
-    return ModelDownloadResponse(
+    return ModelDownloadTaskResponse(
+        task_id=result["task_id"],
         model_id=result["model_id"],
         status=result["status"],
-        path=result.get("path"),
-        error=result.get("error"),
     )
 
 
+@router.get("/download/{task_id}", response_model=ModelDownloadTaskResponse)
+async def get_model_download_status(task_id: str):
+    """查询模型下载任务状态。"""
+    result = await model_service.get_model_download_status(task_id)
+    if result.get("status") == "not_found":
+        raise HTTPException(status_code=404, detail="Download task not found")
+    return ModelDownloadTaskResponse(**result)
+
+
+@router.get("/downloads")
+async def list_model_downloads():
+    """列出所有模型下载任务。"""
+    return await model_service.list_model_downloads()
+
+
+@router.post("/download/{task_id}/cancel")
+async def cancel_model_download(task_id: str):
+    """取消模型下载任务。"""
+    return await model_service.cancel_model_download(task_id)
+
+
 @router.get("/{model_id}", response_model=ModelInfo)
 async def get_model_info(model_id: str):
     """获取已缓存模型的详细信息。"""

+ 1 - 1
backend/app/config.py

@@ -57,7 +57,7 @@ class Settings(BaseSettings):
     modelscope_endpoint: str = "https://modelscope.cn"
 
     # --- GPU / 硬件 ---
-    cuda_visible_devices: str = "0"
+    cuda_visible_devices: str = "3"
     max_memory_per_gpu: str = "0"
     use_unsloth: bool = False
 

+ 99 - 0
backend/app/core/background_tasks.py

@@ -0,0 +1,99 @@
+"""轻量后台任务管理器,按类型控制并发,用 asyncio.create_task 执行。"""
+
+import asyncio
+from datetime import datetime, timedelta, timezone
+from typing import Any, Callable, Coroutine, Optional
+
+from app.core.logging import logger
+
+
+class BackgroundTaskManager:
+    def __init__(self):
+        self._tasks: dict[str, dict[str, Any]] = {}
+        self._type_semaphores: dict[str, asyncio.Semaphore] = {}
+        self._callbacks: list[Callable[[str, dict], None]] = []
+
+    def set_concurrency(self, task_type: str, limit: int) -> None:
+        self._type_semaphores[task_type] = asyncio.Semaphore(limit)
+
+    def register_callback(self, callback: Callable[[str, dict], None]) -> None:
+        self._callbacks.append(callback)
+
+    def register_task(self, task_id: str, task_type: str, metadata: dict | None = None) -> None:
+        self._tasks[task_id] = {
+            "task_type": task_type,
+            "status": "pending",
+            "progress": 0.0,
+            "error": None,
+            "created_at": datetime.now(timezone.utc),
+            **(metadata or {}),
+        }
+
+    def update_task(self, task_id: str, **kwargs) -> None:
+        if task_id in self._tasks:
+            self._tasks[task_id].update(kwargs)
+            for cb in self._callbacks:
+                try:
+                    cb(task_id, dict(self._tasks[task_id]))
+                except Exception:
+                    pass
+
+    def get_task(self, task_id: str) -> Optional[dict[str, Any]]:
+        return self._tasks.get(task_id)
+
+    def list_tasks_by_type(self, task_type: str) -> dict[str, dict[str, Any]]:
+        return {tid: t for tid, t in self._tasks.items() if t.get("task_type") == task_type}
+
+    async def run(self, task_id: str, task_type: str, coro: Coroutine) -> None:
+        sem = self._type_semaphores.get(task_type)
+
+        async def _wrapped() -> None:
+            if sem:
+                async with sem:
+                    self.update_task(task_id, status="running", progress=0.0)
+                    try:
+                        result = await coro
+                        self.update_task(
+                            task_id, status="completed", progress=100.0, **(result or {})
+                        )
+                    except Exception as e:
+                        self.update_task(task_id, status="failed", error=str(e))
+                        logger.error(f"Background task {task_id} ({task_type}) failed: {e}")
+            else:
+                self.update_task(task_id, status="running", progress=0.0)
+                try:
+                    result = await coro
+                    self.update_task(
+                        task_id, status="completed", progress=100.0, **(result or {})
+                    )
+                except Exception as e:
+                    self.update_task(task_id, status="failed", error=str(e))
+                    logger.error(f"Background task {task_id} ({task_type}) failed: {e}")
+
+        asyncio.create_task(_wrapped())
+
+    def cancel_task(self, task_id: str) -> bool:
+        if task_id in self._tasks and self._tasks[task_id]["status"] in ("pending", "running"):
+            self.update_task(task_id, status="cancelled", error="Cancelled by user")
+            return True
+        return False
+
+    def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
+        cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
+        to_remove = [
+            tid
+            for tid, t in self._tasks.items()
+            if t["status"] in ("completed", "failed", "cancelled")
+            and t.get("created_at")
+            and t["created_at"] < cutoff
+        ]
+        for tid in to_remove:
+            del self._tasks[tid]
+
+    @property
+    def tasks(self) -> dict[str, dict[str, Any]]:
+        return dict(self._tasks)
+
+
+# 全局单例,在 main.py lifespan 中初始化
+background_task_manager = BackgroundTaskManager()

+ 33 - 0
backend/app/core/db.py

@@ -114,7 +114,10 @@ class EvalResultModel(Base):
 
     id = Column(String(36), primary_key=True)
     job_id = Column(String(36), nullable=False)
+    status = Column(String(32), default="pending")  # pending|running|completed|failed
     metrics = Column(Text, default="{}")
+    progress = Column(Float, default=0.0)
+    error = Column(Text, nullable=True)
     created_at = Column(DateTime, default=datetime.utcnow)
 
 
@@ -126,9 +129,39 @@ class DeployTaskModel(Base):
     status = Column(String(32), default="pending")
     output_path = Column(String(512), nullable=True)
     error = Column(Text, nullable=True)
+    progress = Column(Float, default=0.0)
+    finished_at = Column(DateTime, nullable=True)
     created_at = Column(DateTime, default=datetime.utcnow)
 
 
+class ModelDownloadTask(Base):
+    __tablename__ = "model_download_tasks"
+
+    id = Column(String(36), primary_key=True)
+    model_id = Column(String(256), nullable=False)
+    use_modelscope = Column(Integer, default=0)
+    status = Column(String(32), default="pending")  # pending|downloading|completed|failed
+    path = Column(String(512), nullable=True)
+    error = Column(Text, nullable=True)
+    progress = Column(Float, default=0.0)
+    created_at = Column(DateTime, default=datetime.utcnow)
+    finished_at = Column(DateTime, nullable=True)
+
+
+class DatasetDownloadTask(Base):
+    __tablename__ = "dataset_download_tasks"
+
+    id = Column(String(36), primary_key=True)
+    dataset_id = Column(String(256), nullable=False)
+    use_modelscope = Column(Integer, default=0)
+    status = Column(String(32), default="pending")  # pending|downloading|completed|failed
+    path = Column(String(512), nullable=True)
+    error = Column(Text, nullable=True)
+    record_count = Column(Integer, default=0)
+    created_at = Column(DateTime, default=datetime.utcnow)
+    finished_at = Column(DateTime, nullable=True)
+
+
 class UserModel(Base):
     __tablename__ = "users"
 

+ 1 - 1
backend/app/core/remote_deploy.py

@@ -6,7 +6,7 @@ from pathlib import Path
 
 os.environ["PYTORCH_NO_FLASH"] = "1"
 os.environ["MACA_MPS_MODE"] = "1"
-os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
+os.environ["CUDA_VISIBLE_DEVICES"] = "3"
 
 _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
 _ADAPTERS_DIR = _DATA_DIR / "adapters"

+ 2 - 2
backend/app/core/remote_eval.py

@@ -6,7 +6,7 @@ from pathlib import Path
 # 禁用 FlashAttention,启用 MPS
 os.environ["PYTORCH_NO_FLASH"] = "1"
 os.environ["MACA_MPS_MODE"] = "1"
-os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
+os.environ["CUDA_VISIBLE_DEVICES"] = "3"
 
 _DATA_DIR = Path(os.environ.get("COMPUTE_NODE_REMOTE_DATA_DIR", "/root/Fine-tuning/backend/data"))
 _ADAPTERS_DIR = _DATA_DIR / "adapters"
@@ -21,7 +21,7 @@ async def run_remote_eval(job_id: str) -> dict:
     import torch
     from transformers import AutoModelForCausalLM, AutoTokenizer
 
-    # 加载 adapter(CUDA_VISIBLE_DEVICES=2,3 已将物理 GPU 2,3 映射为逻辑 GPU 0,1
+    # 加载 adapter(CUDA_VISIBLE_DEVICES=3 已将物理 GPU 3 映射为逻辑 GPU 0)
     device_map = {"": 0}
 
     model = AutoModelForCausalLM.from_pretrained(

+ 2 - 2
backend/app/core/remote_executor.py

@@ -159,7 +159,7 @@ def run_training_remote(
     remote_cmd = (
         f"docker exec "
         f"-e MACA_MPS_MODE=1 "
-        f"-e CUDA_VISIBLE_DEVICES=2,3 "
+        f"-e CUDA_VISIBLE_DEVICES=3 "
         f"-e PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True "
         f"-w {settings.compute_node_workdir} "
         f"{settings.compute_node_docker_container} "
@@ -229,7 +229,7 @@ def run_inference_remote(
     remote_cmd = (
         f"docker exec "
         f"-e MACA_MPS_MODE=1 "
-        f"-e CUDA_VISIBLE_DEVICES=2,3 "
+        f"-e CUDA_VISIBLE_DEVICES=3 "
         f"-w {settings.compute_node_workdir} "
         f"{settings.compute_node_docker_container} "
         f"{settings.compute_node_python} -c \""

+ 4 - 4
backend/app/engines/remote_train.py

@@ -19,10 +19,10 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 # 禁用 torch.compile,避免 fork 大量 inductor worker 进程
 os.environ["PT2_COMPILE"] = "0"
 os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1"
-# 限制训练只用 GPU 2 和 3(GPU 0/1 被 VLLM 占用)
-# CUDA_VISIBLE_DEVICES 将 2,3 映射为容器内的 cuda:0, cuda:1
-# device_map 中使用相对编号 0(即物理 GPU 2
-os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
+# 限制训练只用 GPU 3(GPU 0/1 被 VLLM 占用,GPU 2 已占用)
+# CUDA_VISIBLE_DEVICES 将 3 映射为容器内的 cuda:0
+# device_map 中使用相对编号 0(即物理 GPU 3
+os.environ["CUDA_VISIBLE_DEVICES"] = "3"
 # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU
 os.environ["MACA_MPS_MODE"] = "1"
 

+ 5 - 6
backend/app/engines/text_engine.py

@@ -10,10 +10,10 @@ os.environ["PT2_COMPILE"] = "0"
 os.environ["TORCHINDUCTOR_MAX_WORKERS"] = "1"
 # 解决 PyTorch 显存碎片化问题(避免 reserved unallocated 占用大量显存)
 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
-# 限制训练只用 GPU 2 和 3(GPU 0/1 被 VLLM 占用)
-# CUDA_VISIBLE_DEVICES 将物理 GPU 2,3 映射为容器内的 cuda:0, cuda:1
-# device_map 中使用相对编号 0(对应物理 GPU 2
-os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
+# 限制训练只用 GPU 3(GPU 0/1 被 VLLM 占用,GPU 2 已占用)
+# CUDA_VISIBLE_DEVICES 将物理 GPU 3 映射为容器内的 cuda:0
+# device_map 中使用相对编号 0(对应物理 GPU 3
+os.environ["CUDA_VISIBLE_DEVICES"] = "3"
 # 启用 MPS 多进程服务,允许与 VLLM 共享 GPU
 os.environ["MACA_MPS_MODE"] = "1"
 
@@ -80,8 +80,7 @@ class TextEngine(BaseEngine):
         else:
             raise RuntimeError("No GPU detected! Training requires GPU.")
 
-        # CUDA_VISIBLE_DEVICES=2,3 已将物理 GPU 2,3 映射为逻辑 GPU 0,1
-        # device_map 直接用 0 即可(对应物理 GPU 2)
+        # CUDA_VISIBLE_DEVICES=3 已将物理 GPU 3 映射为逻辑 GPU 0
         device_map = {"": 0}
 
         load_kwargs: dict[str, Any] = {

+ 23 - 0
backend/app/schemas/background_task.py

@@ -0,0 +1,23 @@
+from pydantic import BaseModel
+
+
+class ModelDownloadTaskResponse(BaseModel):
+    task_id: str
+    model_id: str
+    status: str
+    use_modelscope: bool = False
+    path: str | None = None
+    error: str | None = None
+    progress: float = 0.0
+    created_at: str = ""
+
+
+class DatasetDownloadTaskResponse(BaseModel):
+    task_id: str
+    dataset_id: str
+    status: str
+    use_modelscope: bool = False
+    path: str | None = None
+    error: str | None = None
+    record_count: int = 0
+    created_at: str = ""

+ 1 - 0
backend/app/schemas/deployment.py

@@ -10,5 +10,6 @@ class DeployConfig(BaseModel):
 class DeployResponse(BaseModel):
     job_id: str
     status: str
+    progress: float = 0.0
     output_path: str | None = None
     error: str | None = None

+ 5 - 2
backend/app/schemas/evaluation.py

@@ -11,5 +11,8 @@ class EvalConfig(BaseModel):
 class EvalResult(BaseModel):
     id: str
     job_id: str
-    metrics: dict
-    created_at: str
+    status: str = "pending"  # pending|running|completed|failed
+    progress: float = 0.0
+    metrics: dict = {}
+    error: str | None = None
+    created_at: str = ""

+ 143 - 9
backend/app/services/dataset_service.py

@@ -1,7 +1,7 @@
 import asyncio
 import json
 import uuid
-from datetime import datetime
+from datetime import datetime, timezone
 from pathlib import Path
 from typing import Any
 
@@ -9,9 +9,11 @@ from typing import Any
 from fastapi import UploadFile
 
 from app.config import get_settings
-from app.core.db import async_session, DatasetRecord
+from app.core.background_tasks import background_task_manager
+from app.core.db import async_session, DatasetRecord, DatasetDownloadTask
 from app.core.logging import logger
 from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
+from sqlalchemy import select
 
 settings = get_settings()
 
@@ -73,10 +75,39 @@ def _is_training_data_file(path: Path) -> bool:
 
 
 async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
-    """从 HuggingFace 或 ModelScope 下载数据集。"""
+    """启动数据集下载后台任务,立即返回 task_id。"""
+    task_id = str(uuid.uuid4())
+
+    # 写 DB
+    record = DatasetDownloadTask(
+        id=task_id,
+        dataset_id=req.dataset_id,
+        use_modelscope=1 if req.use_modelscope else 0,
+        status="pending",
+    )
+    async with async_session() as session:
+        session.add(record)
+        await session.commit()
+
+    # 注册并启动
+    background_task_manager.register_task(task_id, "dataset_download", {"dataset_id": req.dataset_id})
+    background_task_manager.run(
+        task_id, "dataset_download", _execute_dataset_download(task_id, req)
+    )
+
+    logger.info(f"Dataset download task started: {req.dataset_id} (task_id={task_id})")
+    return DatasetDownloadResponse(
+        dataset_id=req.dataset_id, status="pending", path=task_id
+    )
+
+
+async def _execute_dataset_download(task_id: str, req: DatasetDownloadRequest) -> dict:
+    """后台执行数据集下载。"""
     try:
         if req.use_modelscope:
-            ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
+            ds_dir, jsonl_path, record_count = await asyncio.to_thread(
+                _download_modelscope_dataset, req.dataset_id
+            )
         else:
             from datasets import load_dataset
 
@@ -94,7 +125,7 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
             jsonl_path = output_path
             record_count = len(split) if hasattr(split, "__len__") else 0
 
-        record = DatasetRecord(
+        db_record = DatasetRecord(
             id=str(uuid.uuid4()),
             name=req.dataset_id,
             format="jsonl",
@@ -103,14 +134,117 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
             created_at=datetime.utcnow(),
         )
         async with async_session() as session:
-            session.add(record)
+            session.add(db_record)
             await session.commit()
 
-        logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
-        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
+        await _update_dataset_download_status(task_id, "completed", path=str(jsonl_path), record_count=record_count)
+
+        logger.info(f"Dataset downloaded: {req.dataset_id} ({record_count} records)")
+        return {"path": str(jsonl_path), "record_count": record_count}
+
     except Exception as e:
         logger.error(f"Dataset download failed: {e}")
-        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
+        await _update_dataset_download_status(task_id, "failed", error=str(e))
+        return {"error": str(e)}
+
+
+async def _update_dataset_download_status(task_id: str, status: str, path: str = None, error: str = None, record_count: int = 0):
+    async with async_session() as session:
+        result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
+        record = result.scalar_one_or_none()
+        if record:
+            record.status = status
+            if path:
+                record.path = path
+            if error:
+                record.error = error
+            if record_count:
+                record.record_count = record_count
+            if status in ("completed", "failed"):
+                record.finished_at = datetime.utcnow()
+            await session.commit()
+
+    background_task_manager.update_task(
+        task_id, status=status, path=path, error=error, record_count=record_count,
+    )
+
+
+async def get_dataset_download_status(task_id: str) -> dict[str, Any]:
+    async with async_session() as session:
+        result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
+        record = result.scalar_one_or_none()
+        if record:
+            return {
+                "task_id": record.id,
+                "dataset_id": record.dataset_id,
+                "status": record.status,
+                "use_modelscope": bool(record.use_modelscope),
+                "path": record.path,
+                "error": record.error,
+                "record_count": record.record_count,
+                "created_at": record.created_at.isoformat() if record.created_at else "",
+            }
+    mem = background_task_manager.get_task(task_id)
+    if mem:
+        return {
+            "task_id": task_id,
+            "dataset_id": mem.get("dataset_id", ""),
+            "status": mem["status"],
+            "error": mem.get("error"),
+            "record_count": mem.get("record_count", 0),
+        }
+    return {"task_id": task_id, "status": "not_found"}
+
+
+async def list_dataset_downloads() -> list[dict[str, Any]]:
+    async with async_session() as session:
+        result = await session.execute(
+            select(DatasetDownloadTask).order_by(DatasetDownloadTask.created_at.desc())
+        )
+        records = result.scalars().all()
+    return [
+        {
+            "task_id": r.id,
+            "dataset_id": r.dataset_id,
+            "status": r.status,
+            "use_modelscope": bool(r.use_modelscope),
+            "path": r.path,
+            "error": r.error,
+            "record_count": r.record_count,
+            "created_at": r.created_at.isoformat() if r.created_at else "",
+        }
+        for r in records
+    ]
+
+
+async def cancel_dataset_download(task_id: str) -> dict[str, Any]:
+    background_task_manager.cancel_task(task_id)
+    async with async_session() as session:
+        result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
+        record = result.scalar_one_or_none()
+        if record and record.status in ("pending", "downloading"):
+            record.status = "cancelled"
+            record.error = "Cancelled by user"
+            record.finished_at = datetime.utcnow()
+            await session.commit()
+    return {"task_id": task_id, "status": "cancelled"}
+
+
+async def recover_stale_downloads() -> None:
+    async with async_session() as session:
+        result = await session.execute(
+            select(DatasetDownloadTask).where(
+                DatasetDownloadTask.status.in_(["pending", "downloading"])
+            )
+        )
+        records = result.scalars().all()
+        for record in records:
+            record.status = "failed"
+            record.error = "Server restarted, task interrupted"
+            record.finished_at = datetime.utcnow()
+        if records:
+            await session.commit()
+            logger.info(f"Recovered {len(records)} stale dataset download tasks")
 
 
 def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:

+ 68 - 38
backend/app/services/deploy_service.py

@@ -1,9 +1,11 @@
+import json
 import uuid
-from datetime import datetime
+from datetime import datetime, timezone
 from pathlib import Path
 from typing import Any
 
 from app.config import get_settings
+from app.core.background_tasks import background_task_manager
 from app.core.db import async_session, DeployTaskModel
 from app.core.logging import logger
 from app.core.remote_executor import ssh_exec
@@ -13,38 +15,49 @@ settings = get_settings()
 
 
 async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
-    """合并 adapter 与基础模型,并可选导出为 GGUF。"""
+    """启动导出后台任务,立即返回 task_id。"""
     task_id = str(uuid.uuid4())
     merge_with_base = config.get("merge_with_base", False)
     export_format = config.get("export_format", "safetensors")
 
-    # 写入数据库
+    # 写 DB
     task = DeployTaskModel(
         id=task_id,
         job_id=job_id,
         status="pending",
-        created_at=datetime.utcnow(),
     )
     async with async_session() as session:
         session.add(task)
         await session.commit()
 
+    # 注册并启动
+    background_task_manager.register_task(task_id, "deployment", {"job_id": job_id})
+    background_task_manager.run(
+        task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format)
+    )
+
+    logger.info(f"Deploy task started: job={job_id} (task_id={task_id})")
+    return {"job_id": job_id, "status": "pending"}
+
+
+async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
+    """后台执行导出。"""
     try:
         # 远程模式:通过 SSH 在算力节点执行
         if settings.use_remote_compute:
             result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
             return result
 
-        # 本地模式(原有逻辑)
+        # 本地模式
         adapter_path = settings.adapters_dir / job_id
         if not adapter_path.exists():
-            return _update_task_status(task_id, "failed", error="Adapter not found")
-
-        output_path = settings.adapters_dir / f"{job_id}_merged"
+            raise ValueError("Adapter not found")
 
         import torch
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
+        output_path = settings.adapters_dir / f"{job_id}_merged"
+
         if merge_with_base:
             base_model_id = _get_base_model_id_local(job_id)
             if base_model_id:
@@ -57,7 +70,7 @@ async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
                     AutoModelForCausalLM.from_pretrained(
                         str(adapter_path), torch_dtype=torch.float16
                     ),
-                    adapter_path,
+                    str(adapter_path),
                 )
                 merged = merged.merge_and_unload()
                 merged.save_pretrained(output_path)
@@ -71,11 +84,13 @@ async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
             gguf_path = output_path.with_suffix(".gguf")
             _export_to_gguf_local(output_path, gguf_path)
 
-        return _update_task_status(task_id, "completed", output_path=str(output_path))
+        await _update_deploy_status(task_id, "completed", output_path=str(output_path))
+        return {"output_path": str(output_path)}
 
     except Exception as e:
         logger.error(f"Export failed for job {job_id}: {e}")
-        return _update_task_status(task_id, "failed", error=str(e))
+        await _update_deploy_status(task_id, "failed", error=str(e))
+        return {"error": str(e)}
 
 
 async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
@@ -83,7 +98,7 @@ async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, e
     remote_cmd = (
         f"docker exec "
         f"-e MACA_MPS_MODE=1 "
-        f"-e CUDA_VISIBLE_DEVICES=2,3 "
+        f"-e CUDA_VISIBLE_DEVICES=3 "
         f"-w {settings.compute_node_workdir} "
         f"{settings.compute_node_docker_container} "
         f"{settings.compute_node_python} -c \""
@@ -96,8 +111,7 @@ async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, e
     code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
 
     if code != 0:
-        logger.error(f"Remote export failed: {stderr}")
-        return _update_task_status(task_id, "failed", error=stderr.strip())
+        raise RuntimeError(f"Remote export failed: {stderr}")
 
     for line in reversed(stdout.strip().split("\n")):
         line = line.strip()
@@ -105,34 +119,32 @@ async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, e
             try:
                 result = json.loads(line)
                 if "error" in result:
-                    return _update_task_status(task_id, "failed", error=result["error"])
-                return _update_task_status(task_id, "completed", output_path=result.get("output_path"))
+                    raise RuntimeError(result["error"])
+                await _update_deploy_status(task_id, "completed", output_path=result.get("output_path"))
+                return {"output_path": result.get("output_path")}
             except json.JSONDecodeError:
                 continue
 
-    return _update_task_status(task_id, "failed", error=f"Invalid response: {stdout[:500]}")
-
-
-def _update_task_status(task_id: str, status: str, output_path: str = None, error: str = None):
-    import asyncio
+    raise RuntimeError(f"Invalid response: {stdout[:500]}")
 
-    async def _update():
-        async with async_session() as session:
-            result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
-            record = result.scalar_one_or_none()
-            if record:
-                record.status = status
-                if output_path:
-                    record.output_path = output_path
-                if error:
-                    record.error = error
-                await session.commit()
 
-    asyncio.get_event_loop().run_until_complete(_update())
-    base = {"job_id": "", "status": status, "output_path": output_path}
-    if error:
-        base["error"] = error
-    return base
+async def _update_deploy_status(task_id: str, status: str, output_path: str = None, error: str = None):
+    async with async_session() as session:
+        result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
+        record = result.scalar_one_or_none()
+        if record:
+            record.status = status
+            if output_path:
+                record.output_path = output_path
+            if error:
+                record.error = error
+            if status in ("completed", "failed"):
+                record.finished_at = datetime.utcnow()
+            await session.commit()
+
+    background_task_manager.update_task(
+        task_id, status=status, output_path=output_path, error=error,
+    )
 
 
 def _get_base_model_id_local(job_id: str):
@@ -166,7 +178,25 @@ async def get_deploy_status(task_id: str) -> dict[str, Any]:
             return {
                 "job_id": record.job_id,
                 "status": record.status,
+                "progress": record.progress,
                 "output_path": record.output_path,
                 "error": record.error,
             }
-    return {"job_id": "", "status": "not_found", "output_path": None, "error": None}
+    return {"job_id": "", "status": "not_found", "progress": 0.0, "output_path": None, "error": None}
+
+
+async def recover_stale_deploys() -> None:
+    async with async_session() as session:
+        result = await session.execute(
+            select(DeployTaskModel).where(
+                DeployTaskModel.status.in_(["pending", "running"])
+            )
+        )
+        records = result.scalars().all()
+        for record in records:
+            record.status = "failed"
+            record.error = "Server restarted, task interrupted"
+            record.finished_at = datetime.utcnow()
+        if records:
+            await session.commit()
+            logger.info(f"Recovered {len(records)} stale deploy tasks")

+ 84 - 46
backend/app/services/eval_service.py

@@ -1,9 +1,10 @@
 import json
 import uuid
-from datetime import datetime
+from datetime import datetime, timezone
 from typing import Any
 
 from app.config import get_settings
+from app.core.background_tasks import background_task_manager
 from app.core.db import async_session, EvalResultModel
 from app.core.logging import logger
 from app.core.remote_executor import ssh_exec
@@ -13,20 +14,43 @@ settings = get_settings()
 
 
 async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
-    """在已训练的 adapter 上运行评估(perplexity)。"""
+    """启动评估后台任务,立即返回 eval_id。"""
     eval_id = str(uuid.uuid4())
 
-    # 远程训练模式:把评估任务也发到远程容器执行
-    if settings.use_remote_compute:
-        logger.info(f"Running remote evaluation for job {job_id}")
-        return await _run_remote_evaluation(eval_id, job_id)
+    # 写 DB
+    record = EvalResultModel(
+        id=eval_id,
+        job_id=job_id,
+        status="pending",
+        metrics="{}",
+    )
+    async with async_session() as session:
+        session.add(record)
+        await session.commit()
+
+    # 注册并启动
+    background_task_manager.register_task(eval_id, "evaluation", {"job_id": job_id})
+    background_task_manager.run(
+        eval_id, "evaluation", _execute_evaluation(eval_id, job_id, config)
+    )
 
-    adapter_path = settings.adapters_dir / job_id
+    logger.info(f"Evaluation task started: job={job_id} (eval_id={eval_id})")
+    return {"id": eval_id, "job_id": job_id, "status": "pending"}
 
-    if not adapter_path.exists():
-        return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"}
 
+async def _execute_evaluation(eval_id: str, job_id: str, config: dict[str, Any]) -> dict:
+    """后台执行评估。"""
     try:
+        # 远程训练模式:把评估任务也发到远程容器执行
+        if settings.use_remote_compute:
+            logger.info(f"Running remote evaluation for job {job_id}")
+            result = await _run_remote_evaluation(eval_id, job_id)
+            return {"metrics": result.get("metrics", {})}
+
+        adapter_path = settings.adapters_dir / job_id
+        if not adapter_path.exists():
+            raise ValueError("Adapter not found")
+
         import torch
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
@@ -34,24 +58,13 @@ async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
         model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
         tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
 
-        # 加载评估数据
-        async with async_session() as session:
-            from app.core.db import TrainingJobModel
-            result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
-            record = result.scalar_one_or_none()
-
-        if record:
-            dataset_path = record.dataset_id
-
-        metrics = {}
-        model.eval()
-
-        # 计算 perplexity(使用 adapter 自身的数据或默认样例)
+        # 计算 perplexity
         sample_texts = [
             "The quick brown fox jumps over the lazy dog.",
             "Hello, how are you doing today?",
         ]
         losses = []
+        model.eval()
         with torch.no_grad():
             for text in sample_texts:
                 inputs = tokenizer(text, return_tensors="pt").to(model.device)
@@ -67,23 +80,29 @@ async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
             "num_samples": len(sample_texts),
         }
 
-        # 保存结果
-        eval_record = EvalResultModel(
-            id=eval_id,
-            job_id=job_id,
-            metrics=json.dumps(metrics),
-            created_at=datetime.utcnow(),
-        )
+        # 更新 DB
         async with async_session() as session:
-            session.add(eval_record)
-            await session.commit()
+            result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
+            eval_record = result.scalar_one_or_none()
+            if eval_record:
+                eval_record.metrics = json.dumps(metrics)
+                eval_record.status = "completed"
+                eval_record.progress = 100.0
+                await session.commit()
 
         logger.info(f"Evaluation completed for job {job_id}: {metrics}")
-        return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()}
+        return {"metrics": metrics}
 
     except Exception as e:
         logger.error(f"Evaluation failed for job {job_id}: {e}")
-        return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)}
+        async with async_session() as session:
+            result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
+            eval_record = result.scalar_one_or_none()
+            if eval_record:
+                eval_record.status = "failed"
+                eval_record.error = str(e)
+                await session.commit()
+        return {"error": str(e)}
 
 
 async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
@@ -91,7 +110,7 @@ async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
     remote_cmd = (
         f"docker exec "
         f"-e MACA_MPS_MODE=1 "
-        f"-e CUDA_VISIBLE_DEVICES=2,3 "
+        f"-e CUDA_VISIBLE_DEVICES=3 "
         f"-w {settings.compute_node_workdir} "
         f"{settings.compute_node_docker_container} "
         f"{settings.compute_node_python} -c \""
@@ -104,8 +123,7 @@ async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
     code, stdout, stderr = ssh_exec(remote_cmd, timeout=300)
 
     if code != 0:
-        logger.error(f"Remote evaluation failed: {stderr}")
-        return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": stderr.strip()}
+        raise RuntimeError(f"Remote evaluation failed: {stderr}")
 
     # 提取最后一行 JSON
     for line in reversed(stdout.strip().split("\n")):
@@ -114,21 +132,22 @@ async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
             try:
                 result = json.loads(line)
                 # 保存结果到本地数据库
-                eval_record = EvalResultModel(
-                    id=eval_id,
-                    job_id=job_id,
-                    metrics=json.dumps(result.get("metrics", {})),
-                    created_at=datetime.utcnow(),
-                )
+                metrics = result.get("metrics", {})
                 async with async_session() as session:
+                    eval_record = EvalResultModel(
+                        id=eval_id,
+                        job_id=job_id,
+                        metrics=json.dumps(metrics),
+                        status="completed",
+                        created_at=datetime.utcnow(),
+                    )
                     session.add(eval_record)
                     await session.commit()
-                return {"id": eval_id, "job_id": job_id, "metrics": result.get("metrics", {}),
-                        "created_at": eval_record.created_at.isoformat()}
+                return {"id": eval_id, "job_id": job_id, "metrics": metrics}
             except json.JSONDecodeError:
                 continue
 
-    return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": f"Invalid response: {stdout[:500]}"}
+    raise RuntimeError(f"Invalid response: {stdout[:500]}")
 
 
 async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
@@ -140,7 +159,26 @@ async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
             return {
                 "id": record.id,
                 "job_id": record.job_id,
+                "status": record.status,
+                "progress": record.progress,
                 "metrics": json.loads(record.metrics) if record.metrics else {},
+                "error": record.error,
                 "created_at": record.created_at.isoformat(),
             }
-    return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}
+    return {"id": eval_id, "job_id": "", "status": "not_found", "metrics": {}}
+
+
+async def recover_stale_evaluations() -> None:
+    async with async_session() as session:
+        result = await session.execute(
+            select(EvalResultModel).where(
+                EvalResultModel.status.in_(["pending", "running"])
+            )
+        )
+        records = result.scalars().all()
+        for record in records:
+            record.status = "failed"
+            record.error = "Server restarted, task interrupted"
+        if records:
+            await session.commit()
+            logger.info(f"Recovered {len(records)} stale evaluation tasks")

+ 1 - 1
backend/app/services/inference_service.py

@@ -79,7 +79,7 @@ def _generate_local(
         if tokenizer.pad_token is None:
             tokenizer.pad_token = tokenizer.eos_token
 
-        # CUDA_VISIBLE_DEVICES=2,3 已将物理 GPU 2,3 映射为逻辑 GPU 0,1
+        # CUDA_VISIBLE_DEVICES=3 已将物理 GPU 3 映射为逻辑 GPU 0
         import torch
         device_map = {"": 0}
         torch.cuda.set_device(0)

+ 147 - 8
backend/app/services/model_service.py

@@ -1,10 +1,13 @@
 import os
 import json
+import uuid
+from datetime import datetime, timezone
 from pathlib import Path
 from typing import Any
 
 from app.config import get_settings
-from app.core.db import async_session, ModelCache
+from app.core.background_tasks import background_task_manager
+from app.core.db import async_session, ModelCache, ModelDownloadTask
 from app.core.logging import logger
 from sqlalchemy import select
 
@@ -40,13 +43,46 @@ async def resolve_model_path(model_id: str) -> str | None:
 
 
 async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
-    """从 HF 或 ModelScope 下载模型到本地缓存。"""
+    """启动模型下载后台任务,立即返回 task_id。"""
+    task_id = str(uuid.uuid4())
+
+    # 检查是否有正在进行的同模型下载
+    for tid, t in background_task_manager.tasks.items():
+        if (
+            t.get("task_type") == "model_download"
+            and t.get("model_id") == model_id
+            and t.get("status") in ("pending", "downloading", "running")
+        ):
+            return {"task_id": tid, "model_id": model_id, "status": t["status"], "duplicate": True}
+
+    # 写 DB
+    record = ModelDownloadTask(
+        id=task_id,
+        model_id=model_id,
+        use_modelscope=1 if use_modelscope else 0,
+        status="pending",
+    )
+    async with async_session() as session:
+        session.add(record)
+        await session.commit()
+
+    # 注册并启动
+    background_task_manager.register_task(task_id, "model_download", {"model_id": model_id})
+    background_task_manager.run(
+        task_id, "model_download", _execute_model_download(task_id, model_id, use_modelscope)
+    )
+
+    logger.info(f"Model download task started: {model_id} (task_id={task_id})")
+    return {"task_id": task_id, "model_id": model_id, "status": "pending"}
+
+
+async def _execute_model_download(task_id: str, model_id: str, use_modelscope: bool) -> dict:
+    """后台执行模型下载。"""
     try:
         if use_modelscope:
             import subprocess
 
             download_dir = str(settings.models_dir / model_id.replace("/", "_"))
-            # 用独立进程调用 CLI,完全隔离 FastAPI 事件循环,避免 __aenter__ 错误
             proc = subprocess.run(
                 [
                     "modelscope", "download",
@@ -67,7 +103,7 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
                 local_dir_use_symlinks=False,
             )
 
-        # 读取 config.json 获取模型信息
+        # 读取 config.json
         config_path = Path(local_path) / "config.json"
         model_type = "text"
         context_length = 2048
@@ -79,7 +115,7 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
             model_type = cfg.get("model_type", "text")
             context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
 
-        # 写入数据库(如果已存在则更新)
+        # 更新 ModelCache
         async with async_session() as session:
             result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
             existing = result.scalar_one_or_none()
@@ -103,8 +139,12 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
                 session.add(record)
             await session.commit()
 
+        # 更新下载任务
+        await _update_model_download_status(task_id, "completed", path=local_path)
+
         logger.info(f"Model downloaded: {model_id} -> {local_path}")
-        return {"model_id": model_id, "status": "completed", "path": local_path}
+        return {"path": local_path}
+
     except Exception as e:
         import traceback
         tb = traceback.format_exc()
@@ -113,8 +153,107 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
         error_msg = str(e)
         if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
             error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
-        return {"model_id": model_id, "status": "failed", "error": error_msg}
-        return {"model_id": model_id, "status": "failed", "error": error_msg}
+        await _update_model_download_status(task_id, "failed", error=error_msg)
+        return {"error": error_msg}
+
+
+async def _update_model_download_status(task_id: str, status: str, path: str = None, error: str = None):
+    async with async_session() as session:
+        result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
+        record = result.scalar_one_or_none()
+        if record:
+            record.status = status
+            if path:
+                record.path = path
+            if error:
+                record.error = error
+            if status in ("completed", "failed"):
+                record.finished_at = datetime.utcnow()
+            await session.commit()
+
+    background_task_manager.update_task(
+        task_id, status=status, path=path, error=error,
+        finished_at=datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
+    )
+
+
+async def get_model_download_status(task_id: str) -> dict[str, Any]:
+    async with async_session() as session:
+        result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
+        record = result.scalar_one_or_none()
+        if record:
+            return {
+                "task_id": record.id,
+                "model_id": record.model_id,
+                "status": record.status,
+                "use_modelscope": bool(record.use_modelscope),
+                "path": record.path,
+                "error": record.error,
+                "progress": record.progress,
+                "created_at": record.created_at.isoformat() if record.created_at else "",
+            }
+    # 也查内存
+    mem = background_task_manager.get_task(task_id)
+    if mem:
+        return {
+            "task_id": task_id,
+            "model_id": mem.get("model_id", ""),
+            "status": mem["status"],
+            "error": mem.get("error"),
+            "progress": mem.get("progress", 0),
+        }
+    return {"task_id": task_id, "status": "not_found"}
+
+
+async def list_model_downloads() -> list[dict[str, Any]]:
+    async with async_session() as session:
+        result = await session.execute(
+            select(ModelDownloadTask).order_by(ModelDownloadTask.created_at.desc())
+        )
+        records = result.scalars().all()
+    return [
+        {
+            "task_id": r.id,
+            "model_id": r.model_id,
+            "status": r.status,
+            "use_modelscope": bool(r.use_modelscope),
+            "path": r.path,
+            "error": r.error,
+            "created_at": r.created_at.isoformat() if r.created_at else "",
+        }
+        for r in records
+    ]
+
+
+async def cancel_model_download(task_id: str) -> dict[str, Any]:
+    background_task_manager.cancel_task(task_id)
+    async with async_session() as session:
+        result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
+        record = result.scalar_one_or_none()
+        if record and record.status in ("pending", "downloading"):
+            record.status = "cancelled"
+            record.error = "Cancelled by user"
+            record.finished_at = datetime.utcnow()
+            await session.commit()
+    return {"task_id": task_id, "status": "cancelled"}
+
+
+async def recover_stale_downloads() -> None:
+    """把因重启中断的下载任务标记为 failed。"""
+    async with async_session() as session:
+        result = await session.execute(
+            select(ModelDownloadTask).where(
+                ModelDownloadTask.status.in_(["pending", "downloading"])
+            )
+        )
+        records = result.scalars().all()
+        for record in records:
+            record.status = "failed"
+            record.error = "Server restarted, task interrupted"
+            record.finished_at = datetime.utcnow()
+        if records:
+            await session.commit()
+            logger.info(f"Recovered {len(records)} stale model download tasks")
 
 
 async def list_cached_models() -> list[dict[str, Any]]:

+ 3 - 3
backend/app/services/training_service.py

@@ -54,10 +54,10 @@ async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
         await session.commit()
 
     # 加入 JobQueue
-    # 如果启用 DeepSpeed,生成配置文件
+    # DeepSpeed 需要多 GPU,单卡模式已禁用
     if config.get("deepspeed", False):
-        ds_config_path = _generate_deepspeed_config()
-        config["deepspeed"] = ds_config_path
+        config["deepspeed"] = False
+        logger.warning("DeepSpeed requires multiple GPUs, but only GPU 3 is available. DeepSpeed disabled.")
 
     job = TrainingJob(
         id=job_id,

+ 16 - 0
backend/main.py

@@ -31,6 +31,22 @@ async def lifespan(app: FastAPI):
     job_queue.register_callback(update_job_in_db)
     await job_queue.start()
 
+    # 初始化后台任务管理器
+    from app.core.background_tasks import background_task_manager
+
+    background_task_manager.set_concurrency("model_download", 5)
+    background_task_manager.set_concurrency("dataset_download", 5)
+    background_task_manager.set_concurrency("evaluation", 1)
+    background_task_manager.set_concurrency("deployment", 1)
+
+    # 恢复因重启中断的任务
+    from app.services import model_service, dataset_service, eval_service, deploy_service
+
+    await model_service.recover_stale_downloads()
+    await dataset_service.recover_stale_downloads()
+    await eval_service.recover_stale_evaluations()
+    await deploy_service.recover_stale_deploys()
+
     yield
 
     # 关闭时:停止 JobQueue

+ 37 - 3
frontend/src/api/client.ts

@@ -83,7 +83,13 @@ const api = {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify({ model_id: modelId, use_modelscope: useModelscope }),
-      }).then(r => r.json()) as Promise<ModelDownloadResponse>,
+      }).then(r => r.json()) as Promise<ModelDownloadTaskResponse>,
+    downloadStatus: (taskId: string) =>
+      apiFetch(`/api/v1/models/download/${taskId}`).then(r => r.json()) as Promise<ModelDownloadTaskResponse>,
+    listDownloads: () =>
+      apiFetch('/api/v1/models/downloads').then(r => r.json()) as Promise<ModelDownloadTaskResponse[]>,
+    cancelDownload: (taskId: string) =>
+      apiFetch(`/api/v1/models/download/${taskId}/cancel`, { method: 'POST' }).then(r => r.json()),
     delete: (modelId: string) =>
       apiFetch(`/api/v1/models/${encodeURIComponent(modelId)}`, { method: 'DELETE' }).then(r => r.json()),
     getInfo: (modelId: string) =>
@@ -109,7 +115,13 @@ const api = {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify({ dataset_id: datasetId, use_modelscope: useModelscope }),
-      }).then(r => r.json()) as Promise<DatasetDownloadResponse>,
+      }).then(r => r.json()) as Promise<DatasetDownloadTaskResponse>,
+    downloadStatus: (taskId: string) =>
+      apiFetch(`/api/v1/datasets/download/${taskId}`).then(r => r.json()) as Promise<DatasetDownloadTaskResponse>,
+    listDownloads: () =>
+      apiFetch('/api/v1/datasets/downloads').then(r => r.json()) as Promise<DatasetDownloadTaskResponse[]>,
+    cancelDownload: (taskId: string) =>
+      apiFetch(`/api/v1/datasets/download/${taskId}/cancel`, { method: 'POST' }).then(r => r.json()),
     preview: (id: string, rows = 10) =>
       apiFetch(`/api/v1/datasets/${id}/preview?rows=${rows}`).then(r => r.json()) as Promise<DatasetPreview>,
     validate: (id: string) =>
@@ -220,6 +232,17 @@ interface ModelDownloadResponse {
   error?: string
 }
 
+interface ModelDownloadTaskResponse {
+  task_id: string
+  model_id: string
+  status: string
+  use_modelscope?: boolean
+  path?: string
+  error?: string
+  progress?: number
+  created_at?: string
+}
+
 interface DatasetInfo {
   id: string
   name: string
@@ -236,6 +259,17 @@ interface DatasetDownloadResponse {
   error?: string
 }
 
+interface DatasetDownloadTaskResponse {
+  task_id: string
+  dataset_id: string
+  status: string
+  use_modelscope?: boolean
+  path?: string
+  error?: string
+  record_count?: number
+  created_at?: string
+}
+
 interface DatasetPreview {
   total_records: number
   preview_rows: { row_index: number; data: Record<string, unknown> }[]
@@ -382,4 +416,4 @@ interface KbImportResponse {
   child_table: string
 }
 
-export type { ModelInfo, ModelTestRequest, ModelTestResponse, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse, KnowledgeBaseItem, KnowledgeBaseListResponse, KnowledgeBaseDetailResponse, KbImportResponse }
+export type { ModelInfo, ModelTestRequest, ModelTestResponse, ModelDownloadResponse, ModelDownloadTaskResponse, DatasetInfo, DatasetDownloadResponse, DatasetDownloadTaskResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse, KnowledgeBaseItem, KnowledgeBaseListResponse, KnowledgeBaseDetailResponse, KbImportResponse }

+ 102 - 8
frontend/src/pages/Datasets.tsx

@@ -1,6 +1,6 @@
 import { useState, useEffect, useRef, memo, useCallback } from 'react'
-import api, { DatasetInfo, KnowledgeBaseItem } from '../api/client'
-import { Database, Upload, Loader2, FolderOpen } from 'lucide-react'
+import api, { DatasetInfo, KnowledgeBaseItem, DatasetDownloadTaskResponse } from '../api/client'
+import { Database, Upload, Loader2, FolderOpen, CheckCircle, XCircle } from 'lucide-react'
 
 const DatasetRow = memo(function DatasetRow({ d, onPreview, onDelete }: {
   d: DatasetInfo
@@ -65,6 +65,10 @@ export function Datasets() {
   const [kbPage, setKbPage] = useState(1)
   const [kbTotal, setKbTotal] = useState(0)
 
+  // Active downloads tracking
+  const [activeDownloads, setActiveDownloads] = useState<Map<string, DatasetDownloadTaskResponse>>(new Map())
+  const downloadPollIntervals = useRef<Map<string, ReturnType<typeof setInterval>>>(new Map())
+
   useEffect(() => {
     fetchDatasets()
   }, [])
@@ -99,16 +103,67 @@ export function Datasets() {
     if (file) handleFileUpload(file)
   }
 
-  const handleDownload = () => {
+  const handleDownload = async () => {
     if (!dlDatasetId.trim()) return
     setDownloading(true)
-    setDlStatus('正在下载...')
-    api.datasets.download(dlDatasetId, dlUseModelscope)
-      .then(res => setDlStatus(`${res.dataset_id}: ${res.status}${res.error ? ` - ${res.error}` : ''}`))
-      .catch(err => setDlStatus(`下载失败: ${err.message}`))
-      .finally(() => setDownloading(false))
+    setDlStatus('正在提交下载任务...')
+    try {
+      const res = await api.datasets.download(dlDatasetId, dlUseModelscope)
+      setDlStatus(`下载任务已提交: ${res.dataset_id}`)
+      setActiveDownloads(prev => new Map(prev).set(res.task_id, res))
+      startDatasetDownloadPolling(res.task_id)
+    } catch (err) {
+      setDlStatus(`下载失败: ${err instanceof Error ? err.message : '未知错误'}`)
+    } finally {
+      setDownloading(false)
+    }
+  }
+
+  const startDatasetDownloadPolling = (taskId: string) => {
+    const interval = setInterval(() => {
+      api.datasets.downloadStatus(taskId)
+        .then(res => {
+          setActiveDownloads(prev => {
+            const next = new Map(prev)
+            next.set(taskId, res)
+            return next
+          })
+          if (res.status === 'completed') {
+            clearInterval(interval)
+            downloadPollIntervals.current.delete(taskId)
+            fetchDatasets()
+            setDlStatus(`${res.dataset_id} 下载完成 (${res.record_count} 条记录)`)
+          } else if (res.status === 'failed') {
+            clearInterval(interval)
+            downloadPollIntervals.current.delete(taskId)
+            setDlStatus(`${res.dataset_id} 下载失败: ${res.error}`)
+          }
+        })
+        .catch(() => {})
+    }, 3000)
+    downloadPollIntervals.current.set(taskId, interval)
+  }
+
+  const handleCancelDatasetDownload = (taskId: string) => {
+    api.datasets.cancelDownload(taskId)
+    setActiveDownloads(prev => {
+      const next = new Map(prev)
+      next.delete(taskId)
+      return next
+    })
+    const interval = downloadPollIntervals.current.get(taskId)
+    if (interval) {
+      clearInterval(interval)
+      downloadPollIntervals.current.delete(taskId)
+    }
   }
 
+  useEffect(() => {
+    return () => {
+      downloadPollIntervals.current.forEach(interval => clearInterval(interval))
+    }
+  }, [])
+
   const handlePreview = (id: string) => {
     api.datasets.preview(id, 10)
       .then(res => setPreviewData({ columns: res.columns, rows: res.preview_rows }))
@@ -237,6 +292,45 @@ export function Datasets() {
         }}>{dlStatus}</p>}
       </div>
 
+      {/* Active Downloads */}
+      {activeDownloads.size > 0 && (
+        <div style={{
+          marginTop: 24, background: '#fff', borderRadius: 10, padding: 20,
+          boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
+        }}>
+          <h2 style={{ margin: '0 0 12px', fontSize: 15, fontWeight: 600 }}>下载任务</h2>
+          <div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
+            {Array.from(activeDownloads.values()).map(dl => (
+              <div key={dl.task_id} style={{
+                display: 'flex', alignItems: 'center', gap: 12, padding: 12,
+                borderRadius: 8, background: '#f8fafc', border: '1px solid #e2e8f0',
+              }}>
+                {dl.status === 'completed' ? (
+                  <CheckCircle size={20} color="#059669" />
+                ) : dl.status === 'failed' ? (
+                  <XCircle size={20} color="#f43f5e" />
+                ) : (
+                  <Loader2 size={20} color="#14b8a6" style={{ animation: 'lucide-spin 1s linear infinite' }} />
+                )}
+                <div style={{ flex: 1 }}>
+                  <div style={{ fontSize: 13, fontWeight: 500 }}>{dl.dataset_id}</div>
+                  <div style={{ fontSize: 12, color: '#64748b' }}>
+                    {dl.status === 'completed' ? `已完成 (${dl.record_count} 条记录)` : dl.status === 'failed' ? `失败: ${dl.error}` :
+                     dl.status === 'cancelled' ? '已取消' : '下载中...'}
+                  </div>
+                </div>
+                {(dl.status === 'pending' || dl.status === 'running' || dl.status === 'downloading') && (
+                  <button onClick={() => handleCancelDatasetDownload(dl.task_id)} style={{
+                    padding: '4px 12px', color: '#f43f5e', border: '1px solid #f43f5e',
+                    borderRadius: 6, background: 'transparent', cursor: 'pointer', fontSize: 12,
+                  }}>取消</button>
+                )}
+              </div>
+            ))}
+          </div>
+        </div>
+      )}
+
       {/* Sample Center section */}
       <div style={{
         marginTop: 24, background: '#fff', borderRadius: 10, padding: 20,

+ 112 - 9
frontend/src/pages/Models.tsx

@@ -1,6 +1,6 @@
-import { useState, useEffect, memo } from 'react'
-import api, { ModelInfo } from '../api/client'
-import { Cpu, CheckCircle, XCircle } from 'lucide-react'
+import { useState, useEffect, memo, useRef } from 'react'
+import api, { ModelInfo, ModelDownloadTaskResponse } from '../api/client'
+import { Cpu, CheckCircle, XCircle, Loader2 } from 'lucide-react'
 
 const ModelRow = memo(function ModelRow({ m, onTest, onDelete }: {
   m: ModelInfo
@@ -62,6 +62,10 @@ export function Models() {
   const [statusType, setStatusType] = useState<'success' | 'error' | ''>('')
   const [statusContent, setStatusContent] = useState('')
 
+  // Active downloads tracking
+  const [activeDownloads, setActiveDownloads] = useState<Map<string, ModelDownloadTaskResponse>>(new Map())
+  const pollIntervals = useRef<Map<string, ReturnType<typeof setInterval>>>(new Map())
+
   // Test state
   const [testModelId, setTestModelId] = useState('')
   const [testPrompt, setTestPrompt] = useState('')
@@ -81,17 +85,77 @@ export function Models() {
       .finally(() => setLoading(false))
   }
 
-  const handleDownload = () => {
+  const handleDownload = async () => {
     if (!modelId.trim()) return
     setDownloading(true)
     setStatusType('')
-    setStatusContent('正在下载...')
-    api.models.download(modelId, useModelscope)
-      .then(res => { setStatusType('success'); setStatusContent(`${res.model_id}: ${res.status}`) })
-      .catch(err => { setStatusType('error'); setStatusContent(`下载失败: ${err.message}`) })
-      .finally(() => setDownloading(false))
+    setStatusContent('正在提交下载任务...')
+    try {
+      const res = await api.models.download(modelId, useModelscope)
+      setStatusType('success')
+      setStatusContent(`下载任务已提交: ${res.model_id}`)
+      // Add to active downloads
+      setActiveDownloads(prev => new Map(prev).set(res.task_id, res))
+      // Start polling
+      startModelDownloadPolling(res.task_id)
+    } catch (err) {
+      setStatusType('error')
+      setStatusContent(`下载失败: ${err instanceof Error ? err.message : '未知错误'}`)
+    } finally {
+      setDownloading(false)
+    }
+  }
+
+  const startModelDownloadPolling = (taskId: string) => {
+    const interval = setInterval(() => {
+      api.models.downloadStatus(taskId)
+        .then(res => {
+          setActiveDownloads(prev => {
+            const next = new Map(prev)
+            next.set(taskId, res)
+            return next
+          })
+          if (res.status === 'completed') {
+            clearInterval(interval)
+            pollIntervals.current.delete(taskId)
+            fetchModels()
+            setStatusType('success')
+            setStatusContent(`${res.model_id} 下载完成`)
+          } else if (res.status === 'failed') {
+            clearInterval(interval)
+            pollIntervals.current.delete(taskId)
+            setStatusType('error')
+            setStatusContent(`${res.model_id} 下载失败: ${res.error}`)
+          }
+        })
+        .catch(() => {
+          // ignore polling errors
+        })
+    }, 3000)
+    pollIntervals.current.set(taskId, interval)
+  }
+
+  const handleCancelDownload = (taskId: string) => {
+    api.models.cancelDownload(taskId)
+    setActiveDownloads(prev => {
+      const next = new Map(prev)
+      next.delete(taskId)
+      return next
+    })
+    const interval = pollIntervals.current.get(taskId)
+    if (interval) {
+      clearInterval(interval)
+      pollIntervals.current.delete(taskId)
+    }
   }
 
+  // Cleanup polling on unmount
+  useEffect(() => {
+    return () => {
+      pollIntervals.current.forEach(interval => clearInterval(interval))
+    }
+  }, [])
+
   const handleDelete = async (id: string, name: string) => {
     if (!confirm(`确定删除模型 "${name}"?这将删除本地所有相关文件。`)) return
     try {
@@ -190,6 +254,45 @@ export function Models() {
         )}
       </div>
 
+      {/* Active Downloads */}
+      {activeDownloads.size > 0 && (
+        <div style={{
+          marginTop: 24, background: '#fff', borderRadius: 10, padding: 20,
+          boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
+        }}>
+          <h2 style={{ margin: '0 0 12px', fontSize: 15, fontWeight: 600 }}>下载任务</h2>
+          <div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
+            {Array.from(activeDownloads.values()).map(dl => (
+              <div key={dl.task_id} style={{
+                display: 'flex', alignItems: 'center', gap: 12, padding: 12,
+                borderRadius: 8, background: '#f8fafc', border: '1px solid #e2e8f0',
+              }}>
+                {dl.status === 'completed' ? (
+                  <CheckCircle size={20} color="#059669" />
+                ) : dl.status === 'failed' ? (
+                  <XCircle size={20} color="#f43f5e" />
+                ) : (
+                  <Loader2 size={20} color="#14b8a6" style={{ animation: 'lucide-spin 1s linear infinite' }} />
+                )}
+                <div style={{ flex: 1 }}>
+                  <div style={{ fontSize: 13, fontWeight: 500 }}>{dl.model_id}</div>
+                  <div style={{ fontSize: 12, color: '#64748b' }}>
+                    {dl.status === 'completed' ? '已完成' : dl.status === 'failed' ? `失败: ${dl.error}` :
+                     dl.status === 'cancelled' ? '已取消' : '下载中...'}
+                  </div>
+                </div>
+                {(dl.status === 'pending' || dl.status === 'running' || dl.status === 'downloading') && (
+                  <button onClick={() => handleCancelDownload(dl.task_id)} style={{
+                    padding: '4px 12px', color: '#f43f5e', border: '1px solid #f43f5e',
+                    borderRadius: 6, background: 'transparent', cursor: 'pointer', fontSize: 12,
+                  }}>取消</button>
+                )}
+              </div>
+            ))}
+          </div>
+        </div>
+      )}
+
       {/* Model list */}
       <div style={{ marginTop: 24 }}>
         <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>

+ 3 - 13
frontend/src/pages/Training.tsx

@@ -322,7 +322,6 @@ export function Training() {
   const [loraR, setLoraR] = useState(16)
   const [seqLen, setSeqLen] = useState(2048)
   const [gradAcc, setGradAcc] = useState(4)
-  const [deepspeed, setDeepspeed] = useState(false)
 
   const [jobs, setJobs] = useState<TrainingJob[]>([])
   const [loading, setLoading] = useState(false)
@@ -400,7 +399,7 @@ export function Training() {
       peft_method: peftMethod, task_type: taskType, dataset_template: template,
       epochs, batch_size: batchSize, gradient_accumulation: gradAcc,
       max_seq_length: seqLen, learning_rate: parseFloat(lr),
-      lora_r: loraR, lora_alpha: loraR * 2, deepspeed,
+      lora_r: loraR, lora_alpha: loraR * 2,
     })
       .then(() => {
         setModelId('')
@@ -504,17 +503,8 @@ export function Training() {
           </div>
         </div>
 
-        {/* 高级选项 */}
-        <div style={{
-          fontSize: 13, fontWeight: 600, color: '#14b8a6', marginBottom: 12,
-          paddingBottom: 6, borderBottom: '2px solid #ccfbf1',
-        }}>高级选项</div>
-        <div style={{ display: 'flex', gap: 16, alignItems: 'center', marginBottom: 20 }}>
-          <label style={{ display: 'flex', alignItems: 'center', gap: 6, fontSize: 13, cursor: 'pointer' }}>
-            <input type="checkbox" checked={deepspeed} onChange={e => setDeepspeed(e.target.checked)} />
-            DeepSpeed ZeRO-2 (多 GPU)
-          </label>
-        </div>
+        {/* 高级选项 — DeepSpeed 已禁用(仅 GPU 3 单卡) */}
+
 
         {/* 错误提示 */}
         {createError && (

+ 38 - 21
result.txt

@@ -1,21 +1,38 @@
-(base) [root@localhost ~]# docker exec finetune-trainer bash -c 'tail -n 20 $(ls -t /tmp/train_*.log | head -1)'
-[remote_train]   Preprocessing done, output: /root/Fine-tuning/backend/data/processed/fb18c7a8-e275-4014-b6a3-dea08f3f7adb_processed.jsonl
-[remote_train] Step 2: Loading model: Qwen/Qwen1.5-0.5B...
-[remote_train]   Quantization: None
-Loading weights: 100%|██████████| 291/291 [00:04<00:00, 59.89it/s] 
-[remote_train]   Model loaded successfully
-[remote_train] Step 3: Building PEFT config...
-[remote_train]   PEFT config built
-[remote_train] Step 4: Starting training...
-Map: 100%|██████████| 274147/274147 [00:15<00:00, 18259.13 examples/s]
-/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:1348: UserWarning: Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, but `ensure_weight_tying` is not set to True. This can lead to complications, for example when merging the adapter or converting your model to formats other than safetensors. Check the discussion here: https://github.com/huggingface/peft/issues/2777
-  warnings.warn(msg)
-[transformers] warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
-/opt/conda/lib/python3.10/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
-  warnings.warn(_BETA_TRANSFORMS_WARNING)
-/opt/conda/lib/python3.10/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
-  warnings.warn(_BETA_TRANSFORMS_WARNING)
-trainable params: 5,593,088 || all params: 469,580,800 || trainable%: 1.1911
-  0%|          | 0/4284 [00:00<?, ?it/s]/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:108: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /workspace/framework/mcPytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:183.)
-  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- 19%|█▉        | 812/4284 [21:54<1:34:25,  1.63s/it](base) [root@localhost ~]# 
+(base) [root@localhost ~]# mx-smi
+mx-smi  version: 2.2.9
+
+=================== MetaX System Management Interface Log ===================
+Timestamp                                         : Fri May 22 03:09:03 2026
+
+Attached GPUs                                     : 4
++---------------------------------------------------------------------------------+
+| MX-SMI 2.2.9                       Kernel Mode Driver Version: 3.4.4            |
+| MACA Version: 3.3.0.15             BIOS Version: 1.30.0.0                       |
+|------------------+-----------------+---------------------+----------------------|
+| Board       Name | GPU   Persist-M | Bus-id              | GPU-Util      sGPU-M |
+| Pwr:Usage/Cap    | Temp       Perf | Memory-Usage        | GPU-State            |
+|==================+=================+=====================+======================|
+| 0     MetaX N260 | 0           Off | 0000:b5:00.0        | 0%          Disabled |
+| 53W / 225W       | 43C          P9 | 60459/65536 MiB     | Available            |
++------------------+-----------------+---------------------+----------------------+
+| 1     MetaX N260 | 1           Off | 0000:b6:00.0        | 0%          Disabled |
+| 50W / 225W       | 42C          P9 | 60459/65536 MiB     | Available            |
++------------------+-----------------+---------------------+----------------------+
+| 2     MetaX N260 | 2           Off | 0000:b9:00.0        | 62%         Disabled |
+| 130W / 225W      | 64C          P9 | 41042/65536 MiB     | Available            |
++------------------+-----------------+---------------------+----------------------+
+| 3     MetaX N260 | 3           Off | 0000:bd:00.0        | 60%         Disabled |
+| 126W / 225W      | 61C          P9 | 39916/65536 MiB     | Available            |
++------------------+-----------------+---------------------+----------------------+
+
++---------------------------------------------------------------------------------+
+| Process:                                                                        |
+|  GPU                    PID         Process Name                 GPU Memory     |
+|                                                                  Usage(MiB)     |
+|=================================================================================|
+|  0                  1007916         VLLM::Worker_TP              59790          |
+|  1                  1007917         VLLM::Worker_TP              59790          |
+|  2                  1217897         python                       5846           |
+|  2                  1229576         python                       34528          |
+|  3                  1217897         python                       5384           |
+|  3                  1229576         python                       33864