Przeglądaj źródła

修复无法下载图片压缩包

lxylxy123321 5 godzin temu
rodzic
commit
25fa4fa89b
1 zmienionych plików z 78 dodań i 3 usunięć
  1. 78 3
      backend/app/services/dataset_service.py

+ 78 - 3
backend/app/services/dataset_service.py

@@ -121,6 +121,70 @@ def _extract_archives(ds_dir: Path):
         logger.info(f"Archive extraction completed for {ds_dir}")
 
 
+def _download_modelscope_data_files(dataset_id: str, ds_dir: Path):
+    """通过 ModelScope REST API 下载"数据文件"区的文件。
+    CLI download 只能下载 git 仓库中的元数据文件(如 train.csv),
+    图片数据集的压缩包存放在"数据文件"区,需要通过 API 单独下载。"""
+    import urllib.request
+    import urllib.error
+
+    api_base = "https://modelscope.cn"
+    api_url = f"{api_base}/api/v1/datasets/{dataset_id}/repo/tree?Recursive=true&PageSize=500"
+
+    logger.info(f"Fetching data file list from API: {dataset_id}")
+    try:
+        req = urllib.request.Request(api_url, headers={"User-Agent": "FineTuning-Backend"})
+        with urllib.request.urlopen(req, timeout=30) as resp:
+            result = json.loads(resp.read().decode())
+
+        files_data = result.get("Data", result)
+        if isinstance(files_data, dict):
+            files_data = files_data.get("Files", files_data.get("files", []))
+        if not isinstance(files_data, list):
+            logger.debug(f"Unexpected API response format: {type(files_data)}")
+            return
+
+        # 已存在的文件名集合(CLI 已下载的)
+        existing = {f.name for f in ds_dir.rglob("*") if f.is_file()}
+        archive_exts = (".zip", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")
+
+        downloaded = []
+        for file_info in files_data:
+            name = file_info.get("Name", file_info.get("name", ""))
+            path_in_repo = file_info.get("Path", file_info.get("path", name))
+            if not name:
+                continue
+            if not any(name.lower().endswith(ext) for ext in archive_exts):
+                continue
+            if name in existing:
+                logger.info(f"Archive already exists, skipping: {name}")
+                continue
+
+            dl_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo"
+                      f"?Revision=master&FilePath={path_in_repo}")
+            dest = ds_dir / name
+            logger.info(f"Downloading data file from API: {name}")
+
+            try:
+                dl_req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"})
+                with urllib.request.urlopen(dl_req, timeout=600) as dl_resp:
+                    dest.write_bytes(dl_resp.read())
+                downloaded.append(name)
+                logger.info(f"Downloaded data file: {name} ({dest.stat().st_size / 1024 / 1024:.1f}MB)")
+            except Exception as e:
+                logger.warning(f"Failed to download data file {name}: {e}")
+
+        if downloaded:
+            logger.info(f"Downloaded {len(downloaded)} data file(s) from API: {downloaded}")
+        else:
+            logger.info("No additional data files (archives) found via API")
+
+    except urllib.error.HTTPError as e:
+        logger.warning(f"ModelScope API error ({e.code}): cannot fetch data files for {dataset_id}")
+    except Exception as e:
+        logger.warning(f"Failed to fetch data file list from API: {e}")
+
+
 async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
     """启动数据集下载后台任务,立即返回 task_id。"""
     task_id = str(uuid.uuid4())
@@ -309,7 +373,10 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
         logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
         raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
 
-    # 下载完成后,检测并解压图片压缩包(图片数据集通常把图片放在"数据文件"区的压缩包中)
+    # CLI 下载完成后,通过 API 额外下载"数据文件"区的压缩包(CLI 只下载 git 元数据)
+    _download_modelscope_data_files(dataset_id, ds_dir)
+
+    # 检测并解压图片压缩包(图片数据集通常把图片放在压缩包中)
     _extract_archives(ds_dir)
 
     # 扫描下载目录中的所有文件
@@ -644,7 +711,9 @@ async def list_datasets() -> list[dict[str, Any]]:
 
 
 async def delete_dataset(dataset_id: str) -> dict[str, Any]:
-    """删除数据集。"""
+    """删除数据集,同时清理关联的目录文件。"""
+    import shutil
+
     async with async_session() as session:
         from sqlalchemy import select
 
@@ -652,7 +721,13 @@ async def delete_dataset(dataset_id: str) -> dict[str, Any]:
         record = result.scalar_one_or_none()
         if record:
             file_path = Path(record.file_path)
-            if file_path.exists():
+            # 下载的数据集(processed_dir 下的子目录):删除整个目录
+            if file_path.exists() and settings.processed_dir in file_path.parents:
+                ds_dir = file_path.parent
+                shutil.rmtree(ds_dir, ignore_errors=True)
+                logger.info(f"Deleted dataset directory: {ds_dir}")
+            elif file_path.exists():
+                # 上传的数据集:只删文件
                 file_path.unlink()
             await session.delete(record)
             await session.commit()