remote_executor.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """SSH 远程执行模块 — 在算力节点上运行 GPU 任务。"""
  2. import json
  3. import os
  4. import subprocess
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.logging import logger
  8. settings = get_settings()
  9. def _get_ssh_prefix() -> list[str]:
  10. """构建 ssh/scp 命令前缀,支持密钥或密码登录。"""
  11. prefix = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]
  12. return prefix
  13. def scp_to_remote(local_path: str, remote_path: str) -> tuple[int, str, str]:
  14. """通过 SCP 把本地文件复制到远端主机,返回 (exit_code, stdout, stderr)。"""
  15. target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
  16. scp_args = ["scp", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)]
  17. if settings.compute_node_ssh_key:
  18. scp_args += ["-i", settings.compute_node_ssh_key]
  19. elif settings.compute_node_ssh_password:
  20. scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args
  21. scp_args += [local_path, f"{target}:{remote_path}"]
  22. try:
  23. proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=30)
  24. clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
  25. if not line.startswith("Warning:"))
  26. return proc.returncode, proc.stdout, clean_stderr
  27. except Exception as e:
  28. logger.error(f"SCP failed: {e}")
  29. return -1, "", str(e)
  30. def ssh_exec(cmd: str, timeout: int | None = None) -> tuple[int, str, str]:
  31. """通过 SSH 在算力节点执行命令,返回 (exit_code, stdout, stderr)。"""
  32. if not settings.use_remote_compute:
  33. raise RuntimeError("未配置算力节点(compute_node_host 为空)")
  34. target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
  35. ssh_args = [
  36. "ssh", *_get_ssh_prefix(),
  37. "-p", str(settings.compute_node_ssh_port),
  38. target,
  39. cmd,
  40. ]
  41. # sshpass 需要包裹 ssh 命令,而不是作为 ssh 的参数
  42. if settings.compute_node_ssh_key:
  43. ssh_args = ["ssh", "-i", settings.compute_node_ssh_key] + ssh_args[1:]
  44. elif settings.compute_node_ssh_password:
  45. ssh_args = ["sshpass", "-p", settings.compute_node_ssh_password] + ssh_args
  46. timeout = timeout or settings.compute_node_ssh_timeout
  47. try:
  48. proc = subprocess.run(
  49. ssh_args,
  50. capture_output=True,
  51. text=True,
  52. timeout=timeout,
  53. )
  54. # 过滤 known_hosts 警告,这些不算真正的错误
  55. clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
  56. if not line.startswith("Warning:"))
  57. return proc.returncode, proc.stdout, clean_stderr
  58. except subprocess.TimeoutExpired:
  59. logger.error(f"SSH command timeout after {timeout}s: {cmd[:100]}")
  60. return -1, "", f"Command timed out after {timeout}s"
  61. except Exception as e:
  62. logger.error(f"SSH exec failed: {e}")
  63. return -1, "", str(e)
  64. def run_training_remote(
  65. job_id: str,
  66. model_id: str,
  67. model_type: str,
  68. dataset_id: str,
  69. config: dict[str, Any],
  70. ) -> str | None:
  71. """在算力节点启动训练任务(通过 docker exec,后台执行)。
  72. 在容器内用 nohup 启动训练,返回 PID 以便后续检测。
  73. 配置通过 base64 编码写入远端临时文件,避免 shell 引号/转义问题。
  74. """
  75. import base64
  76. config_json = json.dumps(config, ensure_ascii=False)
  77. config_b64 = base64.b64encode(config_json.encode()).decode()
  78. config_file = f"/tmp/config_{job_id}.json"
  79. # 远端容器内执行的脚本:解码 base64 → 写临时文件 → 启动训练
  80. inner_script = (
  81. f"echo '{config_b64}' | base64 -d > {config_file} && "
  82. f"nohup {settings.compute_node_python} -m app.engines.remote_train "
  83. f"{job_id} {model_id} {model_type} {dataset_id} {config_file} "
  84. f">/tmp/train_{job_id}.log 2>&1 & echo $!"
  85. )
  86. remote_cmd = f"docker exec {settings.compute_node_docker_container} bash -c '{inner_script}'"
  87. code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
  88. if code != 0:
  89. logger.error(f"Remote training launch failed: {stderr}")
  90. return None
  91. pid = stdout.strip()
  92. logger.info(f"Remote training launched in container: job={job_id}, container_pid={pid}")
  93. return pid
  94. def is_process_running(pid: str) -> bool:
  95. """检查远程训练进程是否还在运行。
  96. 通过 docker exec 进入容器检查 PID 是否存在。
  97. """
  98. cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'kill -0 {pid} 2>/dev/null && echo running || echo stopped'"
  99. code, stdout, stderr = ssh_exec(cmd, timeout=10)
  100. return code == 0 and "running" in stdout
  101. def run_inference_remote(
  102. model_id: str,
  103. adapter_id: str,
  104. prompt: str,
  105. max_new_tokens: int,
  106. temperature: float,
  107. top_p: float,
  108. repetition_penalty: float,
  109. do_sample: bool,
  110. ) -> dict[str, Any] | None:
  111. """在算力节点执行推理。"""
  112. safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
  113. remote_cmd = (
  114. f"docker exec {settings.compute_node_docker_container} "
  115. f"{settings.compute_node_python} -c \""
  116. "import asyncio, json; "
  117. "from app.config import get_settings; "
  118. "settings = get_settings(); "
  119. "from app.services.inference_service import run_inference_single; "
  120. f"result = asyncio.run(run_inference_single("
  121. f"'{model_id}', '{adapter_id}', '{safe_prompt}', "
  122. f"{max_new_tokens}, {temperature}, {top_p}, {repetition_penalty}, {str(do_sample).lower()}"
  123. ")); "
  124. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  125. )
  126. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  127. if code != 0:
  128. logger.error(f"Remote inference failed: {stderr}")
  129. return {"error": stderr.strip() or "Remote inference failed"}
  130. # 提取最后一行 JSON
  131. for line in reversed(stdout.strip().split("\n")):
  132. line = line.strip()
  133. if line.startswith("{"):
  134. try:
  135. return json.loads(line)
  136. except json.JSONDecodeError:
  137. continue
  138. return {"error": f"Invalid JSON response: {stdout[:500]}"}