| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- from pathlib import Path
- from typing import Any
- from fastapi import UploadFile
- from app.config import get_settings
- from app.core.logging import logger
- from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
- settings = get_settings()
- async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
- """从 HuggingFace 或 ModelScope 下载数据集。"""
- import os
- import uuid
- download_dir = settings.processed_dir
- download_dir.mkdir(parents=True, exist_ok=True)
- if req.use_modelscope:
- try:
- from modelscope.msdatasets import MsDataset
- MsDataset.load(req.dataset_id, split="train")
- path = str(download_dir / f"ms_{req.dataset_id.replace('/', '_')}")
- logger.info(f"Downloaded dataset from ModelScope: {req.dataset_id}")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
- except Exception as e:
- logger.error(f"ModelScope dataset download failed: {e}")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
- else:
- try:
- from datasets import load_dataset
- load_dataset(req.dataset_id)
- path = str(download_dir / f"hf_{req.dataset_id.replace('/', '_')}")
- logger.info(f"Downloaded dataset from HuggingFace: {req.dataset_id}")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
- except Exception as e:
- logger.error(f"HuggingFace dataset download failed: {e}")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
- async def upload_dataset(file: UploadFile) -> dict[str, Any]:
- """保存上传文件并检测格式。"""
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- file_path = upload_dir / file.filename
- content = await file.read()
- file_path.write_bytes(content)
- fmt = _detect_format(file.filename or "")
- logger.info(f"Uploaded dataset: {file_path} (format={fmt})")
- return {"path": str(file_path), "format": fmt, "size": len(content)}
- def _detect_format(filename: str) -> str:
- ext = Path(filename).suffix.lower().lstrip(".")
- if ext in ("jsonl", "csv", "parquet", "json"):
- return ext
- return "unknown"
|