"""SSH 远程执行模块 — 在算力节点上运行 GPU 任务。""" import json import os import subprocess from typing import Any from app.config import get_settings from app.core.logging import logger settings = get_settings() def _get_ssh_prefix() -> list[str]: """构建 ssh/scp 命令前缀,支持密钥或密码登录。""" prefix = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"] return prefix def scp_to_remote(local_path: str, remote_path: str) -> tuple[int, str, str]: """通过 SCP 把本地文件复制到远端主机,返回 (exit_code, stdout, stderr)。""" target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}" scp_args = ["scp", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)] if settings.compute_node_ssh_key: scp_args += ["-i", settings.compute_node_ssh_key] elif settings.compute_node_ssh_password: scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args scp_args += [local_path, f"{target}:{remote_path}"] try: proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=30) clean_stderr = "\n".join(line for line in proc.stderr.split("\n") if not line.startswith("Warning:")) return proc.returncode, proc.stdout, clean_stderr except Exception as e: logger.error(f"SCP failed: {e}") return -1, "", str(e) def ssh_exec(cmd: str, timeout: int | None = None) -> tuple[int, str, str]: """通过 SSH 在算力节点执行命令,返回 (exit_code, stdout, stderr)。""" if not settings.use_remote_compute: raise RuntimeError("未配置算力节点(compute_node_host 为空)") target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}" ssh_args = [ "ssh", *_get_ssh_prefix(), "-p", str(settings.compute_node_ssh_port), target, cmd, ] # sshpass 需要包裹 ssh 命令,而不是作为 ssh 的参数 if settings.compute_node_ssh_key: ssh_args = ["ssh", "-i", settings.compute_node_ssh_key] + ssh_args[1:] elif settings.compute_node_ssh_password: ssh_args = ["sshpass", "-p", settings.compute_node_ssh_password] + ssh_args timeout = timeout or settings.compute_node_ssh_timeout try: proc = subprocess.run( ssh_args, capture_output=True, text=True, timeout=timeout, ) # 过滤 known_hosts 警告,这些不算真正的错误 clean_stderr = "\n".join(line for line in proc.stderr.split("\n") if not line.startswith("Warning:")) return proc.returncode, proc.stdout, clean_stderr except subprocess.TimeoutExpired: logger.error(f"SSH command timeout after {timeout}s: {cmd[:100]}") return -1, "", f"Command timed out after {timeout}s" except Exception as e: logger.error(f"SSH exec failed: {e}") return -1, "", str(e) def run_training_remote( job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict[str, Any], ) -> str | None: """在算力节点启动训练任务(通过 docker exec,后台执行)。 通过 docker exec -i 将配置传入容器内,避免宿主机/容器路径混淆。 """ import base64 config_json = json.dumps(config, ensure_ascii=False) config_b64 = base64.b64encode(config_json.encode()).decode() config_file = f"/tmp/config_{job_id}.json" # 通过 docker exec -i 把配置传入容器内,在容器里写入临时文件并启动训练 remote_cmd = ( f"echo '{config_b64}' | base64 -d | " f"docker exec -i {settings.compute_node_docker_container} bash -c '" f"cat > {config_file} && " f"cd {settings.compute_node_workdir} && " f"nohup {settings.compute_node_python} -m app.engines.remote_train " f"{job_id} {model_id} {model_type} {dataset_id} {config_file} " f">/tmp/train_{job_id}.log 2>&1 & echo $!'" ) code, stdout, stderr = ssh_exec(remote_cmd, timeout=30) if code != 0: logger.error(f"Remote training launch failed: {stderr}") return None pid = stdout.strip() logger.info(f"Remote training launched in container: job={job_id}, container_pid={pid}") return pid def is_process_running(pid: str) -> bool: """检查远程训练进程是否还在运行。 通过 docker exec 进入容器检查 PID 是否存在。 """ cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'kill -0 {pid} 2>/dev/null && echo running || echo stopped'" code, stdout, stderr = ssh_exec(cmd, timeout=10) return code == 0 and "running" in stdout def run_inference_remote( model_id: str, adapter_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, do_sample: bool, ) -> dict[str, Any] | None: """在算力节点执行推理。""" safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n") remote_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"{settings.compute_node_python} -c \"" "import asyncio, json; " "from app.config import get_settings; " "settings = get_settings(); " "from app.services.inference_service import run_inference_single; " f"result = asyncio.run(run_inference_single(" f"'{model_id}', '{adapter_id}', '{safe_prompt}', " f"{max_new_tokens}, {temperature}, {top_p}, {repetition_penalty}, {str(do_sample).lower()}" ")); " "print(json.dumps(result, ensure_ascii=False))\" 2>&1" ) code, stdout, stderr = ssh_exec(remote_cmd, timeout=600) if code != 0: logger.error(f"Remote inference failed: {stderr}") return {"error": stderr.strip() or "Remote inference failed"} # 提取最后一行 JSON for line in reversed(stdout.strip().split("\n")): line = line.strip() if line.startswith("{"): try: return json.loads(line) except json.JSONDecodeError: continue return {"error": f"Invalid JSON response: {stdout[:500]}"}