Forráskód Böngészése

增加删除功能,优化界面显示

lxylxy123321 2 hete
szülő
commit
c189acc6b2

+ 8 - 5
backend/app/api/datasets.py

@@ -1,4 +1,4 @@
-from fastapi import APIRouter, UploadFile, File, Query
+from fastapi import APIRouter, UploadFile, File, Query, HTTPException
 
 from app.schemas.dataset import (
     DatasetDownloadRequest,
@@ -12,13 +12,16 @@ from app.services import dataset_service
 router = APIRouter()
 
 
-@router.post("/download", response_model=DatasetDownloadResponse)
+@router.post("/download", response_model=DatasetDownloadResponse, status_code=200)
 async def download_dataset(req: DatasetDownloadRequest):
     """从 HuggingFace 或 ModelScope 下载数据集。"""
-    return await dataset_service.download_dataset(req)
+    result = await dataset_service.download_dataset(req)
+    if result.status == "failed":
+        raise HTTPException(status_code=400, detail=result.error or "Dataset download failed")
+    return result
 
 
-@router.post("/upload", response_model=DatasetUploadResponse)
+@router.post("/upload", response_model=DatasetUploadResponse, status_code=201)
 async def upload_dataset(file: UploadFile = File(...)):
     """上传数据集文件(JSONL / CSV / Parquet / JSON)。"""
     result = await dataset_service.upload_dataset(file)
@@ -46,7 +49,7 @@ async def list_datasets():
     return [DatasetUploadResponse(**item) for item in items]
 
 
-@router.delete("/{dataset_id}")
+@router.delete("/{dataset_id}", status_code=200)
 async def delete_dataset(dataset_id: str):
     """删除数据集。"""
     return await dataset_service.delete_dataset(dataset_id)

+ 28 - 22
backend/app/api/models.py

@@ -1,4 +1,4 @@
-from fastapi import APIRouter
+from fastapi import APIRouter, HTTPException
 
 from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
 from app.services import model_service
@@ -8,26 +8,30 @@ router = APIRouter()
 
 @router.get("/", response_model=list[ModelInfo])
 async def list_models():
-    """列出所有本地缓存的模型。"""
-    models = model_service.list_cached_models()
+    """列出所有本地缓存的模型(从数据库读取)。"""
+    models = await model_service.list_cached_models()
     return [
         ModelInfo(
             id=m["id"],
-            name=m.get("name", m["id"]),
-            model_type=m.get("model_type", "text"),
-            path=m.get("path"),
-            is_downloaded=m.get("is_downloaded", True),
-            context_length=m.get("context_length"),
-            supported_peft_methods=m.get("supported_peft_methods", []),
+            name=m["name"],
+            model_type=m["model_type"],
+            path=m["path"],
+            is_downloaded=m["is_downloaded"],
+            context_length=m["context_length"],
+            supported_peft_methods=m["supported_peft_methods"],
         )
         for m in models
     ]
 
 
-@router.post("/download", response_model=ModelDownloadResponse)
+@router.post("/download", response_model=ModelDownloadResponse, status_code=200)
 async def download_model(req: ModelDownloadRequest):
     """从 HuggingFace 或 ModelScope 下载模型。"""
     result = await model_service.download_model(req.model_id, req.use_modelscope)
+
+    if result["status"] == "failed":
+        raise HTTPException(status_code=400, detail=result.get("error", "Download failed"))
+
     return ModelDownloadResponse(
         model_id=result["model_id"],
         status=result["status"],
@@ -43,16 +47,18 @@ async def get_model_info(model_id: str):
     if info:
         return ModelInfo(
             id=info["id"],
-            name=info.get("name", model_id.split("/")[-1]),
-            model_type=info.get("model_type", "text"),
-            path=info.get("path"),
-            is_downloaded=info.get("is_downloaded", True),
-            context_length=info.get("context_length"),
-            supported_peft_methods=info.get("supported_peft_methods", []),
+            name=info["name"],
+            model_type=info["model_type"],
+            path=info["path"],
+            is_downloaded=info["is_downloaded"],
+            context_length=info["context_length"],
+            supported_peft_methods=info["supported_peft_methods"],
         )
-    return ModelInfo(
-        id=model_id,
-        name=model_id.split("/")[-1],
-        model_type="text",
-        is_downloaded=False,
-    )
+    raise HTTPException(status_code=404, detail=f"Model not found: {model_id}")
+
+
+@router.delete("/{model_id}")
+async def delete_model(model_id: str):
+    """删除已缓存的模型(数据库记录 + 本地文件)。"""
+    result = await model_service.delete_model(model_id)
+    return result

+ 73 - 57
backend/app/services/model_service.py

@@ -38,18 +38,28 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
             model_type = cfg.get("model_type", "text")
             context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
 
-        # 写入数据库
+        # 写入数据库(如果已存在则更新)
         async with async_session() as session:
-            record = ModelCache(
-                id=model_id,
-                name=model_id.split("/")[-1],
-                model_type=model_type,
-                path=local_path,
-                is_downloaded=1,
-                context_length=context_length,
-                supported_peft_methods=peft_methods,
-            )
-            session.add(record)
+            result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
+            existing = result.scalar_one_or_none()
+            if existing:
+                existing.name = model_id.split("/")[-1]
+                existing.model_type = model_type
+                existing.path = local_path
+                existing.is_downloaded = 1
+                existing.context_length = context_length
+                existing.supported_peft_methods = peft_methods
+            else:
+                record = ModelCache(
+                    id=model_id,
+                    name=model_id.split("/")[-1],
+                    model_type=model_type,
+                    path=local_path,
+                    is_downloaded=1,
+                    context_length=context_length,
+                    supported_peft_methods=peft_methods,
+                )
+                session.add(record)
             await session.commit()
 
         logger.info(f"Model downloaded: {model_id} -> {local_path}")
@@ -62,39 +72,37 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
         return {"model_id": model_id, "status": "failed", "error": error_msg}
 
 
-def list_cached_models() -> list[dict[str, Any]]:
-    """列出本地已缓存的模型。"""
-    models_dir = settings.models_dir
-    if not models_dir.exists():
-        return []
-
-    result = []
-    for d in models_dir.iterdir():
-        if not d.is_dir():
-            continue
-        config_path = d / "config.json"
-        info: dict[str, Any] = {
-            "id": d.name,
-            "name": d.name,
-            "model_type": "text",
-            "path": str(d),
-            "is_downloaded": True,
-            "context_length": None,
-            "supported_peft_methods": [],
-        }
-        if config_path.exists():
-            with open(config_path) as f:
-                cfg = json.load(f)
-            info["model_type"] = cfg.get("model_type", "text")
-            info["context_length"] = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
-            info["supported_peft_methods"] = ["lora", "qlora", "ia3", "adalora", "prefix_tuning"]
-        result.append(info)
-    return result
+async def list_cached_models() -> list[dict[str, Any]]:
+    """从数据库列出已缓存的模型(不扫描目录,避免 HF 缓存子目录干扰)。"""
+    async with async_session() as session:
+        result = await session.execute(select(ModelCache).order_by(ModelCache.created_at.desc()))
+        records = result.scalars().all()
+
+    models = []
+    for r in records:
+        # 验证目录是否真的存在,如果不存在则标记为未下载
+        dir_exists = r.path and Path(r.path).exists()
+        if not dir_exists:
+            # 尝试从 models_dir 下查找
+            alt_path = settings.models_dir / r.id.replace("/", "_")
+            dir_exists = alt_path.exists()
+            if dir_exists:
+                r.path = str(alt_path)
+
+        models.append({
+            "id": r.id,
+            "name": r.name,
+            "model_type": r.model_type,
+            "path": r.path,
+            "is_downloaded": dir_exists,
+            "context_length": r.context_length,
+            "supported_peft_methods": r.supported_peft_methods.split(",") if r.supported_peft_methods else [],
+        })
+    return models
 
 
 async def get_model_info(model_id: str) -> dict[str, Any] | None:
     """获取已缓存模型的元数据。"""
-    # 先查数据库
     async with async_session() as session:
         result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
         record = result.scalar_one_or_none()
@@ -104,24 +112,32 @@ async def get_model_info(model_id: str) -> dict[str, Any] | None:
                 "name": record.name,
                 "model_type": record.model_type,
                 "path": record.path,
-                "is_downloaded": bool(record.is_downloaded),
+                "is_downloaded": bool(record.is_downloaded) and Path(record.path).exists() if record.path else False,
                 "context_length": record.context_length,
                 "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
             }
-
-    # 回退:直接从文件系统读取
-    model_dir = settings.models_dir / model_id.replace("/", "_")
-    config_path = model_dir / "config.json"
-    if config_path.exists():
-        with open(config_path) as f:
-            cfg = json.load(f)
-        return {
-            "id": model_id,
-            "name": model_id.split("/")[-1],
-            "model_type": cfg.get("model_type", "text"),
-            "path": str(model_dir),
-            "is_downloaded": True,
-            "context_length": cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048)),
-            "supported_peft_methods": ["lora", "qlora", "ia3", "adalora", "prefix_tuning"],
-        }
     return None
