Procházet zdrojové kódy

修复数据集下载问题

lxylxy123321 před 1 týdnem
rodič
revize
9e88076d95
2 změnil soubory, kde provedl 74 přidání a 20 odebrání
  1. 74 19
      backend/app/services/dataset_service.py
  2. 0 1
      backend/requirements.txt

+ 74 - 19
backend/app/services/dataset_service.py

@@ -1,9 +1,13 @@
+import asyncio
 import json
 import json
 import uuid
 import uuid
 from datetime import datetime, timezone
 from datetime import datetime, timezone
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
 
 
+import urllib.request
+import urllib.parse
+
 from fastapi import UploadFile
 from fastapi import UploadFile
 
 
 from app.config import get_settings
 from app.config import get_settings
@@ -18,47 +22,98 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
     """从 HuggingFace 或 ModelScope 下载数据集。"""
     """从 HuggingFace 或 ModelScope 下载数据集。"""
     try:
     try:
         if req.use_modelscope:
         if req.use_modelscope:
-            from modelscope.msdatasets import MsDataset
-
-            ds = MsDataset.load(req.dataset_id)
-            ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
+            # 用 asyncio.to_thread 包裹同步下载,避免阻塞事件循环
+            ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
         else:
         else:
             from datasets import load_dataset
             from datasets import load_dataset
 
 
             ds = load_dataset(req.dataset_id)
             ds = load_dataset(req.dataset_id)
             ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
             ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
-
-        ds_dir.mkdir(parents=True, exist_ok=True)
-        # 保存为 JSONL
-        if "train" in ds:
-            split = ds["train"]
-        else:
-            split = ds[list(ds.keys())[0]]
-        output_path = ds_dir / "data.jsonl"
-        with open(output_path, "w", encoding="utf-8") as f:
-            for item in split:
-                f.write(json.dumps(item, ensure_ascii=False) + "\n")
+            ds_dir.mkdir(parents=True, exist_ok=True)
+            if "train" in ds:
+                split = ds["train"]
+            else:
+                split = ds[list(ds.keys())[0]]
+            output_path = ds_dir / "data.jsonl"
+            with open(output_path, "w", encoding="utf-8") as f:
+                for item in split:
+                    f.write(json.dumps(item, ensure_ascii=False) + "\n")
+            jsonl_path = output_path
+            record_count = len(split) if hasattr(split, "__len__") else 0
 
 
         # 写入数据库
         # 写入数据库
         record = DatasetRecord(
         record = DatasetRecord(
             id=str(uuid.uuid4()),
             id=str(uuid.uuid4()),
             name=req.dataset_id,
             name=req.dataset_id,
             format="jsonl",
             format="jsonl",
-            record_count=len(split),
-            file_path=str(output_path),
+            record_count=record_count,
+            file_path=str(jsonl_path),
             created_at=datetime.now(timezone.utc),
             created_at=datetime.now(timezone.utc),
         )
         )
         async with async_session() as session:
         async with async_session() as session:
             session.add(record)
             session.add(record)
             await session.commit()
             await session.commit()
 
 
-        logger.info(f"Downloaded dataset: {req.dataset_id} ({len(split)} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
-        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(output_path))
+        logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
+        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
     except Exception as e:
     except Exception as e:
         logger.error(f"Dataset download failed: {e}")
         logger.error(f"Dataset download failed: {e}")
         return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
         return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
 
 
 
 
+def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
+    """同步下载 ModelScope 数据集(不依赖 MsDataset,避免与新 datasets 库的兼容问题)。"""
+    base_url = "https://www.modelscope.cn/api/v1/datasets/{}/repo".format(dataset_id)
+    ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
+    ds_dir.mkdir(parents=True, exist_ok=True)
+
+    # 获取文件列表
+    params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master"})
+    resp = urllib.request.urlopen(f"{base_url}?{params}&FilePath=", timeout=60)
+    file_list = json.loads(resp.read()).get("Data", {}).get("Files", [])
+    files = [f.get("Key", "") for f in file_list]
+
+    # 找目标文件
+    target_file = None
+    for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
+        if name in files:
+            target_file = name
+            break
+    if not target_file:
+        for f in files:
+            if f.endswith((".jsonl", ".json")):
+                target_file = f
+                break
+    if not target_file:
+        raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
+
+    # 下载文件
+    params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master", "FilePath": target_file})
+    resp = urllib.request.urlopen(f"{base_url}?{params}", timeout=600)
+    content = resp.read()
+
+    temp_path = ds_dir / target_file
+    temp_path.write_bytes(content)
+
+    # 读取数据
+    if target_file.endswith(".jsonl"):
+        with open(temp_path, "r", encoding="utf-8") as f:
+            records = [json.loads(line.strip()) for line in f if line.strip()]
+    else:
+        records = json.loads(content)
+        if not isinstance(records, list):
+            records = [records]
+
+    record_count = len(records)
+    # 统一转为 JSONL
+    jsonl_path = ds_dir / "data.jsonl"
+    with open(jsonl_path, "w", encoding="utf-8") as f:
+        for item in records:
+            f.write(json.dumps(item, ensure_ascii=False) + "\n")
+
+    return ds_dir, jsonl_path, record_count
+
+
 async def upload_dataset(file: UploadFile) -> dict[str, Any]:
 async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     """保存上传文件并写入数据库。"""
     """保存上传文件并写入数据库。"""
     upload_dir = settings.uploads_dir
     upload_dir = settings.uploads_dir

+ 0 - 1
backend/requirements.txt

@@ -24,7 +24,6 @@ scikit-learn>=1.5.0
 pillow>=10.4.0
 pillow>=10.4.0
 huggingface_hub>=0.25.0
 huggingface_hub>=0.25.0
 modelscope>=1.15.0
 modelscope>=1.15.0
-addict>=2.4.0
 pandas>=2.2.0
 pandas>=2.2.0
 pyarrow>=17.0.0
 pyarrow>=17.0.0
 sentencepiece>=0.2.0
 sentencepiece>=0.2.0