remote_executor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """SSH 远程执行模块 — 在算力节点上运行 GPU 任务。"""
  2. import json
  3. import os
  4. import subprocess
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.logging import logger
  8. settings = get_settings()
  9. def _get_ssh_prefix() -> list[str]:
  10. """构建 ssh/scp 命令前缀,支持密钥或密码登录。"""
  11. prefix = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=30"]
  12. return prefix
  13. def scp_to_remote(local_path: str, remote_path: str) -> tuple[int, str, str]:
  14. """通过 SCP 把本地文件复制到远端主机,返回 (exit_code, stdout, stderr)。"""
  15. target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
  16. scp_args = ["scp", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)]
  17. if settings.compute_node_ssh_key:
  18. scp_args += ["-i", settings.compute_node_ssh_key]
  19. elif settings.compute_node_ssh_password:
  20. scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args
  21. scp_args += [local_path, f"{target}:{remote_path}"]
  22. try:
  23. proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=30)
  24. clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
  25. if not line.startswith("Warning:"))
  26. return proc.returncode, proc.stdout, clean_stderr
  27. except Exception as e:
  28. logger.error(f"SCP failed: {e}")
  29. return -1, "", str(e)
  30. def scp_to_remote_dir(local_path: str, remote_path: str) -> tuple[int, str, str]:
  31. """通过 SCP 把本地目录递归复制到远端主机。"""
  32. target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
  33. scp_args = ["scp", "-r", *_get_ssh_prefix(), "-P", str(settings.compute_node_ssh_port)]
  34. if settings.compute_node_ssh_key:
  35. scp_args += ["-i", settings.compute_node_ssh_key]
  36. elif settings.compute_node_ssh_password:
  37. scp_args = ["sshpass", "-p", settings.compute_node_ssh_password] + scp_args
  38. scp_args += [local_path, f"{target}:{remote_path}"]
  39. try:
  40. proc = subprocess.run(scp_args, capture_output=True, text=True, timeout=120)
  41. clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
  42. if not line.startswith("Warning:"))
  43. return proc.returncode, proc.stdout, clean_stderr
  44. except Exception as e:
  45. logger.error(f"SCP dir failed: {e}")
  46. return -1, "", str(e)
  47. def ssh_exec(cmd: str, timeout: int | None = None) -> tuple[int, str, str]:
  48. """通过 SSH 在算力节点执行命令,返回 (exit_code, stdout, stderr)。"""
  49. if not settings.use_remote_compute:
  50. raise RuntimeError("未配置算力节点(compute_node_host 为空)")
  51. target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
  52. ssh_args = [
  53. "ssh", *_get_ssh_prefix(),
  54. "-p", str(settings.compute_node_ssh_port),
  55. target,
  56. cmd,
  57. ]
  58. # sshpass 需要包裹 ssh 命令,而不是作为 ssh 的参数
  59. if settings.compute_node_ssh_key:
  60. ssh_args = ["ssh", "-i", settings.compute_node_ssh_key] + ssh_args[1:]
  61. elif settings.compute_node_ssh_password:
  62. ssh_args = ["sshpass", "-p", settings.compute_node_ssh_password] + ssh_args
  63. timeout = timeout or settings.compute_node_ssh_timeout
  64. try:
  65. proc = subprocess.run(
  66. ssh_args,
  67. capture_output=True,
  68. text=True,
  69. timeout=timeout,
  70. )
  71. # 过滤 known_hosts 警告,这些不算真正的错误
  72. clean_stderr = "\n".join(line for line in proc.stderr.split("\n")
  73. if not line.startswith("Warning:"))
  74. return proc.returncode, proc.stdout, clean_stderr
  75. except subprocess.TimeoutExpired:
  76. logger.error(f"SSH command timeout after {timeout}s: {cmd[:100]}")
  77. return -1, "", f"Command timed out after {timeout}s"
  78. except Exception as e:
  79. logger.error(f"SSH exec failed: {e}")
  80. return -1, "", str(e)
  81. def run_training_remote(
  82. job_id: str,
  83. model_id: str,
  84. model_type: str,
  85. dataset_path: str,
  86. config: dict[str, Any],
  87. ) -> str | None:
  88. """在算力节点启动训练任务(通过 docker exec,后台执行)。
  89. 通过 SCP 把配置文件传到远端宿主机,再在容器内启动训练。
  90. dataset_path 由主节点预先查好,直接传给远程脚本。
  91. """
  92. import tempfile
  93. # 在 151 宿主机创建临时配置文件
  94. config_tmp = tempfile.mktemp(suffix=".json", prefix=f"config_{job_id}_")
  95. with open(config_tmp, "w", encoding="utf-8") as f:
  96. json.dump(config, f, ensure_ascii=False)
  97. # SCP 到远端宿主机(使用 data_dir,这个目录已通过 bind mount 共享给容器)
  98. remote_config_path = f"{settings.compute_node_remote_data_dir}/config_{job_id}.json"
  99. remote_config_dir = os.path.dirname(remote_config_path)
  100. _, _, mkdir_stderr = ssh_exec(f"mkdir -p {remote_config_dir}")
  101. ret_code, stdout, stderr = scp_to_remote(config_tmp, f"{remote_config_path}")
  102. os.unlink(config_tmp) # 删除本地临时文件
  103. if ret_code != 0:
  104. logger.error(f"SCP config file failed: ret_code={ret_code}, stderr={stderr}")
  105. return None
  106. # 把数据集路径也传到远程(SCP 到 data/uploads/ 目录)
  107. remote_dataset_name = os.path.basename(dataset_path)
  108. remote_dataset_path = f"{settings.compute_node_remote_data_dir}/datasets/{remote_dataset_name}"
  109. # 确保远程父目录存在
  110. remote_dataset_dir = os.path.dirname(remote_dataset_path)
  111. _, _, mkdir_stderr = ssh_exec(f"mkdir -p {remote_dataset_dir}")
  112. logger.info(f"Created remote dataset directory: {remote_dataset_dir}")
  113. if os.path.isdir(dataset_path):
  114. # 目录:用 scp -r
  115. logger.info(f"Uploading dataset directory: {dataset_path} -> {remote_dataset_path}")
  116. ret_code, _, stderr = scp_to_remote_dir(dataset_path, remote_dataset_path)
  117. else:
  118. # 文件:普通 scp
  119. logger.info(f"Uploading dataset file: {dataset_path} -> {remote_dataset_path}")
  120. ret_code, _, stderr = scp_to_remote(dataset_path, remote_dataset_path)
  121. if ret_code != 0:
  122. logger.error(f"SCP dataset failed: ret_code={ret_code}, stderr={stderr}")
  123. return None
  124. logger.info(f"Dataset uploaded successfully: {remote_dataset_path}")
  125. # 在容器内启动训练
  126. # 日志写容器内的 /tmp,同时追加写到共享数据目录(宿主机可直接查看)
  127. remote_log_dir = f"{settings.compute_node_remote_data_dir}/logs"
  128. _, _, _ = ssh_exec(f"mkdir -p {remote_log_dir}")
  129. # 根据 num_gpus 构建 GPU 配置
  130. num_gpus = config.get("num_gpus", 1)
  131. if num_gpus >= 2:
  132. cuda_devices = "2,3" # 物理 GPU 2 和 3
  133. # 多 GPU:使用 torchrun 启动 DDP
  134. launch_cmd = (
  135. f"{settings.compute_node_python} -m torch.distributed.run "
  136. f"--nproc_per_node={num_gpus} --nnodes=1 "
  137. f"-m app.engines.remote_train "
  138. f"{job_id} {model_id} {model_type} {remote_dataset_path} {remote_config_path}"
  139. )
  140. extra_env = (
  141. f"-e NCCL_TIMEOUT=1800 "
  142. f"-e TORCH_DISTRIBUTED_DEBUG=OFF "
  143. )
  144. logger.info(f"Multi-GPU training: num_gpus={num_gpus}, CUDA_VISIBLE_DEVICES={cuda_devices}")
  145. else:
  146. cuda_devices = "3" # 单 GPU
  147. launch_cmd = (
  148. f"{settings.compute_node_python} -m app.engines.remote_train "
  149. f"{job_id} {model_id} {model_type} {remote_dataset_path} {remote_config_path}"
  150. )
  151. extra_env = ""
  152. # 使用 setsid 启动训练进程,确保进程组独立,kill 时能正确清理子进程
  153. remote_cmd = (
  154. f"docker exec "
  155. f"-e MACA_MPS_MODE=1 "
  156. f"-e CUDA_VISIBLE_DEVICES={cuda_devices} "
  157. f"-e PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True "
  158. f"{extra_env}"
  159. f"-w {settings.compute_node_workdir} "
  160. f"{settings.compute_node_docker_container} "
  161. f"bash -c '"
  162. f"setsid {launch_cmd} "
  163. f"</dev/null >/tmp/train_{job_id}.log 2>&1 &"
  164. f" disown; echo $!'"
  165. )
  166. code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
  167. if code != 0:
  168. logger.error(f"Remote training launch failed: {stderr}")
  169. return None
  170. pid = stdout.strip()
  171. logger.info(f"Remote training launched in container: job={job_id}, container_pid={pid}")
  172. return pid
  173. def is_process_running(pid: str, retries: int = 3) -> bool:
  174. """检查远程训练/推理进程是否还在运行。
  175. 通过 docker exec 进入容器,优先用 kill -0 检查指定 PID,
  176. 兜底用 ps 检查是否存在匹配的 Python 进程。
  177. 失败时重试,避免因单次 SSH 超时误判。
  178. """
  179. for attempt in range(retries):
  180. cmd = (
  181. f"docker exec {settings.compute_node_docker_container} bash -c '"
  182. f"if kill -0 {pid} 2>/dev/null; then "
  183. f" state=$(cat /proc/{pid}/stat 2>/dev/null | awk \"{{{{print \\$3}}}}\"); "
  184. f" if [ \"$state\" = \"Z\" ]; then echo stopped; else echo running; "
  185. f" fi; "
  186. f"else "
  187. f" echo stopped; "
  188. f"fi'"
  189. )
  190. code, stdout, stderr = ssh_exec(cmd, timeout=30)
  191. if code != 0:
  192. # SSH/docker exec 本身失败(容器可能挂了或网络抖动),重试
  193. if attempt < retries - 1:
  194. import time
  195. time.sleep(2)
  196. continue
  197. return False
  198. return "running" in stdout
  199. return False
  200. def get_remote_stderr(job_id: str) -> str | None:
  201. """读取远程训练的 stderr 日志(/tmp/train_{job_id}.log)。
  202. 用于 jsonl 日志未写入时兜底获取错误信息。
  203. """
  204. log_path = f"/tmp/train_{job_id}.log"
  205. cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'tail -100 {log_path} 2>/dev/null'"
  206. code, stdout, stderr = ssh_exec(cmd, timeout=30)
  207. if code == 0 and stdout.strip():
  208. return stdout.strip()
  209. return None
  210. def run_inference_remote(
  211. model_id: str,
  212. adapter_id: str,
  213. prompt: str,
  214. max_new_tokens: int,
  215. temperature: float,
  216. top_p: float,
  217. repetition_penalty: float,
  218. do_sample: bool,
  219. ) -> dict[str, Any] | None:
  220. """在算力节点执行推理。"""
  221. safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
  222. remote_cmd = (
  223. f"docker exec "
  224. f"-e MACA_MPS_MODE=1 "
  225. f"-e CUDA_VISIBLE_DEVICES={settings.inference_cuda_devices} "
  226. f"-w {settings.compute_node_workdir} "
  227. f"{settings.compute_node_docker_container} "
  228. f"{settings.compute_node_python} -c \""
  229. "import asyncio, json; "
  230. "from app.config import get_settings; "
  231. "settings = get_settings(); "
  232. "from app.services.inference_service import run_inference_single; "
  233. f"result = asyncio.run(run_inference_single("
  234. f"'{model_id}', '{adapter_id}', '{safe_prompt}', "
  235. f"{max_new_tokens}, {temperature}, {top_p}, {repetition_penalty}, {do_sample}"
  236. ")); "
  237. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  238. )
  239. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  240. logger.info(f"Remote inference SSH result: code={code}, stdout={stdout.strip()}, stderr={stderr.strip()[:500]}")
  241. if code != 0:
  242. logger.error(f"Remote inference failed: stderr={stderr}, stdout={stdout}")
  243. return {"error": stderr.strip() or stdout.strip() or "Remote inference failed"}
  244. # 提取最后一行 JSON
  245. for line in reversed(stdout.strip().split("\n")):
  246. line = line.strip()
  247. if line.startswith("{"):
  248. try:
  249. return json.loads(line)
  250. except json.JSONDecodeError:
  251. continue
  252. return {"error": f"Invalid JSON response: {stdout[:500]}"}