lxylxy123321 1 неделя назад
Родитель
Сommit
025ca0ab87
1 измененных файлов с 14 добавлено и 41 удалено
  1. 14 41
      backend/app/services/dataset_service.py

+ 14 - 41
backend/app/services/dataset_service.py

@@ -5,8 +5,6 @@ from datetime import datetime, timezone
 from pathlib import Path
 from typing import Any
 
-import urllib.request
-import urllib.parse
 
 from fastapi import UploadFile
 
@@ -62,54 +60,29 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
 
 
 def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
-    """同步下载 ModelScope 数据集(不依赖 MsDataset,避免与新 datasets 库的兼容问题)。"""
-    base_url = "https://www.modelscope.cn/api/v1/datasets/{}/repo".format(dataset_id)
+    """通过 ModelScope SDK 下载数据集,与模型下载保持一致的可靠性。"""
+    from modelscope import MsDataset
+
     ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
     ds_dir.mkdir(parents=True, exist_ok=True)
 
-    # 获取文件列表
-    params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master"})
-    resp = urllib.request.urlopen(f"{base_url}?{params}&FilePath=", timeout=60)
-    file_list = json.loads(resp.read()).get("Data", {}).get("Files", [])
-    files = [f.get("Key", "") for f in file_list]
-
-    # 找目标文件
-    target_file = None
-    for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
-        if name in files:
-            target_file = name
-            break
-    if not target_file:
-        for f in files:
-            if f.endswith((".jsonl", ".json")):
-                target_file = f
-                break
-    if not target_file:
-        raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
-
-    # 下载文件
-    params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master", "FilePath": target_file})
-    resp = urllib.request.urlopen(f"{base_url}?{params}", timeout=600)
-    content = resp.read()
-
-    temp_path = ds_dir / target_file
-    temp_path.write_bytes(content)
-
-    # 读取数据
-    if target_file.endswith(".jsonl"):
-        with open(temp_path, "r", encoding="utf-8") as f:
-            records = [json.loads(line.strip()) for line in f if line.strip()]
+    # 使用 SDK 下载,避免手动 API 的 404 问题
+    ms_ds = MsDataset.load(dataset_id, cache_dir=str(settings.processed_dir))
+
+    # 确定使用的 split
+    if hasattr(ms_ds, "split_names") and ms_ds.split_names:
+        split_name = "train" if "train" in ms_ds.split_names else ms_ds.split_names[0]
+        split = ms_ds[split_name]
     else:
-        records = json.loads(content)
-        if not isinstance(records, list):
-            records = [records]
+        split = ms_ds
 
-    record_count = len(records)
     # 统一转为 JSONL
     jsonl_path = ds_dir / "data.jsonl"
+    record_count = 0
     with open(jsonl_path, "w", encoding="utf-8") as f:
-        for item in records:
+        for item in split:
             f.write(json.dumps(item, ensure_ascii=False) + "\n")
+            record_count += 1
 
     return ds_dir, jsonl_path, record_count