Просмотр исходного кода

修复无法下载图片压缩包

lxylxy123321 6 часов назад
Родитель
Сommit
3e11b1a3ea
2 измененных файлов с 160 добавлено и 50 удалено
  1. 91 50
      backend/app/services/dataset_service.py
  2. 69 0
      backend/scripts/test_ms_api.py

+ 91 - 50
backend/app/services/dataset_service.py

@@ -122,67 +122,108 @@ def _extract_archives(ds_dir: Path):
 
 
 def _download_modelscope_data_files(dataset_id: str, ds_dir: Path):
-    """通过 ModelScope REST API 下载"数据文件"区的文件
+    """通过 ModelScope API 下载"数据文件"区的压缩包
     CLI download 只能下载 git 仓库中的元数据文件(如 train.csv),
     图片数据集的压缩包存放在"数据文件"区,需要通过 API 单独下载。"""
     import urllib.request
     import urllib.error
+    import urllib.parse
 
-    api_base = "https://modelscope.cn"
-    api_url = f"{api_base}/api/v1/datasets/{dataset_id}/repo/tree?Recursive=true&PageSize=500"
+    api_base = "https://www.modelscope.cn"
+    archive_exts = (".zip", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")
+    existing = {f.name for f in ds_dir.rglob("*") if f.is_file()}
 
-    logger.info(f"Fetching data file list from API: {dataset_id}")
+    # Step 1: 获取数据集的数字 hub ID(tree API 需要用数字 ID,不是 namespace/name)
+    hub_id = None
     try:
-        req = urllib.request.Request(api_url, headers={"User-Agent": "FineTuning-Backend"})
+        info_url = f"{api_base}/api/v1/datasets/{dataset_id}"
+        logger.info(f"Fetching dataset hub ID: {info_url}")
+        req = urllib.request.Request(info_url, headers={"User-Agent": "FineTuning-Backend"})
         with urllib.request.urlopen(req, timeout=30) as resp:
-            result = json.loads(resp.read().decode())
-
-        files_data = result.get("Data", result)
-        if isinstance(files_data, dict):
-            files_data = files_data.get("Files", files_data.get("files", []))
-        if not isinstance(files_data, list):
-            logger.debug(f"Unexpected API response format: {type(files_data)}")
-            return
-
-        # 已存在的文件名集合(CLI 已下载的)
-        existing = {f.name for f in ds_dir.rglob("*") if f.is_file()}
-        archive_exts = (".zip", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")
-
-        downloaded = []
-        for file_info in files_data:
-            name = file_info.get("Name", file_info.get("name", ""))
-            path_in_repo = file_info.get("Path", file_info.get("path", name))
-            if not name:
-                continue
-            if not any(name.lower().endswith(ext) for ext in archive_exts):
-                continue
-            if name in existing:
-                logger.info(f"Archive already exists, skipping: {name}")
-                continue
+            info = json.loads(resp.read().decode())
+        hub_id = info.get("Data", {}).get("Id") or info.get("Data", {}).get("id")
+        if hub_id:
+            logger.info(f"Got dataset hub ID: {hub_id}")
+    except Exception as e:
+        logger.warning(f"Failed to get dataset hub ID: {e}")
 
-            dl_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo"
-                      f"?Revision=master&FilePath={path_in_repo}")
-            dest = ds_dir / name
-            logger.info(f"Downloading data file from API: {name}")
+    # Step 2: 用 hub ID 列出仓库中的所有文件
+    files = []
+    if hub_id:
+        try:
+            tree_url = (f"{api_base}/api/v1/datasets/{hub_id}/repo/tree"
+                        f"?Revision=master&Root=/&Recursive=True&PageNumber=1&PageSize=10000")
+            logger.info(f"Listing dataset files: {tree_url}")
+            req = urllib.request.Request(tree_url, headers={"User-Agent": "FineTuning-Backend"})
+            with urllib.request.urlopen(req, timeout=30) as resp:
+                result = json.loads(resp.read().decode())
+            files = result.get("Data", {}).get("Files", [])
+            if files:
+                logger.info(f"Found {len(files)} files in dataset repo")
+                for f in files:
+                    fn = f.get("Name", f.get("name", ""))
+                    if fn:
+                        logger.debug(f"  - {fn}")
+        except urllib.error.HTTPError as e:
+            logger.warning(f"Tree API error ({e.code}) for hub_id={hub_id}")
+        except Exception as e:
+            logger.warning(f"Failed to list files via tree API: {e}")
 
-            try:
-                dl_req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"})
-                with urllib.request.urlopen(dl_req, timeout=600) as dl_resp:
-                    dest.write_bytes(dl_resp.read())
-                downloaded.append(name)
-                logger.info(f"Downloaded data file: {name} ({dest.stat().st_size / 1024 / 1024:.1f}MB)")
-            except Exception as e:
-                logger.warning(f"Failed to download data file {name}: {e}")
-
-        if downloaded:
-            logger.info(f"Downloaded {len(downloaded)} data file(s) from API: {downloaded}")
-        else:
-            logger.info("No additional data files (archives) found via API")
+    if not files:
+        # fallback: 直接用 namespace/name 格式尝试
+        try:
+            tree_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo/tree"
+                        f"?Revision=master&Root=/&Recursive=True&PageNumber=1&PageSize=10000")
+            logger.info(f"Fallback tree URL: {tree_url}")
+            req = urllib.request.Request(tree_url, headers={"User-Agent": "FineTuning-Backend"})
+            with urllib.request.urlopen(req, timeout=30) as resp:
+                result = json.loads(resp.read().decode())
+            files = result.get("Data", {}).get("Files", [])
+            if files:
+                logger.info(f"Fallback found {len(files)} files")
+        except Exception as e:
+            logger.warning(f"Fallback tree API also failed: {e}")
+
+    if not files:
+        logger.warning(f"Could not list any files for dataset {dataset_id}")
+        return
+
+    # Step 3: 筛选压缩包文件并下载
+    # 下载 URL 用 namespace/name 格式: /api/v1/datasets/{ns}/{name}/repo?Source=SDK&Revision=master&FilePath=...
+    namespace, name = dataset_id.split("/", 1)
+    downloaded = []
+    for file_info in files:
+        fname = file_info.get("Name", file_info.get("name", ""))
+        fpath = file_info.get("Path", file_info.get("path", fname))
+        if not fname:
+            continue
+        if not any(fname.lower().endswith(ext) for ext in archive_exts):
+            continue
+        if fname in existing:
+            logger.info(f"Archive already exists, skipping: {fname}")
+            continue
 
-    except urllib.error.HTTPError as e:
-        logger.warning(f"ModelScope API error ({e.code}): cannot fetch data files for {dataset_id}")
-    except Exception as e:
-        logger.warning(f"Failed to fetch data file list from API: {e}")
+        params = urllib.parse.urlencode({
+            "Source": "SDK", "Revision": "master",
+            "FilePath": fpath, "View": "false",
+        })
+        dl_url = f"{api_base}/api/v1/datasets/{namespace}/{name}/repo?{params}"
+        dest = ds_dir / fname
+        logger.info(f"Downloading data file: {fname} from {dl_url}")
+
+        try:
+            req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"})
+            with urllib.request.urlopen(req, timeout=600) as resp:
+                dest.write_bytes(resp.read())
+            downloaded.append(fname)
+            logger.info(f"Downloaded: {fname} ({dest.stat().st_size / 1024 / 1024:.1f}MB)")
+        except Exception as e:
+            logger.warning(f"Failed to download {fname}: {e}")
+
+    if downloaded:
+        logger.info(f"Downloaded {len(downloaded)} data file(s): {downloaded}")
+    else:
+        logger.info("No downloadable archives found in dataset repo")
 
 
 async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:

+ 69 - 0
backend/scripts/test_ms_api.py

@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+"""测试 ModelScope API 能否正确获取数据集文件列表并下载压缩包。"""
+import json
+import urllib.request
+import urllib.parse
+import sys
+
+api_base = "https://www.modelscope.cn"
+dataset_id = sys.argv[1] if len(sys.argv) > 1 else "tany0699/carBrands50"
+
+print(f"测试数据集: {dataset_id}\n")
+
+# Step 1: 获取数字 hub ID
+print("=== Step1: 获取 hub ID ===")
+try:
+    info_url = f"{api_base}/api/v1/datasets/{dataset_id}"
+    print(f"请求: {info_url}")
+    req = urllib.request.Request(info_url, headers={"User-Agent": "Test"})
+    with urllib.request.urlopen(req, timeout=30) as resp:
+        info = json.loads(resp.read().decode())
+    hub_id = info.get("Data", {}).get("Id") or info.get("Data", {}).get("id")
+    print(f"hub_id = {hub_id}\n")
+except Exception as e:
+    print(f"失败: {e}\n")
+    hub_id = None
+
+# Step 2: 列出文件
+print("=== Step2: 列出文件 ===")
+files = []
+for test_id in filter(None, [hub_id, dataset_id]):
+    try:
+        tree_url = (f"{api_base}/api/v1/datasets/{test_id}/repo/tree"
+                    f"?Revision=master&Root=/&Recursive=True&PageNumber=1&PageSize=10000")
+        print(f"请求: {tree_url}")
+        req = urllib.request.Request(tree_url, headers={"User-Agent": "Test"})
+        with urllib.request.urlopen(req, timeout=30) as resp:
+            result = json.loads(resp.read().decode())
+        files = result.get("Data", {}).get("Files", [])
+        print(f"成功! 共 {len(files)} 个文件:")
+        for f in files:
+            name = f.get("Name", f.get("name", ""))
+            size = f.get("Size", f.get("size", ""))
+            print(f"  {name}  (size={size})")
+        if files:
+            break
+    except Exception as e:
+        print(f"失败: {e}")
+print()
+
+# Step 3: 筛选压缩包
+print("=== Step3: 压缩包文件 ===")
+archive_exts = (".zip", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")
+namespace, ds_name = dataset_id.split("/", 1)
+found = False
+for f in files:
+    name = f.get("Name", f.get("name", ""))
+    if any(name.lower().endswith(ext) for ext in archive_exts):
+        path = f.get("Path", f.get("path", name))
+        params = urllib.parse.urlencode({
+            "Source": "SDK", "Revision": "master",
+            "FilePath": path, "View": "false",
+        })
+        dl_url = f"{api_base}/api/v1/datasets/{namespace}/{ds_name}/repo?{params}"
+        print(f"  {name}")
+        print(f"    路径: {path}")
+        print(f"    下载URL: {dl_url}")
+        found = True
+if not found:
+    print("  未找到压缩包文件")