Przeglądaj źródła

修复图片数据下载

lxylxy123321 6 godzin temu
rodzic
commit
13361138a5

+ 1 - 0
backend/app/schemas/dataset.py

@@ -41,6 +41,7 @@ class DatasetPreviewResponse(BaseModel):
     total_records: int
     preview_rows: list[DatasetPreviewRow]
     columns: list[str]
+    image_column: str | None = None
 
 
 class DatasetValidationResult(BaseModel):

+ 50 - 0
backend/app/services/dataset_service.py

@@ -74,6 +74,53 @@ def _is_training_data_file(path: Path) -> bool:
     return False
 
 
+def _extract_archives(ds_dir: Path):
+    """检测并解压数据集目录中的压缩包(zip/tar.gz/tar.bz2/tgz),
+    图片数据集通常将图片存放在压缩包中,需要解压后才能在预览时显示。"""
+    import zipfile
+    import tarfile
+
+    extracted_any = False
+
+    for f in list(ds_dir.rglob("*")):
+        if not f.is_file():
+            continue
+        # 判断是否为压缩包
+        name_lower = f.name.lower()
+        is_zip = name_lower.endswith(".zip")
+        is_tar = any(name_lower.endswith(ext) for ext in
+                     (".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar"))
+        if not is_zip and not is_tar:
+            continue
+
+        # 用压缩包名(去掉所有后缀)作为解压目标目录
+        stem = f.name
+        for ext in (".tar.gz", ".tar.bz2", ".tgz", ".tbz2", ".tar", ".zip"):
+            if stem.lower().endswith(ext):
+                stem = stem[:-len(ext)]
+                break
+        extract_dir = f.parent / stem
+        if extract_dir.exists():
+            logger.info(f"Archive already extracted, skipping: {f.name}")
+            continue
+
+        logger.info(f"Extracting archive: {f.name} -> {extract_dir}")
+        try:
+            if is_zip:
+                with zipfile.ZipFile(f, "r") as zf:
+                    zf.extractall(f.parent)
+            else:
+                with tarfile.open(f, "r:*") as tf:
+                    tf.extractall(f.parent)
+            extracted_any = True
+            logger.info(f"Successfully extracted: {f.name}")
+        except Exception as e:
+            logger.warning(f"Failed to extract {f.name}: {e}")
+
+    if extracted_any:
+        logger.info(f"Archive extraction completed for {ds_dir}")
+
+
 async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
     """启动数据集下载后台任务,立即返回 task_id。"""
     task_id = str(uuid.uuid4())
@@ -262,6 +309,9 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
         logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
         raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
 
+    # 下载完成后,检测并解压图片压缩包(图片数据集通常把图片放在"数据文件"区的压缩包中)
+    _extract_archives(ds_dir)
+
     # 扫描下载目录中的所有文件
     all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
     logger.info(f"ModelScope CLI downloaded {len(all_files)} files to {ds_dir}")