+
+
+async def delete_model(model_id: str) -> dict[str, Any]:
+    """删除已缓存的模型(数据库记录 + 本地文件)。"""
+    async with async_session() as session:
+        result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
+        record = result.scalar_one_or_none()
+        if not record:
+            return {"status": "not_found", "message": f"Model not found: {model_id}"}
+
+        # 删除本地文件目录
+        model_dir = Path(record.path) if record.path else settings.models_dir / record.id.replace("/", "_")
+        deleted_files = False
+        if model_dir.exists() and model_dir.is_dir():
+            import shutil
+            shutil.rmtree(model_dir, ignore_errors=True)
+            deleted_files = True
+
+        # 删除数据库记录
+        await session.delete(record)
+        await session.commit()
+
+        logger.info(f"Model deleted: {model_id} (files={deleted_files})")
+        return {"status": "deleted", "model_id": model_id, "files_deleted": deleted_files}

+ 37 - 20
frontend/src/api/client.ts

@@ -1,91 +1,108 @@
+// 统一的 fetch 包装器:非 2xx 状态码自动抛出错误
+async function apiFetch(url: string, init?: RequestInit): Promise<Response> {
+  const res = await fetch(url, init)
+  if (!res.ok) {
+    try {
+      const err = await res.json()
+      throw new Error(err.detail || err.error || `Request failed: ${res.status}`)
+    } catch (e) {
+      if (e instanceof Error) throw e
+      throw new Error(`Request failed: ${res.status}`)
+    }
+  }
+  return res
+}
+
 const api = {
   // --- Health ---
-  health: () => fetch('/health').then(r => r.json()),
+  health: () => apiFetch('/health').then(r => r.json()),
 
   // --- Models ---
   models: {
-    list: () => fetch('/api/v1/models/').then(r => r.json()) as Promise<ModelInfo[]>,
+    list: () => apiFetch('/api/v1/models/').then(r => r.json()) as Promise<ModelInfo[]>,
     download: (modelId: string, useModelscope = false) =>
-      fetch('/api/v1/models/download', {
+      apiFetch('/api/v1/models/download', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify({ model_id: modelId, use_modelscope: useModelscope }),
       }).then(r => r.json()) as Promise<ModelDownloadResponse>,
+    delete: (modelId: string) =>
+      apiFetch(`/api/v1/models/${encodeURIComponent(modelId)}`, { method: 'DELETE' }).then(r => r.json()),
     getInfo: (modelId: string) =>
-      fetch(`/api/v1/models/${encodeURIComponent(modelId)}`).then(r => r.json()) as Promise<ModelInfo>,
+      apiFetch(`/api/v1/models/${encodeURIComponent(modelId)}`).then(r => r.json()) as Promise<ModelInfo>,
   },
 
   // --- Datasets ---
   datasets: {
-    list: () => fetch('/api/v1/datasets/').then(r => r.json()) as Promise<DatasetInfo[]>,
+    list: () => apiFetch('/api/v1/datasets/').then(r => r.json()) as Promise<DatasetInfo[]>,
     upload: (file: File) => {
       const form = new FormData()
       form.append('file', file)
-      return fetch('/api/v1/datasets/upload', { method: 'POST', body: form }).then(r => r.json()) as Promise<DatasetInfo>
+      return apiFetch('/api/v1/datasets/upload', { method: 'POST', body: form }).then(r => r.json()) as Promise<DatasetInfo>
     },
     download: (datasetId: string, useModelscope = false) =>
-      fetch('/api/v1/datasets/download', {
+      apiFetch('/api/v1/datasets/download', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify({ dataset_id: datasetId, use_modelscope: useModelscope }),
       }).then(r => r.json()) as Promise<DatasetDownloadResponse>,
     preview: (id: string, rows = 10) =>
-      fetch(`/api/v1/datasets/${id}/preview?rows=${rows}`).then(r => r.json()) as Promise<DatasetPreview>,
+      apiFetch(`/api/v1/datasets/${id}/preview?rows=${rows}`).then(r => r.json()) as Promise<DatasetPreview>,
     validate: (id: string) =>
-      fetch(`/api/v1/datasets/${id}/validate`, { method: 'POST' }).then(r => r.json()) as Promise<DatasetValidation>,
+      apiFetch(`/api/v1/datasets/${id}/validate`, { method: 'POST' }).then(r => r.json()) as Promise<DatasetValidation>,
     delete: (id: string) =>
-      fetch(`/api/v1/datasets/${id}`, { method: 'DELETE' }).then(r => r.json()),
+      apiFetch(`/api/v1/datasets/${id}`, { method: 'DELETE' }).then(r => r.json()),
   },
 
   // --- Training ---
   training: {
-    list: () => fetch('/api/v1/training/jobs').then(r => r.json()) as Promise<TrainingJob[]>,
+    list: () => apiFetch('/api/v1/training/jobs').then(r => r.json()) as Promise<TrainingJob[]>,
     create: (cfg: TrainingConfig) =>
-      fetch('/api/v1/training/jobs', {
+      apiFetch('/api/v1/training/jobs', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify(cfg),
       }).then(r => r.json()) as Promise<TrainingJob>,
     get: (id: string) =>
-      fetch(`/api/v1/training/jobs/${id}`).then(r => r.json()) as Promise<TrainingJob>,
+      apiFetch(`/api/v1/training/jobs/${id}`).then(r => r.json()) as Promise<TrainingJob>,
     cancel: (id: string) =>
-      fetch(`/api/v1/training/jobs/${id}/cancel`, { method: 'POST' }).then(r => r.json()),
+      apiFetch(`/api/v1/training/jobs/${id}/cancel`, { method: 'POST' }).then(r => r.json()),
   },
 
   // --- Evaluation ---
   evaluation: {
     run: (cfg: EvalConfig) =>
-      fetch('/api/v1/evaluation/run', {
+      apiFetch('/api/v1/evaluation/run', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify(cfg),
       }).then(r => r.json()) as Promise<EvalResult>,
     results: (id: string) =>
-      fetch(`/api/v1/evaluation/${id}/results`).then(r => r.json()) as Promise<EvalResult>,
+      apiFetch(`/api/v1/evaluation/${id}/results`).then(r => r.json()) as Promise<EvalResult>,
   },
 
   // --- Deployment ---
   deployment: {
     export: (cfg: DeployConfig) =>
-      fetch('/api/v1/deployment/export', {
+      apiFetch('/api/v1/deployment/export', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify(cfg),
       }).then(r => r.json()) as Promise<DeployResponse>,
     status: (id: string) =>
-      fetch(`/api/v1/deployment/${id}/status`).then(r => r.json()) as Promise<DeployResponse>,
+      apiFetch(`/api/v1/deployment/${id}/status`).then(r => r.json()) as Promise<DeployResponse>,
   },
 
   // --- Inference ---
   inference: {
     generate: (req: InferenceRequest) =>
-      fetch('/api/v1/inference/generate', {
+      apiFetch('/api/v1/inference/generate', {
         method: 'POST',
         headers: { 'Content-Type': 'application/json' },
         body: JSON.stringify(req),
       }).then(r => r.json()) as Promise<InferenceResponse>,
     adapters: () =>
-      fetch('/api/v1/inference/adapters').then(r => r.json()) as Promise<AdapterInfo[]>,
+      apiFetch('/api/v1/inference/adapters').then(r => r.json()) as Promise<AdapterInfo[]>,
   },
 }
 

+ 18 - 3
frontend/src/pages/Models.tsx

@@ -22,11 +22,22 @@ export function Models() {
     setDownloading(true)
     setStatusMsg('正在下载...')
     api.models.download(modelId, useModelscope)
-      .then(res => setStatusMsg(`${res.model_id}: ${res.status}`))
-      .catch(err => setStatusMsg(`下载失败: ${err.message}`))
+      .then(res => setStatusMsg(`${res.model_id}: ${res.status}`))
+      .catch(err => setStatusMsg(`下载失败: ${err.message}`))
       .finally(() => setDownloading(false))
   }
 
+  const handleDelete = async (id: string, name: string) => {
+    if (!confirm(`确定删除模型 "${name}"?这将删除本地所有相关文件。`)) return
+    try {
+      await api.models.delete(id)
+      fetchModels()
+    } catch (err) {
+      const msg = err instanceof Error ? err.message : '删除失败'
+      setStatusMsg(`❌ ${msg}`)
+    }
+  }
+
   return (
     <div>
       <h1>模型注册</h1>
@@ -53,7 +64,7 @@ export function Models() {
         </button>
       </div>
 
-      {statusMsg && <p style={{ marginTop: 8, fontSize: 13, color: '#e94560' }}>{statusMsg}</p>}
+      {statusMsg && <p style={{ marginTop: 8, fontSize: 13, color: statusMsg.includes('❌') ? '#e94560' : '#4caf50' }}>{statusMsg}</p>}
 
       {/* Model list */}
       <div style={{ marginTop: 24 }}>
@@ -79,6 +90,7 @@ export function Models() {
                 <th>类型</th>
                 <th>状态</th>
                 <th>PEFT 支持</th>
+                <th>操作</th>
               </tr>
             </thead>
             <tbody>
@@ -91,6 +103,9 @@ export function Models() {
                     {m.is_downloaded ? '已缓存' : '未下载'}
                   </td>
                   <td>{m.supported_peft_methods.join(', ') || '-'}</td>
+                  <td>
+                    <button onClick={() => handleDelete(m.id, m.name)} style={{ padding: '2px 8px', color: '#e94560', border: '1px solid #e94560', borderRadius: 4, background: 'transparent', cursor: 'pointer' }}>删除</button>
+                  </td>
                 </tr>
               ))}
             </tbody>