Selaa lähdekoodia

修复文件传输和数据下载

lxylxy123321 1 viikko sitten
vanhempi
sitoutus
a8a5c82282
2 muutettua tiedostoa jossa 33 lisäystä ja 45 poistoa
  1. 14 37
      backend/app/services/dataset_service.py
  2. 19 8
      backend/app/services/model_test_service.py

+ 14 - 37
backend/app/services/dataset_service.py

@@ -69,45 +69,22 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
     """从 HuggingFace 或 ModelScope 下载数据集。"""
     try:
         if req.use_modelscope:
-            import subprocess
-
             ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
             ds_dir.mkdir(parents=True, exist_ok=True)
-            # 优先尝试用 modelscope 的 load_dataset API 加载数据
-            try:
-                from modelscope.msdatasets import MsDataset
-                ms_ds = MsDataset.load(req.dataset_id)
-                if hasattr(ms_ds, '__getitem__') and hasattr(ms_ds, '__len__'):
-                    split = ms_ds
-                elif isinstance(ms_ds, dict):
-                    split = ms_ds.get("train") or ms_ds.get("default") or list(ms_ds.values())[0]
-                else:
-                    split = ms_ds
-                output_path = ds_dir / "data.jsonl"
-                record_count = 0
-                with open(output_path, "w", encoding="utf-8") as f:
-                    for item in split:
-                        f.write(json.dumps({k: str(v) for k, v in item.items()}, ensure_ascii=False) + "\n")
-                        record_count += 1
-                if record_count == 0:
-                    raise RuntimeError("MsDataset loaded but returned 0 records")
-                jsonl_path = output_path
-            except (ImportError, RuntimeError) as e:
-                # 回退到 CLI 下载方式
-                logger.warning(f"MsDataset.load failed: {e}, falling back to CLI download")
-                proc = subprocess.run(
-                    [
-                        "modelscope", "download",
-                        "--dataset", req.dataset_id,
-                        "--local_dir", str(ds_dir),
-                    ],
-                    capture_output=True, text=True, timeout=3600,
-                )
-                if proc.returncode != 0:
-                    raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
-                jsonl_path, record_count = _scan_and_convert_to_jsonl(ds_dir)
-                if record_count == 0:
-                    raise RuntimeError("No training data found in downloaded dataset files")
+
+            # 使用 ModelScope SDK 加载数据集
+            from modelscope.msdatasets import MsDataset
+            ms_ds = MsDataset.load(req.dataset_id, subset_name='default', split='train')
+            output_path = ds_dir / "data.jsonl"
+            record_count = 0
+            with open(output_path, "w", encoding="utf-8") as f:
+                for item in ms_ds:
+                    f.write(json.dumps(item, ensure_ascii=False) + "\n")
+                    record_count += 1
+
+            if record_count == 0:
+                raise RuntimeError("MsDataset loaded but returned 0 records")
+            jsonl_path = output_path
         else:
             from datasets import load_dataset
 

+ 19 - 8
backend/app/services/model_test_service.py

@@ -17,9 +17,10 @@ async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temp
 def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
     """通过 SSH 在算力节点执行模型测试。
 
-    先 scp 脚本到远端,再 docker exec 执行文件,完全避开 heredoc/引号/管道问题。
+    流程:scp 到远端宿主机 → docker cp 传入容器 → docker exec 执行 → 清理
     """
     import json
+    import os
     import tempfile
     from app.core.remote_executor import scp_to_remote, ssh_exec
 
@@ -86,29 +87,39 @@ gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)
 print(json.dumps({'generated_text': gen}))
 """ % str(temperature > 0).lower()
 
-    # 写入本地临时文件 → scp 到远端 → docker exec 执行 → 清理
-    remote_script = "/tmp/remote_model_test.py"
+    remote_script = "/tmp/_model_test.py"
     with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as tmp:
         tmp.write(python_script)
         tmp.flush()
         tmp_path = tmp.name
 
     try:
-        code, out, err = scp_to_remote(tmp_path, remote_script)
+        # Step 1: SCP 到远端宿主机
+        host_tmp = "/tmp/_model_test_host.py"
+        code, out, err = scp_to_remote(tmp_path, host_tmp)
         if code != 0:
             logger.error(f"SCP failed: {err}")
             return {"error": f"Failed to upload script: {err.strip()}"}
 
-        remote_cmd = f"docker exec -w {workdir} {container} {python} {remote_script} '{model_id}' '{prompt.replace(chr(39), chr(92)+chr(39))}' {max_new_tokens} {temperature} {top_p}"
-        code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
+        # Step 2: docker cp 把文件从宿主机传入容器
+        cp_cmd = f"docker cp {host_tmp} {container}:/tmp/_model_test.py"
+        code, out, err = ssh_exec(cp_cmd, timeout=10)
+        if code != 0:
+            logger.error(f"docker cp failed: {err}")
+            return {"error": f"Failed to copy script to container: {err.strip()}"}
+
+        # Step 3: docker exec 执行容器内的脚本
+        safe_prompt = prompt.replace("'", "\\'")
+        run_cmd = f"docker exec -w {workdir} {container} {python} /tmp/_model_test.py '{model_id}' '{safe_prompt}' {max_new_tokens} {temperature} {top_p}"
+        code, stdout, stderr = ssh_exec(run_cmd, timeout=600)
 
         if code != 0:
             logger.error(f"Remote model test failed: {stderr}")
             return {"error": stderr.strip() or "Remote test failed"}
     finally:
-        import os
         os.unlink(tmp_path)
-        ssh_exec(f"rm -f {remote_script}", timeout=10)
+        ssh_exec(f"rm -f /tmp/_model_test_host.py", timeout=10)
+        ssh_exec(f"docker exec {container} rm -f /tmp/_model_test.py", timeout=10)
 
     logger.info(f"Remote test result: code={code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}")
     if stdout: