| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- """SSH 远程执行模块 — 在算力节点上运行 GPU 任务。"""
- import json
- import os
- import subprocess
- from typing import Any
- from app.config import get_settings
- from app.core.logging import logger
- settings = get_settings()
- def _get_ssh_prefix() -> list[str]:
- """构建 ssh/scp 命令前缀,支持密钥或密码登录。"""
- prefix = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]
- if settings.compute_node_ssh_key:
- prefix.extend(["-i", settings.compute_node_ssh_key])
- elif settings.compute_node_ssh_password:
- prefix = ["sshpass", "-p", settings.compute_node_ssh_password] + prefix
- return prefix
- def ssh_exec(cmd: str, timeout: int | None = None) -> tuple[int, str, str]:
- """通过 SSH 在算力节点执行命令,返回 (exit_code, stdout, stderr)。"""
- if not settings.use_remote_compute:
- raise RuntimeError("未配置算力节点(compute_node_host 为空)")
- target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
- ssh_cmd = [
- "ssh", *_get_ssh_prefix(),
- "-p", str(settings.compute_node_ssh_port),
- target,
- cmd,
- ]
- timeout = timeout or settings.compute_node_ssh_timeout
- try:
- proc = subprocess.run(
- ssh_cmd,
- capture_output=True,
- text=True,
- timeout=timeout,
- )
- return proc.returncode, proc.stdout, proc.stderr
- except subprocess.TimeoutExpired:
- logger.error(f"SSH command timeout after {timeout}s: {cmd[:100]}")
- return -1, "", f"Command timed out after {timeout}s"
- except Exception as e:
- logger.error(f"SSH exec failed: {e}")
- return -1, "", str(e)
- def run_training_remote(
- job_id: str,
- model_id: str,
- model_type: str,
- dataset_id: str,
- config: dict[str, Any],
- ) -> str | None:
- """在算力节点启动训练任务(通过 docker exec,后台执行)。
- 在容器内用 nohup 启动训练,返回 PID 以便后续检测。
- """
- config_json = json.dumps(config, ensure_ascii=False)
- config_escaped = config_json.replace('"', '\\"')
- remote_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'nohup {settings.compute_node_python} -m app.engines.remote_train "
- f"'{job_id}' '{model_id}' '{model_type}' '{dataset_id}' '{config_escaped}' "
- f">/tmp/train_{job_id}.log 2>&1 & echo $!'"
- )
- code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
- if code != 0:
- logger.error(f"Remote training launch failed: {stderr}")
- return None
- pid = stdout.strip()
- logger.info(f"Remote training launched in container: job={job_id}, container_pid={pid}")
- return pid
- def is_process_running(pid: str) -> bool:
- """检查远程训练进程是否还在运行。
- 通过 docker exec 进入容器检查 PID 是否存在。
- """
- cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'kill -0 {pid} 2>/dev/null && echo running || echo stopped'"
- code, stdout, stderr = ssh_exec(cmd, timeout=10)
- return code == 0 and "running" in stdout
- def run_inference_remote(
- model_id: str,
- adapter_id: str,
- prompt: str,
- max_new_tokens: int,
- temperature: float,
- top_p: float,
- repetition_penalty: float,
- do_sample: bool,
- ) -> dict[str, Any] | None:
- """在算力节点执行推理。"""
- safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
- remote_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"{settings.compute_node_python} -c \""
- "import asyncio, json; "
- "from app.config import get_settings; "
- "settings = get_settings(); "
- "from app.services.inference_service import run_inference_single; "
- f"result = asyncio.run(run_inference_single("
- f"'{model_id}', '{adapter_id}', '{safe_prompt}', "
- f"{max_new_tokens}, {temperature}, {top_p}, {repetition_penalty}, {str(do_sample).lower()}"
- ")); "
- "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
- )
- code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
- if code != 0:
- logger.error(f"Remote inference failed: {stderr}")
- return {"error": stderr.strip() or "Remote inference failed"}
- # 提取最后一行 JSON
- for line in reversed(stdout.strip().split("\n")):
- line = line.strip()
- if line.startswith("{"):
- try:
- return json.loads(line)
- except json.JSONDecodeError:
- continue
- return {"error": f"Invalid JSON response: {stdout[:500]}"}
|