|
@@ -400,29 +400,120 @@ async def recover_stale_downloads() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
- """用 modelscope CLI 下载数据集,完全绕过 datasets 库,避免版本兼容问题。"""
|
|
|
|
|
- import subprocess
|
|
|
|
|
-
|
|
|
|
|
|
|
+ """用 MsDataset.load() 下载数据集,支持图片数据集(自动从 CDN 下载图片)。
|
|
|
|
|
+ 如果 MsDataset.load() 失败,fallback 到 CLI 方式。"""
|
|
|
|
|
+ namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
|
|
|
ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
|
|
ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
|
|
|
ds_dir.mkdir(parents=True, exist_ok=True)
|
|
ds_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
- # 使用 CLI 方式下载,避免 snapshot_download API 的路径问题
|
|
|
|
|
|
|
+ # 优先用 MsDataset.load(),它能自动下载"数据文件"区的图片
|
|
|
|
|
+ try:
|
|
|
|
|
+ records, record_count = _download_via_msdataset(dataset_id, ds_dir)
|
|
|
|
|
+ if records:
|
|
|
|
|
+ jsonl_path = ds_dir / "data.jsonl"
|
|
|
|
|
+ with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
|
|
|
+ for item in records:
|
|
|
|
|
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
|
|
+ logger.info(f"MsDataset.load() 成功: {dataset_id} ({record_count} records)")
|
|
|
|
|
+ return ds_dir, jsonl_path, record_count
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"MsDataset.load() failed for {dataset_id}: {e}, falling back to CLI")
|
|
|
|
|
+
|
|
|
|
|
+ # fallback: CLI 方式(只下载 git 仓库文件,不含数据文件区图片)
|
|
|
|
|
+ return _download_modelscope_dataset_cli(dataset_id, ds_dir)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _download_via_msdataset(dataset_id: str, ds_dir: Path) -> tuple[list[dict], int]:
|
|
|
|
|
+ """用 MsDataset.load() 下载数据集,处理图片列(PIL.Image → 保存到磁盘)。"""
|
|
|
|
|
+ from modelscope.msdatasets import MsDataset
|
|
|
|
|
+ from PIL import Image
|
|
|
|
|
+
|
|
|
|
|
+ namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
|
|
|
|
|
+ images_dir = ds_dir / "images"
|
|
|
|
|
+
|
|
|
|
|
+ # 尝试加载不同 split
|
|
|
|
|
+ ds = None
|
|
|
|
|
+ for split in ("train", "validation", "test"):
|
|
|
|
|
+ try:
|
|
|
|
|
+ if namespace:
|
|
|
|
|
+ ds = MsDataset.load(ds_name, namespace=namespace, split=split)
|
|
|
|
|
+ else:
|
|
|
|
|
+ ds = MsDataset.load(dataset_id, split=split)
|
|
|
|
|
+ if ds:
|
|
|
|
|
+ logger.info(f"MsDataset.load() loaded split '{split}': {len(ds) if hasattr(ds, '__len__') else '?'} records")
|
|
|
|
|
+ break
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.debug(f"split '{split}' failed: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ if not ds:
|
|
|
|
|
+ # 不带 split 参数试试
|
|
|
|
|
+ try:
|
|
|
|
|
+ if namespace:
|
|
|
|
|
+ ds = MsDataset.load(ds_name, namespace=namespace)
|
|
|
|
|
+ else:
|
|
|
|
|
+ ds = MsDataset.load(dataset_id)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"MsDataset.load() without split also failed: {e}")
|
|
|
|
|
+ return [], 0
|
|
|
|
|
+
|
|
|
|
|
+ if not ds:
|
|
|
|
|
+ return [], 0
|
|
|
|
|
+
|
|
|
|
|
+ # 检查是否 iterable
|
|
|
|
|
+ if not hasattr(ds, '__iter__'):
|
|
|
|
|
+ return [], 0
|
|
|
|
|
+
|
|
|
|
|
+ records = []
|
|
|
|
|
+ img_counter = 0
|
|
|
|
|
+ columns = None
|
|
|
|
|
+
|
|
|
|
|
+ for row in ds:
|
|
|
|
|
+ if not isinstance(row, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ if columns is None:
|
|
|
|
|
+ columns = list(row.keys())
|
|
|
|
|
+
|
|
|
|
|
+ record = {}
|
|
|
|
|
+ for k, v in row.items():
|
|
|
|
|
+ if isinstance(v, Image.Image):
|
|
|
|
|
+ # 图片对象:保存到磁盘,记录相对路径
|
|
|
|
|
+ images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ img_name = f"{img_counter:06d}.jpg"
|
|
|
|
|
+ img_path = images_dir / img_name
|
|
|
|
|
+ if v.mode in ("RGBA", "P", "LA"):
|
|
|
|
|
+ v = v.convert("RGB")
|
|
|
|
|
+ v.save(str(img_path), format="JPEG", quality=90)
|
|
|
|
|
+ record[k] = f"images/{img_name}"
|
|
|
|
|
+ img_counter += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ record[k] = v
|
|
|
|
|
+
|
|
|
|
|
+ records.append(record)
|
|
|
|
|
+
|
|
|
|
|
+ # 进度日志
|
|
|
|
|
+ if len(records) % 500 == 0:
|
|
|
|
|
+ logger.info(f" 处理中... {len(records)} records, {img_counter} images saved")
|
|
|
|
|
+
|
|
|
|
|
+ if img_counter > 0:
|
|
|
|
|
+ logger.info(f"共保存 {img_counter} 张图片到 {images_dir}")
|
|
|
|
|
+
|
|
|
|
|
+ return records, len(records)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _download_modelscope_dataset_cli(dataset_id: str, ds_dir: Path) -> tuple[Path, Path, int]:
|
|
|
|
|
+ """CLI 方式下载数据集(fallback,只下载 git 仓库文件)。"""
|
|
|
|
|
+ import subprocess
|
|
|
|
|
+
|
|
|
cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)]
|
|
cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)]
|
|
|
- logger.info(f"Running: {' '.join(cmd)}")
|
|
|
|
|
|
|
+ logger.info(f"Fallback CLI: {' '.join(cmd)}")
|
|
|
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
|
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
|
|
if proc.returncode != 0:
|
|
if proc.returncode != 0:
|
|
|
logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
|
|
logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
|
|
|
raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
|
|
raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
|
|
|
|
|
|
|
|
- # CLI 下载完成后,通过 API 额外下载"数据文件"区的压缩包(CLI 只下载 git 元数据)
|
|
|
|
|
- _download_modelscope_data_files(dataset_id, ds_dir)
|
|
|
|
|
-
|
|
|
|
|
- # 检测并解压图片压缩包(图片数据集通常把图片放在压缩包中)
|
|
|
|
|
- _extract_archives(ds_dir)
|
|
|
|
|
-
|
|
|
|
|
# 扫描下载目录中的所有文件
|
|
# 扫描下载目录中的所有文件
|
|
|
all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
|
|
all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
|
|
|
- logger.info(f"ModelScope CLI downloaded {len(all_files)} files to {ds_dir}")
|
|
|
|
|
|
|
+ logger.info(f"CLI downloaded {len(all_files)} files to {ds_dir}")
|
|
|
|
|
|
|
|
# 识别训练数据文件
|
|
# 识别训练数据文件
|
|
|
data_files = [f for f in all_files if _is_training_data_file(f)]
|
|
data_files = [f for f in all_files if _is_training_data_file(f)]
|