Răsfoiți Sursa

修复dpo报错

lxylxy123321 16 ore în urmă
părinte
comite
10e8f5649f
3 a modificat fișierele cu 18 adăugiri și 11 ștergeri
  1. 12 9
      backend/app/core/remote_executor.py
  2. 4 0
      backend/app/engines/text_engine.py
  3. 2 2
      test.py

+ 12 - 9
backend/app/core/remote_executor.py

@@ -37,27 +37,30 @@ def scp_to_remote(local_path: str, remote_path: str) -> tuple[int, str, str]:
 
 
 def scp_from_remote_dir(remote_path: str, local_path: str, timeout: int = 600) -> tuple[int, str, str]:
-    """通过 SCP 从远端主机递归复制目录到本地,返回 (exit_code, stdout, stderr)。
+    """通过 rsync 从远端主机复制目录到本地,返回 (exit_code, stdout, stderr)。
+    使用 rsync 替代 scp -r,避免符号链接导致的 "Maximum directory depth exceeded" 错误。
     timeout 默认 10 分钟,足够传输 20G 的 adapter(千兆网约 3 分钟)。
     """
     target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
-    scp_args = ["scp", "-r", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)]
+    ssh_opts = f"-p {settings.compute_node_ssh_port} " + " ".join(f"-o {o}" for o in ["StrictHostKeyChecking=no", "ConnectTimeout=30"])
+    rsync_args = ["rsync", "-az", "--copy-links", "-e", f"ssh {ssh_opts}"]
     if settings.compute_node_ssh_key:
-        scp_args += ["-i", settings.compute_node_ssh_key]
+        rsync_args = ["rsync", "-az", "--copy-links", "-e", f"ssh {ssh_opts} -i {settings.compute_node_ssh_key}"]
     elif settings.compute_node_ssh_password:
-        scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args
-    scp_args += [f"{target}:{remote_path}", local_path]
+        rsync_args = ["sshpass", "-p", settings.compute_node_ssh_password] + rsync_args
+    # rsync 源路径末尾加 / 表示复制目录内容而非目录本身
+    rsync_args += [f"{target}:{remote_path}/", local_path]
 
     try:
-        proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=timeout)
+        proc = subprocess.run(rsync_args, capture_output=True, text=True, timeout=timeout)
         clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
                                   if not line.startswith("Warning:"))
         return proc.returncode, proc.stdout, clean_stderr
     except subprocess.TimeoutExpired:
-        logger.error(f"SCP from remote timeout after {timeout}s: {remote_path}")
-        return -1, "", f"SCP timed out after {timeout}s"
+        logger.error(f"rsync from remote timeout after {timeout}s: {remote_path}")
+        return -1, "", f"rsync timed out after {timeout}s"
     except Exception as e:
-        logger.error(f"SCP from remote failed: {e}")
+        logger.error(f"rsync from remote failed: {e}")
         return -1, "", str(e)
 
 

+ 4 - 0
backend/app/engines/text_engine.py

@@ -271,6 +271,10 @@ class TextEngine(BaseEngine):
         elif task_type == "dpo":
             from copy import deepcopy
 
+            # 兼容旧版 transformers(缺少 MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
+            import transformers.models.auto.modeling_auto as _ma
+            if not hasattr(_ma, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES"):
+                _ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
             from trl import DPOConfig, DPOTrainer
 
             # 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突

+ 2 - 2
test.py

@@ -1,13 +1,13 @@
 from openai import OpenAI
 
 client = OpenAI(
-    base_url="http://192.168.92.151:3000/api/v1/deployment/proxy/1ce356ff-49d3-48da-a583-46e87c2776da/v1",
+    base_url="http://192.168.92.151:3000/api/v1/deployment/proxy/3c0f3e87-f0df-45a3-b6bc-2ba64ed38e89/v1",
     api_key="sk-1wTkTvsfu0IiyZFhNAx8HMgtIf2TxLGP-DyrcNKYlIc"  # 替换为你的 API Key
 )
 
 response = client.chat.completions.create(
     model="local-model",
-    messages=[{"role": "user", "content": "你是谁,是哪个模型,详细说一下"}],
+    messages=[{"role": "user", "content": "什么是 API?"}],
     max_tokens=512,
     temperature=0.7
 )