|
|
@@ -1,9 +1,13 @@
|
|
|
+import asyncio
|
|
|
import json
|
|
|
import uuid
|
|
|
from datetime import datetime, timezone
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
+import urllib.request
|
|
|
+import urllib.parse
|
|
|
+
|
|
|
from fastapi import UploadFile
|
|
|
|
|
|
from app.config import get_settings
|
|
|
@@ -18,47 +22,98 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
|
|
|
"""从 HuggingFace 或 ModelScope 下载数据集。"""
|
|
|
try:
|
|
|
if req.use_modelscope:
|
|
|
- from modelscope.msdatasets import MsDataset
|
|
|
-
|
|
|
- ds = MsDataset.load(req.dataset_id)
|
|
|
- ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
|
|
|
+ # 用 asyncio.to_thread 包裹同步下载,避免阻塞事件循环
|
|
|
+ ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
|
|
|
else:
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
ds = load_dataset(req.dataset_id)
|
|
|
ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
|
|
|
-
|
|
|
- ds_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- # 保存为 JSONL
|
|
|
- if "train" in ds:
|
|
|
- split = ds["train"]
|
|
|
- else:
|
|
|
- split = ds[list(ds.keys())[0]]
|
|
|
- output_path = ds_dir / "data.jsonl"
|
|
|
- with open(output_path, "w", encoding="utf-8") as f:
|
|
|
- for item in split:
|
|
|
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
+ ds_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ if "train" in ds:
|
|
|
+ split = ds["train"]
|
|
|
+ else:
|
|
|
+ split = ds[list(ds.keys())[0]]
|
|
|
+ output_path = ds_dir / "data.jsonl"
|
|
|
+ with open(output_path, "w", encoding="utf-8") as f:
|
|
|
+ for item in split:
|
|
|
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
+ jsonl_path = output_path
|
|
|
+ record_count = len(split) if hasattr(split, "__len__") else 0
|
|
|
|
|
|
# 写入数据库
|
|
|
record = DatasetRecord(
|
|
|
id=str(uuid.uuid4()),
|
|
|
name=req.dataset_id,
|
|
|
format="jsonl",
|
|
|
- record_count=len(split),
|
|
|
- file_path=str(output_path),
|
|
|
+ record_count=record_count,
|
|
|
+ file_path=str(jsonl_path),
|
|
|
created_at=datetime.now(timezone.utc),
|
|
|
)
|
|
|
async with async_session() as session:
|
|
|
session.add(record)
|
|
|
await session.commit()
|
|
|
|
|
|
- logger.info(f"Downloaded dataset: {req.dataset_id} ({len(split)} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
|
|
|
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(output_path))
|
|
|
+ logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
|
|
|
+ return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Dataset download failed: {e}")
|
|
|
return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
|
|
|
|
|
|
|
|
|
+def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
+ """同步下载 ModelScope 数据集(不依赖 MsDataset,避免与新 datasets 库的兼容问题)。"""
|
|
|
+ base_url = "https://www.modelscope.cn/api/v1/datasets/{}/repo".format(dataset_id)
|
|
|
+ ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
|
|
|
+ ds_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 获取文件列表
|
|
|
+ params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master"})
|
|
|
+ resp = urllib.request.urlopen(f"{base_url}?{params}&FilePath=", timeout=60)
|
|
|
+ file_list = json.loads(resp.read()).get("Data", {}).get("Files", [])
|
|
|
+ files = [f.get("Key", "") for f in file_list]
|
|
|
+
|
|
|
+ # 找目标文件
|
|
|
+ target_file = None
|
|
|
+ for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
|
|
|
+ if name in files:
|
|
|
+ target_file = name
|
|
|
+ break
|
|
|
+ if not target_file:
|
|
|
+ for f in files:
|
|
|
+ if f.endswith((".jsonl", ".json")):
|
|
|
+ target_file = f
|
|
|
+ break
|
|
|
+ if not target_file:
|
|
|
+ raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
|
|
|
+
|
|
|
+ # 下载文件
|
|
|
+ params = urllib.parse.urlencode({"Source": "SDK", "Revision": "master", "FilePath": target_file})
|
|
|
+ resp = urllib.request.urlopen(f"{base_url}?{params}", timeout=600)
|
|
|
+ content = resp.read()
|
|
|
+
|
|
|
+ temp_path = ds_dir / target_file
|
|
|
+ temp_path.write_bytes(content)
|
|
|
+
|
|
|
+ # 读取数据
|
|
|
+ if target_file.endswith(".jsonl"):
|
|
|
+ with open(temp_path, "r", encoding="utf-8") as f:
|
|
|
+ records = [json.loads(line.strip()) for line in f if line.strip()]
|
|
|
+ else:
|
|
|
+ records = json.loads(content)
|
|
|
+ if not isinstance(records, list):
|
|
|
+ records = [records]
|
|
|
+
|
|
|
+ record_count = len(records)
|
|
|
+ # 统一转为 JSONL
|
|
|
+ jsonl_path = ds_dir / "data.jsonl"
|
|
|
+ with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
|
+ for item in records:
|
|
|
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
+
|
|
|
+ return ds_dir, jsonl_path, record_count
|
|
|
+
|
|
|
+
|
|
|
async def upload_dataset(file: UploadFile) -> dict[str, Any]:
|
|
|
"""保存上传文件并写入数据库。"""
|
|
|
upload_dir = settings.uploads_dir
|