|
|
@@ -121,6 +121,70 @@ def _extract_archives(ds_dir: Path):
|
|
|
logger.info(f"Archive extraction completed for {ds_dir}")
|
|
|
|
|
|
|
|
|
+def _download_modelscope_data_files(dataset_id: str, ds_dir: Path):
|
|
|
+ """通过 ModelScope REST API 下载"数据文件"区的文件。
|
|
|
+ CLI download 只能下载 git 仓库中的元数据文件(如 train.csv),
|
|
|
+ 图片数据集的压缩包存放在"数据文件"区,需要通过 API 单独下载。"""
|
|
|
+ import urllib.request
|
|
|
+ import urllib.error
|
|
|
+
|
|
|
+ api_base = "https://modelscope.cn"
|
|
|
+ api_url = f"{api_base}/api/v1/datasets/{dataset_id}/repo/tree?Recursive=true&PageSize=500"
|
|
|
+
|
|
|
+ logger.info(f"Fetching data file list from API: {dataset_id}")
|
|
|
+ try:
|
|
|
+ req = urllib.request.Request(api_url, headers={"User-Agent": "FineTuning-Backend"})
|
|
|
+ with urllib.request.urlopen(req, timeout=30) as resp:
|
|
|
+ result = json.loads(resp.read().decode())
|
|
|
+
|
|
|
+ files_data = result.get("Data", result)
|
|
|
+ if isinstance(files_data, dict):
|
|
|
+ files_data = files_data.get("Files", files_data.get("files", []))
|
|
|
+ if not isinstance(files_data, list):
|
|
|
+ logger.debug(f"Unexpected API response format: {type(files_data)}")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 已存在的文件名集合(CLI 已下载的)
|
|
|
+ existing = {f.name for f in ds_dir.rglob("*") if f.is_file()}
|
|
|
+ archive_exts = (".zip", ".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")
|
|
|
+
|
|
|
+ downloaded = []
|
|
|
+ for file_info in files_data:
|
|
|
+ name = file_info.get("Name", file_info.get("name", ""))
|
|
|
+ path_in_repo = file_info.get("Path", file_info.get("path", name))
|
|
|
+ if not name:
|
|
|
+ continue
|
|
|
+ if not any(name.lower().endswith(ext) for ext in archive_exts):
|
|
|
+ continue
|
|
|
+ if name in existing:
|
|
|
+ logger.info(f"Archive already exists, skipping: {name}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ dl_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo"
|
|
|
+ f"?Revision=master&FilePath={path_in_repo}")
|
|
|
+ dest = ds_dir / name
|
|
|
+ logger.info(f"Downloading data file from API: {name}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ dl_req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"})
|
|
|
+ with urllib.request.urlopen(dl_req, timeout=600) as dl_resp:
|
|
|
+ dest.write_bytes(dl_resp.read())
|
|
|
+ downloaded.append(name)
|
|
|
+ logger.info(f"Downloaded data file: {name} ({dest.stat().st_size / 1024 / 1024:.1f}MB)")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Failed to download data file {name}: {e}")
|
|
|
+
|
|
|
+ if downloaded:
|
|
|
+ logger.info(f"Downloaded {len(downloaded)} data file(s) from API: {downloaded}")
|
|
|
+ else:
|
|
|
+ logger.info("No additional data files (archives) found via API")
|
|
|
+
|
|
|
+ except urllib.error.HTTPError as e:
|
|
|
+ logger.warning(f"ModelScope API error ({e.code}): cannot fetch data files for {dataset_id}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Failed to fetch data file list from API: {e}")
|
|
|
+
|
|
|
+
|
|
|
async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
|
|
|
"""启动数据集下载后台任务,立即返回 task_id。"""
|
|
|
task_id = str(uuid.uuid4())
|
|
|
@@ -309,7 +373,10 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
|
|
|
raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
|
|
|
|
|
|
- # 下载完成后,检测并解压图片压缩包(图片数据集通常把图片放在"数据文件"区的压缩包中)
|
|
|
+ # CLI 下载完成后,通过 API 额外下载"数据文件"区的压缩包(CLI 只下载 git 元数据)
|
|
|
+ _download_modelscope_data_files(dataset_id, ds_dir)
|
|
|
+
|
|
|
+ # 检测并解压图片压缩包(图片数据集通常把图片放在压缩包中)
|
|
|
_extract_archives(ds_dir)
|
|
|
|
|
|
# 扫描下载目录中的所有文件
|
|
|
@@ -644,7 +711,9 @@ async def list_datasets() -> list[dict[str, Any]]:
|
|
|
|
|
|
|
|
|
async def delete_dataset(dataset_id: str) -> dict[str, Any]:
|
|
|
- """删除数据集。"""
|
|
|
+ """删除数据集,同时清理关联的目录文件。"""
|
|
|
+ import shutil
|
|
|
+
|
|
|
async with async_session() as session:
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
@@ -652,7 +721,13 @@ async def delete_dataset(dataset_id: str) -> dict[str, Any]:
|
|
|
record = result.scalar_one_or_none()
|
|
|
if record:
|
|
|
file_path = Path(record.file_path)
|
|
|
- if file_path.exists():
|
|
|
+ # 下载的数据集(processed_dir 下的子目录):删除整个目录
|
|
|
+ if file_path.exists() and settings.processed_dir in file_path.parents:
|
|
|
+ ds_dir = file_path.parent
|
|
|
+ shutil.rmtree(ds_dir, ignore_errors=True)
|
|
|
+ logger.info(f"Deleted dataset directory: {ds_dir}")
|
|
|
+ elif file_path.exists():
|
|
|
+ # 上传的数据集:只删文件
|
|
|
file_path.unlink()
|
|
|
await session.delete(record)
|
|
|
await session.commit()
|