lxylxy123321 1 неделя назад
Родитель
Сommit
b34f33cf72
1 измененных файлов с 36 добавлено и 14 удалено
  1. 36 14
      backend/app/services/dataset_service.py

+ 36 - 14
backend/app/services/dataset_service.py

@@ -60,27 +60,49 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
 
 
 def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
-    """通过 ModelScope SDK 下载数据集,与模型下载保持一致的可靠性。"""
-    from modelscope import MsDataset
+    """用 snapshot_download 下载数据集文件,完全绕过 datasets 库,避免版本兼容问题。"""
+    from modelscope import snapshot_download
 
     ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
     ds_dir.mkdir(parents=True, exist_ok=True)
 
-    # 使用 SDK 下载,避免手动 API 的 404 问题
-    ms_ds = MsDataset.load(dataset_id, cache_dir=str(settings.processed_dir))
-
-    # 确定使用的 split
-    if hasattr(ms_ds, "split_names") and ms_ds.split_names:
-        split_name = "train" if "train" in ms_ds.split_names else ms_ds.split_names[0]
-        split = ms_ds[split_name]
-    else:
-        split = ms_ds
-
-    # 统一转为 JSONL
+    # 用 snapshot_download 下载数据集文件到本地
+    local_path = snapshot_download(dataset_id, cache_dir=str(settings.processed_dir))
+
+    # 扫描下载目录中的 JSON/JSONL 文件
+    data_files = []
+    for p in Path(local_path).rglob("*"):
+        if p.is_file() and p.suffix in (".json", ".jsonl"):
+            data_files.append(p)
+
+    if not data_files:
+        raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
+
+    # 优先取 train / data 开头的文件
+    target = None
+    for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
+        for f in data_files:
+            if f.name == name:
+                target = f
+                break
+        if target:
+            break
+    if not target:
+        target = data_files[0]
+
+    # 读取并统一转为 JSONL
     jsonl_path = ds_dir / "data.jsonl"
     record_count = 0
+    content = target.read_text(encoding="utf-8")
+    if target.suffix == ".jsonl":
+        records = [json.loads(line.strip()) for line in content.splitlines() if line.strip()]
+    else:
+        records = json.loads(content)
+        if not isinstance(records, list):
+            records = [records]
+
     with open(jsonl_path, "w", encoding="utf-8") as f:
-        for item in split:
+        for item in records:
             f.write(json.dumps(item, ensure_ascii=False) + "\n")
             record_count += 1