lxylxy123321 2 часов назад
Родитель
Сommit
069369429b
2 измененных файлов с 159 добавлено и 71 удалено
  1. 103 12
      backend/app/services/dataset_service.py
  2. 56 59
      backend/scripts/test_ms_api.py

+ 103 - 12
backend/app/services/dataset_service.py

@@ -400,29 +400,120 @@ async def recover_stale_downloads() -> None:
 
 
 def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
-    """用 modelscope CLI 下载数据集,完全绕过 datasets 库,避免版本兼容问题。"""
-    import subprocess
-
+    """用 MsDataset.load() 下载数据集,支持图片数据集(自动从 CDN 下载图片)。
+    如果 MsDataset.load() 失败,fallback 到 CLI 方式。"""
+    namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
     ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
     ds_dir.mkdir(parents=True, exist_ok=True)
 
-    # 使用 CLI 方式下载,避免 snapshot_download API 的路径问题
+    # 优先用 MsDataset.load(),它能自动下载"数据文件"区的图片
+    try:
+        records, record_count = _download_via_msdataset(dataset_id, ds_dir)
+        if records:
+            jsonl_path = ds_dir / "data.jsonl"
+            with open(jsonl_path, "w", encoding="utf-8") as f:
+                for item in records:
+                    f.write(json.dumps(item, ensure_ascii=False) + "\n")
+            logger.info(f"MsDataset.load() 成功: {dataset_id} ({record_count} records)")
+            return ds_dir, jsonl_path, record_count
+    except Exception as e:
+        logger.warning(f"MsDataset.load() failed for {dataset_id}: {e}, falling back to CLI")
+
+    # fallback: CLI 方式(只下载 git 仓库文件,不含数据文件区图片)
+    return _download_modelscope_dataset_cli(dataset_id, ds_dir)
+
+
+def _download_via_msdataset(dataset_id: str, ds_dir: Path) -> tuple[list[dict], int]:
+    """用 MsDataset.load() 下载数据集,处理图片列(PIL.Image → 保存到磁盘)。"""
+    from modelscope.msdatasets import MsDataset
+    from PIL import Image
+
+    namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
+    images_dir = ds_dir / "images"
+
+    # 尝试加载不同 split
+    ds = None
+    for split in ("train", "validation", "test"):
+        try:
+            if namespace:
+                ds = MsDataset.load(ds_name, namespace=namespace, split=split)
+            else:
+                ds = MsDataset.load(dataset_id, split=split)
+            if ds:
+                logger.info(f"MsDataset.load() loaded split '{split}': {len(ds) if hasattr(ds, '__len__') else '?'} records")
+                break
+        except Exception as e:
+            logger.debug(f"split '{split}' failed: {e}")
+
+    if not ds:
+        # 不带 split 参数试试
+        try:
+            if namespace:
+                ds = MsDataset.load(ds_name, namespace=namespace)
+            else:
+                ds = MsDataset.load(dataset_id)
+        except Exception as e:
+            logger.warning(f"MsDataset.load() without split also failed: {e}")
+            return [], 0
+
+    if not ds:
+        return [], 0
+
+    # 检查是否 iterable
+    if not hasattr(ds, '__iter__'):
+        return [], 0
+
+    records = []
+    img_counter = 0
+    columns = None
+
+    for row in ds:
+        if not isinstance(row, dict):
+            continue
+        if columns is None:
+            columns = list(row.keys())
+
+        record = {}
+        for k, v in row.items():
+            if isinstance(v, Image.Image):
+                # 图片对象:保存到磁盘,记录相对路径
+                images_dir.mkdir(parents=True, exist_ok=True)
+                img_name = f"{img_counter:06d}.jpg"
+                img_path = images_dir / img_name
+                if v.mode in ("RGBA", "P", "LA"):
+                    v = v.convert("RGB")
+                v.save(str(img_path), format="JPEG", quality=90)
+                record[k] = f"images/{img_name}"
+                img_counter += 1
+            else:
+                record[k] = v
+
+        records.append(record)
+
+        # 进度日志
+        if len(records) % 500 == 0:
+            logger.info(f"  处理中... {len(records)} records, {img_counter} images saved")
+
+    if img_counter > 0:
+        logger.info(f"共保存 {img_counter} 张图片到 {images_dir}")
+
+    return records, len(records)
+
+
+def _download_modelscope_dataset_cli(dataset_id: str, ds_dir: Path) -> tuple[Path, Path, int]:
+    """CLI 方式下载数据集(fallback,只下载 git 仓库文件)。"""
+    import subprocess
+
     cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)]
-    logger.info(f"Running: {' '.join(cmd)}")
+    logger.info(f"Fallback CLI: {' '.join(cmd)}")
     proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
     if proc.returncode != 0:
         logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
         raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
 
-    # CLI 下载完成后,通过 API 额外下载"数据文件"区的压缩包(CLI 只下载 git 元数据)
-    _download_modelscope_data_files(dataset_id, ds_dir)
-
-    # 检测并解压图片压缩包(图片数据集通常把图片放在压缩包中)
-    _extract_archives(ds_dir)
-
     # 扫描下载目录中的所有文件
     all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
