Ver código fonte

支持数据集从平台下载

lxylxy123321 2 semanas atrás
pai
commit
39dbc3320b

+ 9 - 0
backend/app/api/datasets.py

@@ -1,14 +1,23 @@
 from fastapi import APIRouter, UploadFile, File, Query
 
 from app.schemas.dataset import (
+    DatasetDownloadRequest,
+    DatasetDownloadResponse,
     DatasetPreviewResponse,
     DatasetUploadResponse,
     DatasetValidationResult,
 )
+from app.services import dataset_service
 
 router = APIRouter()
 
 
+@router.post("/download", response_model=DatasetDownloadResponse)
+async def download_dataset(req: DatasetDownloadRequest):
+    """从 HuggingFace 或 ModelScope 下载数据集。"""
+    return dataset_service.download_dataset(req)
+
+
 @router.post("/upload", response_model=DatasetUploadResponse)
 async def upload_dataset(file: UploadFile = File(...)):
     """上传数据集文件(JSONL / CSV / Parquet / JSON)。"""

+ 13 - 1
backend/app/schemas/dataset.py

@@ -1,6 +1,6 @@
 from enum import Enum
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, Field
 
 
 class DatasetFormat(str, Enum):
@@ -19,6 +19,18 @@ class DatasetUploadResponse(BaseModel):
     created_at: str
 
 
+class DatasetDownloadRequest(BaseModel):
+    dataset_id: str = Field(..., description="HuggingFace or ModelScope dataset ID, e.g. 'glue', 'MRPC'")
+    use_modelscope: bool = Field(default=False, description="Use ModelScope instead of HuggingFace")
+
+
+class DatasetDownloadResponse(BaseModel):
+    dataset_id: str
+    status: str  # "downloading" | "completed" | "failed"
+    path: str | None = None
+    error: str | None = None
+
+
 class DatasetPreviewRow(BaseModel):
     row_index: int
     data: dict

+ 31 - 0
backend/app/services/dataset_service.py

@@ -4,10 +4,41 @@ from typing import Any
 from fastapi import UploadFile
 from app.config import get_settings
 from app.core.logging import logger
+from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
 
 settings = get_settings()
 
 
+async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
+    """从 HuggingFace 或 ModelScope 下载数据集。"""
+    import os
+    import uuid
+
+    download_dir = settings.processed_dir
+    download_dir.mkdir(parents=True, exist_ok=True)
+
+    if req.use_modelscope:
+        try:
+            from modelscope.msdatasets import MsDataset
+            MsDataset.load(req.dataset_id, split="train")
+            path = str(download_dir / f"ms_{req.dataset_id.replace('/', '_')}")
+            logger.info(f"Downloaded dataset from ModelScope: {req.dataset_id}")
+            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
+        except Exception as e:
+            logger.error(f"ModelScope dataset download failed: {e}")
+            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
+    else:
+        try:
+            from datasets import load_dataset
+            load_dataset(req.dataset_id)
+            path = str(download_dir / f"hf_{req.dataset_id.replace('/', '_')}")
+            logger.info(f"Downloaded dataset from HuggingFace: {req.dataset_id}")
+            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
+        except Exception as e:
+            logger.error(f"HuggingFace dataset download failed: {e}")
+            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
+
+
 async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     """保存上传文件并检测格式。"""
     upload_dir = settings.uploads_dir

+ 14 - 1
frontend/src/api/client.ts

@@ -23,6 +23,12 @@ const api = {
       form.append('file', file)
       return fetch('/api/v1/datasets/upload', { method: 'POST', body: form }).then(r => r.json()) as Promise<DatasetInfo>
     },
+    download: (datasetId: string, useModelscope = false) =>
+      fetch('/api/v1/datasets/download', {
+        method: 'POST',
+        headers: { 'Content-Type': 'application/json' },
+        body: JSON.stringify({ dataset_id: datasetId, use_modelscope: useModelscope }),
+      }).then(r => r.json()) as Promise<DatasetDownloadResponse>,
     preview: (id: string, rows = 10) =>
       fetch(`/api/v1/datasets/${id}/preview?rows=${rows}`).then(r => r.json()) as Promise<DatasetPreview>,
     validate: (id: string) =>
