remote_executor.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. """SSH 远程执行模块 — 在算力节点上运行 GPU 任务。"""
  2. import json
  3. import os
  4. import subprocess
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.logging import logger
  8. settings = get_settings()
  9. def _get_ssh_prefix() -> list[str]:
  10. """构建 ssh/scp 命令前缀,支持密钥或密码登录。"""
  11. prefix = ["-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10"]
  12. if settings.compute_node_ssh_key:
  13. prefix.extend(["-i", settings.compute_node_ssh_key])
  14. elif settings.compute_node_ssh_password:
  15. prefix = ["sshpass", "-p", settings.compute_node_ssh_password] + prefix
  16. return prefix
  17. def ssh_exec(cmd: str, timeout: int | None = None) -> tuple[int, str, str]:
  18. """通过 SSH 在算力节点执行命令,返回 (exit_code, stdout, stderr)。"""
  19. if not settings.use_remote_compute:
  20. raise RuntimeError("未配置算力节点(compute_node_host 为空)")
  21. target = f"{settings.compute_node_ssh_user}@{settings.compute_node_host}"
  22. ssh_cmd = [
  23. "ssh", *_get_ssh_prefix(),
  24. "-p", str(settings.compute_node_ssh_port),
  25. target,
  26. cmd,
  27. ]
  28. timeout = timeout or settings.compute_node_ssh_timeout
  29. try:
  30. proc = subprocess.run(
  31. ssh_cmd,
  32. capture_output=True,
  33. text=True,
  34. timeout=timeout,
  35. )
  36. return proc.returncode, proc.stdout, proc.stderr
  37. except subprocess.TimeoutExpired:
  38. logger.error(f"SSH command timeout after {timeout}s: {cmd[:100]}")
  39. return -1, "", f"Command timed out after {timeout}s"
  40. except Exception as e:
  41. logger.error(f"SSH exec failed: {e}")
  42. return -1, "", str(e)
  43. def run_training_remote(
  44. job_id: str,
  45. model_id: str,
  46. model_type: str,
  47. dataset_id: str,
  48. config: dict[str, Any],
  49. ) -> bool:
  50. """在算力节点启动训练任务(后台执行,不阻塞)。
  51. 使用 nohup + & 让训练在后台运行,通过 WebSocket 回传进度。
  52. """
  53. config_json = json.dumps(config, ensure_ascii=False)
  54. # 转义双引号避免 shell 解析问题
  55. config_escaped = config_json.replace('"', '\\"')
  56. log_path = os.path.join(settings.compute_node_workdir, f"logs/{job_id}.log")
  57. log_dir = os.path.dirname(log_path)
  58. remote_cmd = (
  59. f"mkdir -p {log_dir} && "
  60. f"cd {settings.compute_node_workdir} && "
  61. f"nohup {settings.compute_node_python} -m app.engines.remote_train "
  62. f"'{job_id}' '{model_id}' '{model_type}' '{dataset_id}' '{config_escaped}' "
  63. f"> {log_path} 2>&1 & echo $!"
  64. )
  65. code, stdout, stderr = ssh_exec(remote_cmd, timeout=30)
  66. if code != 0:
  67. logger.error(f"Remote training launch failed: {stderr}")
  68. return False
  69. logger.info(f"Remote training launched: job={job_id}, pid={stdout.strip()}")
  70. return True
  71. def run_inference_remote(
  72. model_id: str,
  73. adapter_id: str,
  74. prompt: str,
  75. max_new_tokens: int,
  76. temperature: float,
  77. top_p: float,
  78. repetition_penalty: float,
  79. do_sample: bool,
  80. ) -> dict[str, Any] | None:
  81. """在算力节点执行推理。"""
  82. safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
  83. remote_cmd = (
  84. f"cd {settings.compute_node_workdir} && "
  85. f"{settings.compute_node_python} -c \""
  86. "import asyncio, json; "
  87. "from app.config import get_settings; "
  88. "settings = get_settings(); "
  89. "from app.services.inference_service import run_inference_single; "
  90. f"result = asyncio.run(run_inference_single("
  91. f"'{model_id}', '{adapter_id}', '{safe_prompt}', "
  92. f"{max_new_tokens}, {temperature}, {top_p}, {repetition_penalty}, {str(do_sample).lower()}"
  93. ")); "
  94. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  95. )
  96. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  97. if code != 0:
  98. logger.error(f"Remote inference failed: {stderr}")
  99. return {"error": stderr.strip() or "Remote inference failed"}
  100. # 提取最后一行 JSON
  101. for line in reversed(stdout.strip().split("\n")):
  102. line = line.strip()
  103. if line.startswith("{"):
  104. try:
  105. return json.loads(line)
  106. except json.JSONDecodeError:
  107. continue
  108. return {"error": f"Invalid JSON response: {stdout[:500]}"}