소스 검색

修复数据集下载问题

lxylxy123321 1 주 전
부모
커밋
9e88076d95
2개의 변경된 파일74개의 추가작업 그리고 20개의 파일을 삭제
  1. 74 19
      backend/app/services/dataset_service.py
  2. 0 1
      backend/requirements.txt

+ 74 - 19
backend/app/services/dataset_service.py

@@ -1,9 +1,13 @@
+import asyncio
 import json
 import uuid
 from datetime import datetime, timezone
 from pathlib import Path
 from typing import Any
 
+import urllib.request
+import urllib.parse
+
 from fastapi import UploadFile
 
 from app.config import get_settings
@@ -18,47 +22,98 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
     """从 HuggingFace 或 ModelScope 下载数据集。"""
     try:
         if req.use_modelscope:
-            from modelscope.msdatasets import MsDataset
-
-            ds = MsDataset.load(req.dataset_id)
-            ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
+            # 用 asyncio.to_thread 包裹同步下载,避免阻塞事件循环
+            ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
         else:
             from datasets import load_dataset
 
             ds = load_dataset(req.dataset_id)
             ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
-
-        ds_dir.mkdir(parents=True, exist_ok=True)
-        # 保存为 JSONL
-        if "train" in ds:
-            split = ds["train"]
-        else:
-            split = ds[list(ds.keys())[0]]
-        output_path = ds_dir / "data.jsonl"
-        with open(output_path, "w", encoding="utf-8") as f:
-            for item in split:
-                f.write(json.dumps(item, ensure_ascii=False) + "\n")
+            ds_dir.mkdir(parents=True, exist_ok=True)
+            if "train" in ds:
+                split = ds["train"]
+            else:
+                split = ds[list(ds.keys())[0]]
+            output_path = ds_dir / "data.jsonl"
+            with open(output_path, "w", encoding="utf-8") as f:
+                for item in split:
+                    f.write(json.dumps(item, ensure_ascii=False) + "\n")
+            jsonl_path = output_path
+            record_count = len(split) if hasattr(split, "__len__") else 0
 
         # 写入数据库
         record = DatasetRecord(
             id=str(uuid.uuid4()),
             name=req.dataset_id,
             format="jsonl",
-            record_count=len(split),
-            file_path=str(output_path),
+            record_count=record_count,
+            file_path=str(jsonl_path),
             created_at=datetime.now(timezone.utc),
         )
         async with async_session() as session:
             session.add(record)
             await session.commit()
 
-        logger.info(f"Downloaded dataset: {req.dataset_id} ({len(split)} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
-        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(output_path))
+        logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
+        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
     except Exception as e:
         logger.error(f"Dataset download failed: {e}")
         return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
 
 
+def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
+    """同步下载 ModelScope 数据集(不依赖 MsDataset,避免与新 datasets 库的兼容问题)。"""
+    base_url = "https://www.modelscope.cn/api/v1/datasets/{}/repo".format(dataset_id)
+    ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
+    ds_dir.mkdir(parents=True, exist_ok=True)
+
+    # 获取文件列表
+    params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master"})
+    resp = urllib.request.urlopen(f"{base_url}?{params}&FilePath=", timeout=60)
+    file_list = json.loads(resp.read()).get("Data", {}).get("Files", [])
+    files = [f.get("Key", "") for f in file_list]
+
+    # 找目标文件
+    target_file = None
+    for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
+        if name in files:
+            target_file = name
+            break
+    if not target_file:
+        for f in files:
+            if f.endswith((".jsonl", ".json")):
+                target_file = f
+                break
+    if not target_file:
+        raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
+
+    # 下载文件
+    params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master", "FilePath": target_file})
+    resp = urllib.request.urlopen(f"{base_url}?{params}", timeout=600)
+    content = resp.read()
+
+    temp_path = ds_dir / target_file
+    temp_path.write_bytes(content)
+
+    # 读取数据
+    if target_file.endswith(".jsonl"):
+        with open(temp_path, "r", encoding="utf-8") as f:
+            records = [json.loads(line.strip()) for line in f if line.strip()]
+    else:
+        records = json.loads(content)
+        if not isinstance(records, list):
+            records = [records]
+
+    record_count = len(records)
+    # 统一转为 JSONL
+    jsonl_path = ds_dir / "data.jsonl"
+    with open(jsonl_path, "w", encoding="utf-8") as f:
+        for item in records:
+            f.write(json.dumps(item, ensure_ascii=False) + "\n")
+
+    return ds_dir, jsonl_path, record_count
+
+
 async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     """保存上传文件并写入数据库。"""
     upload_dir = settings.uploads_dir

+ 0 - 1
backend/requirements.txt

@@ -24,7 +24,6 @@ scikit-learn>=1.5.0
 pillow>=10.4.0
 huggingface_hub>=0.25.0
 modelscope>=1.15.0
-addict>=2.4.0
 pandas>=2.2.0
 pyarrow>=17.0.0
 sentencepiece>=0.2.0