@@ -101,6 +107,13 @@ interface DatasetInfo {
   created_at: string
 }
 
+interface DatasetDownloadResponse {
+  dataset_id: string
+  status: string
+  path?: string
+  error?: string
+}
+
 interface DatasetPreview {
   total_records: number
   preview_rows: { row_index: number; data: Record<string, unknown> }[]
@@ -179,4 +192,4 @@ interface DeployResponse {
   error?: string
 }
 
-export type { ModelInfo, ModelDownloadResponse, DatasetInfo, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse }
+export type { ModelInfo, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse }

+ 42 - 0
frontend/src/pages/Datasets.tsx

@@ -4,10 +4,16 @@ import api, { DatasetInfo } from '../api/client'
 export function Datasets() {
   const [datasets, setDatasets] = useState<DatasetInfo[]>([])
   const [uploading, setUploading] = useState(false)
+  const [downloading, setDownloading] = useState(false)
   const [loading, setLoading] = useState(false)
   const [previewData, setPreviewData] = useState<{ columns: string[]; rows: { row_index: number; data: Record<string, unknown> }[] } | null>(null)
   const inputRef = useRef<HTMLInputElement>(null)
 
+  // Download form
+  const [dlDatasetId, setDlDatasetId] = useState('')
+  const [dlUseModelscope, setDlUseModelscope] = useState(false)
+  const [dlStatus, setDlStatus] = useState('')
+
   const fetchDatasets = () => {
     setLoading(true)
     api.datasets.list()
@@ -33,6 +39,16 @@ export function Datasets() {
     if (file) handleFileUpload(file)
   }
 
+  const handleDownload = () => {
+    if (!dlDatasetId.trim()) return
+    setDownloading(true)
+    setDlStatus('正在下载...')
+    api.datasets.download(dlDatasetId, dlUseModelscope)
+      .then(res => setDlStatus(`${res.dataset_id}: ${res.status}${res.error ? ` - ${res.error}` : ''}`))
+      .catch(err => setDlStatus(`下载失败: ${err.message}`))
+      .finally(() => setDownloading(false))
+  }
+
   const handlePreview = (id: string) => {
     api.datasets.preview(id, 10)
       .then(res => setPreviewData({ columns: res.columns, rows: res.preview_rows }))
@@ -73,6 +89,32 @@ export function Datasets() {
         />
       </div>
 
+      {/* Download section */}
+      <div style={{ marginTop: 24 }}>
+        <h2 style={{ margin: '0 0 12px', fontSize: 16 }}>从平台下载</h2>
+        <div style={{ display: 'flex', gap: 8, alignItems: 'center' }}>
+          <input
+            type="text"
+            placeholder="数据集 ID (如 glue, MRPC, stanfordnlp/imdb)"
+            value={dlDatasetId}
+            onChange={e => setDlDatasetId(e.target.value)}
+            style={{ padding: '8px 12px', width: 400, borderRadius: 4, border: '1px solid #ccc' }}
+          />
+          <label style={{ fontSize: 13, color: '#666', whiteSpace: 'nowrap' }}>
+            <input type="checkbox" checked={dlUseModelscope} onChange={e => setDlUseModelscope(e.target.checked)} />
+            {' '}ModelScope
+          </label>
+          <button
+            onClick={handleDownload}
+            disabled={downloading}
+            style={{ padding: '8px 16px', borderRadius: 4, border: 'none', background: '#e94560', color: '#fff', cursor: 'pointer', opacity: downloading ? 0.6 : 1 }}
+          >
+            {downloading ? '下载中...' : '下载数据集'}
+          </button>
+        </div>
+        {dlStatus && <p style={{ marginTop: 8, fontSize: 13, color: dlStatus.includes('failed') || dlStatus.includes('失败') ? '#e94560' : '#666' }}>{dlStatus}</p>}
+      </div>
+
       {/* Dataset list */}
       <div style={{ marginTop: 24 }}>
         <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>