dataset_service.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from pathlib import Path
  2. from typing import Any
  3. from fastapi import UploadFile
  4. from app.config import get_settings
  5. from app.core.logging import logger
  6. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  7. settings = get_settings()
  8. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  9. """从 HuggingFace 或 ModelScope 下载数据集。"""
  10. import os
  11. import uuid
  12. download_dir = settings.processed_dir
  13. download_dir.mkdir(parents=True, exist_ok=True)
  14. if req.use_modelscope:
  15. try:
  16. from modelscope.msdatasets import MsDataset
  17. MsDataset.load(req.dataset_id, split="train")
  18. path = str(download_dir / f"ms_{req.dataset_id.replace('/', '_')}")
  19. logger.info(f"Downloaded dataset from ModelScope: {req.dataset_id}")
  20. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
  21. except Exception as e:
  22. logger.error(f"ModelScope dataset download failed: {e}")
  23. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
  24. else:
  25. try:
  26. from datasets import load_dataset
  27. load_dataset(req.dataset_id)
  28. path = str(download_dir / f"hf_{req.dataset_id.replace('/', '_')}")
  29. logger.info(f"Downloaded dataset from HuggingFace: {req.dataset_id}")
  30. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
  31. except Exception as e:
  32. logger.error(f"HuggingFace dataset download failed: {e}")
  33. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
  34. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  35. """保存上传文件并检测格式。"""
  36. upload_dir = settings.uploads_dir
  37. upload_dir.mkdir(parents=True, exist_ok=True)
  38. file_path = upload_dir / file.filename
  39. content = await file.read()
  40. file_path.write_bytes(content)
  41. fmt = _detect_format(file.filename or "")
  42. logger.info(f"Uploaded dataset: {file_path} (format={fmt})")
  43. return {"path": str(file_path), "format": fmt, "size": len(content)}
  44. def _detect_format(filename: str) -> str:
  45. ext = Path(filename).suffix.lower().lstrip(".")
  46. if ext in ("jsonl", "csv", "parquet", "json"):
  47. return ext
  48. return "unknown"