|
@@ -1,214 +0,0 @@
|
|
|
-"""SSH 远程执行模块 — 在算力节点上运行 GPU 任务。"""
|
|
|
|
|
-import json
|
|
|
|
|
-import os
|
|
|
|
|
-import subprocess
|
|
|
|
|
-from typing import Any
|
|
|
|
|
-
|
|
|
|
|
-from app.config import get_settings
|
|
|
|
|
-from app.core.logging import logger
|
|
|
|
|
-
|
|
|
|
|
-settings = get_settings()
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def _get_ssh_prefix() -> list[str]:
|
|
|
|
|
- """构建 ssh/scp 命令前缀,支持密钥或密码登录。"""
|
|
|
|
|
- prefix = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]
|
|
|
|
|
- return prefix
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def scp_to_remote(local_path: str, remote_path: str) -> tuple[int, str, str]:
|
|
|
|
|
- """通过 SCP 把本地文件复制到远端主机,返回 (exit_code, stdout, stderr)。"""
|
|
|
|
|
- target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
|
|
|
|
|
- scp_args = ["scp", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)]
|
|
|
|
|
- if settings.compute_node_ssh_key:
|
|
|
|
|
- scp_args += ["-i", settings.compute_node_ssh_key]
|
|
|
|
|
- elif settings.compute_node_ssh_password:
|
|
|
|
|
- scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args
|
|
|
|
|
- scp_args += [local_path, f"{target}:{remote_path}"]
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=30)
|
|
|
|
|
- clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
|
|
|
|
|
- if not line.startswith("Warning:"))
|
|
|
|
|
- return proc.returncode, proc.stdout, clean_stderr
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"SCP failed: {e}")
|
|
|
|
|
- return -1, "", str(e)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def scp_to_remote_dir(local_path: str, remote_path: str) -> tuple[int, str, str]:
|
|
|
|
|
- """通过 SCP 把本地目录递归复制到远端主机。"""
|
|
|
|
|
- target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
|
|
|
|
|
- scp_args = ["scp", "-r", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)]
|
|
|
|
|
- if settings.compute_node_ssh_key:
|
|
|
|
|
- scp_args += ["-i", settings.compute_node_ssh_key]
|
|
|
|
|
- elif settings.compute_node_ssh_password:
|
|
|
|
|
- scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args
|
|
|
|
|
- scp_args += [local_path, f"{target}:{remote_path}"]
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=120)
|
|
|
|
|
- clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
|
|
|
|
|
- if not line.startswith("Warning:"))
|
|
|
|
|
- return proc.returncode, proc.stdout, clean_stderr
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"SCP dir failed: {e}")
|
|
|
|
|
- return -1, "", str(e)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def ssh_exec(cmd: str, timeout: int | None = None) -> tuple[int, str, str]:
|
|
|
|
|
- """通过 SSH 在算力节点执行命令,返回 (exit_code, stdout, stderr)。"""
|
|
|
|
|
- if not settings.use_remote_compute:
|
|
|
|
|
- raise RuntimeError("未配置算力节点(compute_node_host 为空)")
|
|
|
|
|
-
|
|
|
|
|
- target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
|
|
|
|
|
- ssh_args = [
|
|
|
|
|
- "ssh", *_get_ssh_prefix(),
|
|
|
|
|
- "-p", str(settings.compute_node_ssh_port),
|
|
|
|
|
- target,
|
|
|
|
|
- cmd,
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- # sshpass 需要包裹 ssh 命令,而不是作为 ssh 的参数
|
|
|
|
|
- if settings.compute_node_ssh_key:
|
|
|
|
|
- ssh_args = ["ssh", "-i", settings.compute_node_ssh_key] + ssh_args[1:]
|
|
|
|
|
- elif settings.compute_node_ssh_password:
|
|
|
|
|
- ssh_args = ["sshpass", "-p", settings.compute_node_ssh_password] + ssh_args
|
|
|
|
|
-
|
|
|
|
|
- timeout = timeout or settings.compute_node_ssh_timeout
|
|
|
|
|
- try:
|
|
|
|
|
- proc = subprocess.run(
|
|
|
|
|
- ssh_args,
|
|
|
|
|
- capture_output=True,
|
|
|
|
|
- text=True,
|
|
|
|
|
- timeout=timeout,
|
|
|
|
|
- )
|
|
|
|
|
- # 过滤 known_hosts 警告,这些不算真正的错误
|
|
|
|
|
- clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
|
|
|
|
|
- if not line.startswith("Warning:"))
|
|
|
|
|
- return proc.returncode, proc.stdout, clean_stderr
|
|
|
|
|
- except subprocess.TimeoutExpired:
|
|
|
|
|
- logger.error(f"SSH command timeout after {timeout}s: {cmd[:100]}")
|
|
|
|
|
- return -1, "", f"Command timed out after {timeout}s"
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"SSH exec failed: {e}")
|
|
|
|
|
- return -1, "", str(e)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def run_training_remote(
|
|
|
|
|
- job_id: str,
|
|
|
|
|
- model_id: str,
|
|
|
|
|
- model_type: str,
|
|
|
|
|
- dataset_path: str,
|
|
|
|
|
- config: dict[str, Any],
|
|
|
|
|
-) -> str | None:
|
|
|
|
|
- """在算力节点启动训练任务(通过 docker exec,后台执行)。
|
|
|
|
|
-
|
|
|
|
|
- 通过 SCP 把配置文件传到远端宿主机,再在容器内启动训练。
|
|
|
|
|
- dataset_path 由主节点预先查好,直接传给远程脚本。
|
|
|
|
|
- """
|
|
|
|
|
- import tempfile
|
|
|
|
|
-
|
|
|
|
|
- # 在 151 宿主机创建临时配置文件
|
|
|
|
|
- config_tmp = tempfile.mktemp(suffix=".json", prefix=f"config_{job_id}_")
|
|
|
|
|
- with open(config_tmp, "w", encoding="utf-8") as f:
|
|
|
|
|
- json.dump(config, f, ensure_ascii=False)
|
|
|
|
|
-
|
|
|
|
|
- # SCP 到远端宿主机(使用 data_dir,这个目录已通过 bind mount 共享给容器)
|
|
|
|
|
- remote_config_path = f"{settings.compute_node_remote_data_dir}/config_{job_id}.json"
|
|
|
|
|
- ret_code, _, _ = scp_to_remote(config_tmp, f"{remote_config_path}")
|
|
|
|
|
- os.unlink(config_tmp) # 删除本地临时文件
|
|
|
|
|
-
|
|
|
|
|
- if ret_code != 0:
|
|
|
|
|
- logger.error(f"SCP config file failed: ret_code={ret_code}")
|
|
|
|
|
- return None
|
|
|
|
|
-
|
|
|
|
|
- # 把数据集路径也传到远程(SCP 到 data/uploads/ 目录)
|
|
|
|
|
- remote_dataset_name = os.path.basename(dataset_path)
|
|
|
|
|
- remote_dataset_path = f"{settings.compute_node_remote_data_dir}/datasets/{remote_dataset_name}"
|
|
|
|
|
-
|
|
|
|
|
- if os.path.isdir(dataset_path):
|
|
|
|
|
- # 目录:用 scp -r
|
|
|
|
|
- ret_code, _, _ = scp_to_remote_dir(dataset_path, remote_dataset_path)
|
|
|
|
|
- else:
|
|
|
|
|
- # 文件:普通 scp
|
|
|
|
|
- ret_code, _, _ = scp_to_remote(dataset_path, remote_dataset_path)
|
|
|
|
|
-
|
|
|
|
|
- if ret_code != 0:
|
|
|
|
|
- logger.error(f"SCP dataset failed: ret_code={ret_code}")
|
|
|
|
|
- return None
|
|
|
|
|
-
|
|
|
|
|
- # 在容器内启动训练
|
|
|
|
|
- remote_cmd = (
|
|
|
|
|
- f"docker exec -w {settings.compute_node_workdir} "
|
|
|
|
|
- f"{settings.compute_node_docker_container} "
|
|
|
|
|
- f"bash -c '"
|
|
|
|
|
- f"nohup {settings.compute_node_python} -m app.engines.remote_train "
|
|
|
|
|
- f"{job_id} {model_id} {model_type} {remote_dataset_path} {remote_config_path} "
|
|
|
|
|
- f"</dev/null >/tmp/train_{job_id}.log 2>&1 & echo $!'"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
|
|
|
|
|
-
|
|
|
|
|
- if code != 0:
|
|
|
|
|
- logger.error(f"Remote training launch failed: {stderr}")
|
|
|
|
|
- return None
|
|
|
|
|
-
|
|
|
|
|
- pid = stdout.strip()
|
|
|
|
|
- logger.info(f"Remote training launched in container: job={job_id}, container_pid={pid}")
|
|
|
|
|
- return pid
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def is_process_running(pid: str) -> bool:
|
|
|
|
|
- """检查远程训练进程是否还在运行。
|
|
|
|
|
-
|
|
|
|
|
- 通过 docker exec 进入容器检查 PID 是否存在。
|
|
|
|
|
- """
|
|
|
|
|
- cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'kill -0 {pid} 2>/dev/null && echo running || echo stopped'"
|
|
|
|
|
- code, stdout, stderr = ssh_exec(cmd, timeout=30)
|
|
|
|
|
- return code == 0 and "running" in stdout
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def run_inference_remote(
|
|
|
|
|
- model_id: str,
|
|
|
|
|
- adapter_id: str,
|
|
|
|
|
- prompt: str,
|
|
|
|
|
- max_new_tokens: int,
|
|
|
|
|
- temperature: float,
|
|
|
|
|
- top_p: float,
|
|
|
|
|
- repetition_penalty: float,
|
|
|
|
|
- do_sample: bool,
|
|
|
|
|
-) -> dict[str, Any] | None:
|
|
|
|
|
- """在算力节点执行推理。"""
|
|
|
|
|
- safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
|
|
|
|
|
-
|
|
|
|
|
- remote_cmd = (
|
|
|
|
|
- f"docker exec {settings.compute_node_docker_container} "
|
|
|
|
|
- f"{settings.compute_node_python} -c \""
|
|
|
|
|
- "import asyncio, json; "
|
|
|
|
|
- "from app.config import get_settings; "
|
|
|
|
|
- "settings = get_settings(); "
|
|
|
|
|
- "from app.services.inference_service import run_inference_single; "
|
|
|
|
|
- f"result = asyncio.run(run_inference_single("
|
|
|
|
|
- f"'{model_id}', '{adapter_id}', '{safe_prompt}', "
|
|
|
|
|
- f"{max_new_tokens}, {temperature}, {top_p}, {repetition_penalty}, {str(do_sample).lower()}"
|
|
|
|
|
- ")); "
|
|
|
|
|
- "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
|
|
|
|
|
-
|
|
|
|
|
- if code != 0:
|
|
|
|
|
- logger.error(f"Remote inference failed: {stderr}")
|
|
|
|
|
- return {"error": stderr.strip() or "Remote inference failed"}
|
|
|
|
|
-
|
|
|
|
|
- # 提取最后一行 JSON
|
|
|
|
|
- for line in reversed(stdout.strip().split("\n")):
|
|
|
|
|
- line = line.strip()
|
|
|
|
|
- if line.startswith("{"):
|
|
|
|
|
- try:
|
|
|
|
|
- return json.loads(line)
|
|
|
|
|
- except json.JSONDecodeError:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- return {"error": f"Invalid JSON response: {stdout[:500]}"}
|
|
|