-    logger.info(f"ModelScope CLI downloaded {len(all_files)} files to {ds_dir}")
+    logger.info(f"CLI downloaded {len(all_files)} files to {ds_dir}")
 
     # 识别训练数据文件
     data_files = [f for f in all_files if _is_training_data_file(f)]

+ 56 - 59
backend/scripts/test_ms_api.py

@@ -1,69 +1,66 @@
 #!/usr/bin/env python3
-"""测试 ModelScope API 能否正确获取数据集文件列表并下载压缩包。"""
-import json
-import urllib.request
-import urllib.parse
+"""测试 MsDataset.load() 能否正确下载图片数据集。"""
 import sys
+import json
 
-api_base = "https://www.modelscope.cn"
 dataset_id = sys.argv[1] if len(sys.argv) > 1 else "tany0699/carBrands50"
+namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
 
-print(f"测试数据集: {dataset_id}\n")
+print(f"测试数据集: {dataset_id}")
+print(f"namespace: {namespace}, name: {ds_name}\n")
 
-# Step 1: 获取数字 hub ID
-print("=== Step1: 获取 hub ID ===")
+print("=== 用 MsDataset.load() 下载 ===")
 try:
-    info_url = f"{api_base}/api/v1/datasets/{dataset_id}"
-    print(f"请求: {info_url}")
-    req = urllib.request.Request(info_url, headers={"User-Agent": "Test"})
-    with urllib.request.urlopen(req, timeout=30) as resp:
-        info = json.loads(resp.read().decode())
-    hub_id = info.get("Data", {}).get("Id") or info.get("Data", {}).get("id")
-    print(f"hub_id = {hub_id}\n")
-except Exception as e:
-    print(f"失败: {e}\n")
-    hub_id = None
+    from modelscope.msdatasets import MsDataset
+
+    ds = None
+    for split in ("train", "validation", "test"):
+        try:
+            if namespace:
+                ds = MsDataset.load(ds_name, namespace=namespace, split=split)
+            else:
+                ds = MsDataset.load(dataset_id, split=split)
+            if ds:
+                print(f"加载 split='{split}' 成功, 共 {len(ds) if hasattr(ds, '__len__') else '?'} 条")
+                break
+        except Exception as e:
+            print(f"split='{split}' 失败: {e}")
 
-# Step 2: 列出文件
-print("=== Step2: 列出文件 ===")
-files = []
-for test_id in filter(None, [hub_id, dataset_id]):
-    try:
-        tree_url = (f"{api_base}/api/v1/datasets/{test_id}/repo/tree"
-                    f"?Revision=master&Root=/&Recursive=True&PageNumber=1&PageSize=10000")
-        print(f"请求: {tree_url}")
-        req = urllib.request.Request(tree_url, headers={"User-Agent": "Test"})
-        with urllib.request.urlopen(req, timeout=30) as resp:
-            result = json.loads(resp.read().decode())
-        files = result.get("Data", {}).get("Files", [])
-        print(f"成功! 共 {len(files)} 个文件:")
-        for f in files:
-            name = f.get("Name", f.get("name", ""))
-            size = f.get("Size", f.get("size", ""))
-            print(f"  {name}  (size={size})")
-        if files:
+    if not ds:
+        try:
+            if namespace:
+                ds = MsDataset.load(ds_name, namespace=namespace)
+            else:
+                ds = MsDataset.load(dataset_id)
+            print(f"不带 split 加载成功, 类型: {type(ds)}")
+        except Exception as e:
+            print(f"不带 split 也失败: {e}")
+            sys.exit(1)
+
+    if not hasattr(ds, "__iter__"):
+        print(f"数据集不可迭代, 类型: {type(ds)}")
+        sys.exit(1)
+
+    # 查看前 3 条数据
+    print("\n=== 前 3 条数据 ===")
+    count = 0
+    for row in ds:
+        if count >= 3:
             break
-    except Exception as e:
-        print(f"失败: {e}")
-print()
+        print(f"\n--- Record {count} ---")
+        for k, v in row.items():
+            vtype = type(v).__name__
+            if vtype == "Image":
+                print(f"  {k}: PIL.Image (size={v.size}, mode={v.mode})")
+            elif isinstance(v, str) and len(v) > 100:
+                print(f"  {k}: str (len={len(v)}) '{v[:100]}...'")
+            else:
+                print(f"  {k}: {vtype} = {v}")
+        count += 1
 
-# Step 3: 筛选压缩包
-print("=== Step3: 压缩包文件 ===")
-archive_exts = (".zip", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")
-namespace, ds_name = dataset_id.split("/", 1)
-found = False
-for f in files:
-    name = f.get("Name", f.get("name", ""))
-    if any(name.lower().endswith(ext) for ext in archive_exts):
-        path = f.get("Path", f.get("path", name))
-        params = urllib.parse.urlencode({
-            "Source": "SDK", "Revision": "master",
-            "FilePath": path, "View": "false",
-        })
-        dl_url = f"{api_base}/api/v1/datasets/{namespace}/{ds_name}/repo?{params}"
-        print(f"  {name}")
-        print(f"    路径: {path}")
-        print(f"    下载URL: {dl_url}")
-        found = True
-if not found:
-    print("  未找到压缩包文件")
+    print(f"\n=== 完成 ===")
+
+except Exception as e:
+    print(f"失败: {e}")
+    import traceback
+    traceback.print_exc()