|
|
@@ -5,8 +5,6 @@ from datetime import datetime, timezone
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
-import urllib.request
|
|
|
-import urllib.parse
|
|
|
|
|
|
from fastapi import UploadFile
|
|
|
|
|
|
@@ -62,54 +60,29 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
|
|
|
|
|
|
|
|
|
def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
- """同步下载 ModelScope 数据集(不依赖 MsDataset,避免与新 datasets 库的兼容问题)。"""
|
|
|
- base_url = "https://www.modelscope.cn/api/v1/datasets/{}/repo".format(dataset_id)
|
|
|
+ """通过 ModelScope SDK 下载数据集,与模型下载保持一致的可靠性。"""
|
|
|
+ from modelscope import MsDataset
|
|
|
+
|
|
|
ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
|
|
|
ds_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- # 获取文件列表
|
|
|
- params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master"})
|
|
|
- resp = urllib.request.urlopen(f"{base_url}?{params}&FilePath=", timeout=60)
|
|
|
- file_list = json.loads(resp.read()).get("Data", {}).get("Files", [])
|
|
|
- files = [f.get("Key", "") for f in file_list]
|
|
|
-
|
|
|
- # 找目标文件
|
|
|
- target_file = None
|
|
|
- for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
|
|
|
- if name in files:
|
|
|
- target_file = name
|
|
|
- break
|
|
|
- if not target_file:
|
|
|
- for f in files:
|
|
|
- if f.endswith((".jsonl", ".json")):
|
|
|
- target_file = f
|
|
|
- break
|
|
|
- if not target_file:
|
|
|
- raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
|
|
|
-
|
|
|
- # 下载文件
|
|
|
- params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master", "FilePath": target_file})
|
|
|
- resp = urllib.request.urlopen(f"{base_url}?{params}", timeout=600)
|
|
|
- content = resp.read()
|
|
|
-
|
|
|
- temp_path = ds_dir / target_file
|
|
|
- temp_path.write_bytes(content)
|
|
|
-
|
|
|
- # 读取数据
|
|
|
- if target_file.endswith(".jsonl"):
|
|
|
- with open(temp_path, "r", encoding="utf-8") as f:
|
|
|
- records = [json.loads(line.strip()) for line in f if line.strip()]
|
|
|
+ # 使用 SDK 下载,避免手动 API 的 404 问题
|
|
|
+ ms_ds = MsDataset.load(dataset_id, cache_dir=str(settings.processed_dir))
|
|
|
+
|
|
|
+ # 确定使用的 split
|
|
|
+ if hasattr(ms_ds, "split_names") and ms_ds.split_names:
|
|
|
+ split_name = "train" if "train" in ms_ds.split_names else ms_ds.split_names[0]
|
|
|
+ split = ms_ds[split_name]
|
|
|
else:
|
|
|
- records = json.loads(content)
|
|
|
- if not isinstance(records, list):
|
|
|
- records = [records]
|
|
|
+ split = ms_ds
|
|
|
|
|
|
- record_count = len(records)
|
|
|
# 统一转为 JSONL
|
|
|
jsonl_path = ds_dir / "data.jsonl"
|
|
|
+ record_count = 0
|
|
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
|
- for item in records:
|
|
|
+ for item in split:
|
|
|
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
+ record_count += 1
|
|
|
|
|
|
return ds_dir, jsonl_path, record_count
|
|
|
|