|
@@ -74,6 +74,53 @@ def _is_training_data_file(path: Path) -> bool:
|
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _extract_archives(ds_dir: Path):
|
|
|
|
|
+ """检测并解压数据集目录中的压缩包(zip/tar.gz/tar.bz2/tgz),
|
|
|
|
|
+ 图片数据集通常将图片存放在压缩包中,需要解压后才能在预览时显示。"""
|
|
|
|
|
+ import zipfile
|
|
|
|
|
+ import tarfile
|
|
|
|
|
+
|
|
|
|
|
+ extracted_any = False
|
|
|
|
|
+
|
|
|
|
|
+ for f in list(ds_dir.rglob("*")):
|
|
|
|
|
+ if not f.is_file():
|
|
|
|
|
+ continue
|
|
|
|
|
+ # 判断是否为压缩包
|
|
|
|
|
+ name_lower = f.name.lower()
|
|
|
|
|
+ is_zip = name_lower.endswith(".zip")
|
|
|
|
|
+ is_tar = any(name_lower.endswith(ext) for ext in
|
|
|
|
|
+ (".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar"))
|
|
|
|
|
+ if not is_zip and not is_tar:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 用压缩包名(去掉所有后缀)作为解压目标目录
|
|
|
|
|
+ stem = f.name
|
|
|
|
|
+ for ext in (".tar.gz", ".tar.bz2", ".tgz", ".tbz2", ".tar", ".zip"):
|
|
|
|
|
+ if stem.lower().endswith(ext):
|
|
|
|
|
+ stem = stem[:-len(ext)]
|
|
|
|
|
+ break
|
|
|
|
|
+ extract_dir = f.parent / stem
|
|
|
|
|
+ if extract_dir.exists():
|
|
|
|
|
+ logger.info(f"Archive already extracted, skipping: {f.name}")
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"Extracting archive: {f.name} -> {extract_dir}")
|
|
|
|
|
+ try:
|
|
|
|
|
+ if is_zip:
|
|
|
|
|
+ with zipfile.ZipFile(f, "r") as zf:
|
|
|
|
|
+ zf.extractall(f.parent)
|
|
|
|
|
+ else:
|
|
|
|
|
+ with tarfile.open(f, "r:*") as tf:
|
|
|
|
|
+ tf.extractall(f.parent)
|
|
|
|
|
+ extracted_any = True
|
|
|
|
|
+ logger.info(f"Successfully extracted: {f.name}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"Failed to extract {f.name}: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ if extracted_any:
|
|
|
|
|
+ logger.info(f"Archive extraction completed for {ds_dir}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
|
|
async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
|
|
|
"""启动数据集下载后台任务,立即返回 task_id。"""
|
|
"""启动数据集下载后台任务,立即返回 task_id。"""
|
|
|
task_id = str(uuid.uuid4())
|
|
task_id = str(uuid.uuid4())
|
|
@@ -262,6 +309,9 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
|
|
logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
|
|
|
raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
|
|
raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
|
|
|
|
|
|
|
|
|
|
+ # 下载完成后,检测并解压图片压缩包(图片数据集通常把图片放在"数据文件"区的压缩包中)
|
|
|
|
|
+ _extract_archives(ds_dir)
|
|
|
|
|
+
|
|
|
# 扫描下载目录中的所有文件
|
|
# 扫描下载目录中的所有文件
|
|
|
all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
|
|
all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
|
|
|
logger.info(f"ModelScope CLI downloaded {len(all_files)} files to {ds_dir}")
|
|
logger.info(f"ModelScope CLI downloaded {len(all_files)} files to {ds_dir}")
|