lxylxy123321 1 неделя назад
Родитель
Сommit
ee787c1ffd
2 измененных файлов с 70 добавлено и 30 удалено
  1. 58 24
      backend/app/services/dataset_service.py
  2. 12 6
      backend/app/services/model_service.py

+ 58 - 24
backend/app/services/dataset_service.py

@@ -66,7 +66,23 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
     """从 HuggingFace 或 ModelScope 下载数据集。"""
     try:
         if req.use_modelscope:
-            ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
+            import subprocess
+
+            ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
+            ds_dir.mkdir(parents=True, exist_ok=True)
+            # 用独立进程调用 CLI,完全隔离 FastAPI 事件循环
+            proc = subprocess.run(
+                [
+                    "modelscope", "download",
+                    "--dataset", req.dataset_id,
+                    "--local_dir", str(ds_dir),
+                ],
+                capture_output=True, text=True, timeout=3600,
+            )
+            if proc.returncode != 0:
+                raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
+            # 扫描下载的文件,找训练数据
+            jsonl_path, record_count = _scan_and_convert_to_jsonl(ds_dir)
         else:
             from datasets import load_dataset
 
@@ -103,34 +119,52 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
         return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
 
 
-def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
-    """按官方文档推荐方式:MsDataset 加载并转为 JSONL。"""
-    from modelscope import MsDataset
-
-    # 按官方文档推荐方式加载,优先使用 train split
-    try:
-        ds = MsDataset.load(dataset_id, split='train')
-    except Exception:
-        # 部分数据集可能没有 train split,尝试加载完整数据集
-        ds = MsDataset.load(dataset_id)
-
-    ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
-    ds_dir.mkdir(parents=True, exist_ok=True)
-
-    # 如果是 DatasetDict(有多个 split),取第一个 split 的数据
-    split_data = ds if not hasattr(ds, "keys") else ds[list(ds.keys())[0]]
+def _scan_and_convert_to_jsonl(ds_dir: Path) -> tuple[Path, int]:
+    """扫描 CLI 下载的数据集目录,找训练数据文件并转为 JSONL。"""
+    # 找所有可能的数据文件
+    data_files = []
+    for ext in ("*.jsonl", "*.json", "*.csv"):
+        data_files.extend(ds_dir.rglob(ext))
+    # 过滤掉元数据文件
+    data_files = [f for f in data_files if f.name not in META_FILENAMES]
 
-    # 如果是 DatasetDict,取第一个 split
-    split_data = ds if not hasattr(ds, "keys") else ds[list(ds.keys())[0]]
+    if not data_files:
+        raise RuntimeError(f"No dataset files found in {ds_dir}")
 
     jsonl_path = ds_dir / "data.jsonl"
     record_count = 0
-    with open(jsonl_path, "w", encoding="utf-8") as f:
-        for item in split_data:
-            f.write(json.dumps(item, ensure_ascii=False) + "\n")
-            record_count += 1
 
-    return ds_dir, jsonl_path, record_count
+    with open(jsonl_path, "w", encoding="utf-8") as out:
+        for data_file in data_files:
+            if data_file.suffix == ".jsonl":
+                with open(data_file, "r", encoding="utf-8") as f:
+                    for line in f:
+                        line = line.strip()
+                        if line:
+                            out.write(line + "\n")
+                            record_count += 1
+            elif data_file.suffix == ".json":
+                try:
+                    with open(data_file, "r", encoding="utf-8") as f:
+                        data = json.load(f)
+                        if isinstance(data, list):
+                            for item in data:
+                                out.write(json.dumps(item, ensure_ascii=False) + "\n")
+                                record_count += 1
+                        elif isinstance(data, dict):
+                            out.write(json.dumps(data, ensure_ascii=False) + "\n")
+                            record_count += 1
+                except Exception:
+                    pass
+            elif data_file.suffix == ".csv":
+                import csv
+                with open(data_file, "r", encoding="utf-8") as f:
+                    reader = csv.DictReader(f)
+                    for row in reader:
+                        out.write(json.dumps(dict(row), ensure_ascii=False) + "\n")
+                        record_count += 1
+
+    return jsonl_path, record_count
 
 
 async def upload_dataset(file: UploadFile) -> dict[str, Any]:

+ 12 - 6
backend/app/services/model_service.py

@@ -43,15 +43,21 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
     """从 HF 或 ModelScope 下载模型到本地缓存。"""
     try:
         if use_modelscope:
-            import asyncio
-
-            from modelscope.hub.snapshot_download import snapshot_download as ms_download
+            import subprocess
 
             download_dir = str(settings.models_dir / model_id.replace("/", "_"))
-            # 在线程池中执行,避免与 FastAPI 事件循环冲突
-            local_path = await asyncio.to_thread(
-                ms_download, model_id, local_dir=download_dir
+            # 用独立进程调用 CLI,完全隔离 FastAPI 事件循环,避免 __aenter__ 错误
+            proc = subprocess.run(
+                [
+                    "modelscope", "download",
+                    "--model", model_id,
+                    "--local_dir", download_dir,
+                ],
+                capture_output=True, text=True, timeout=3600,
             )
+            if proc.returncode != 0:
+                raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
+            local_path = download_dir
         else:
             from huggingface_hub import snapshot_download