deploy_service.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. """部署服务 —— 导出模型 / 部署为在线推理服务。
  2. 架构:
  3. - 253 算力节点运行轻量 inference_worker.py(纯 stdlib + torch/transformers,不需要 fastapi/uvicorn)
  4. - 151 主节点对外提供 OpenAI 兼容代理 API,通过 TCP 转发请求到 253
  5. """
  6. import asyncio
  7. import json
  8. import socket
  9. import struct
  10. import uuid
  11. from datetime import datetime, timezone
  12. from pathlib import Path
  13. from typing import Any
  14. from app.config import get_settings
  15. from app.core.background_tasks import background_task_manager
  16. from app.core.db import async_session, DeployTaskModel
  17. from app.core.logging import logger
  18. from app.core.remote_executor import ssh_exec
  19. from sqlalchemy import select
  20. settings = get_settings()
  21. # 253 上 worker 的 TCP 端口范围
  22. _SERVE_PORT_MIN = 8100
  23. _SERVE_PORT_MAX = 8199
  24. # ---------------------------------------------------------------------------
  25. # TCP 代理:151 → 253 inference_worker
  26. # ---------------------------------------------------------------------------
  27. async def proxy_to_worker(task_id: str, request: dict) -> dict:
  28. """通过 TCP 把推理请求转发到 253 的 inference_worker,返回响应。
  29. 协议:4 字节大端长度前缀 + JSON body
  30. """
  31. # 查 DB 获取 worker 监听的端口
  32. async with async_session() as session:
  33. result = await session.execute(
  34. select(DeployTaskModel).where(DeployTaskModel.id == task_id)
  35. )
  36. record = result.scalar_one_or_none()
  37. if not record:
  38. return {"error": "部署任务不存在"}
  39. if record.status != "running":
  40. return {"error": f"服务未运行(当前状态: {record.status})"}
  41. port = record.port
  42. if not port:
  43. return {"error": "未找到 worker 端口"}
  44. # 通过 asyncio 在线程池中执行同步 TCP 操作
  45. return await asyncio.to_thread(_tcp_request, settings.compute_node_host, port, request)
  46. def _tcp_request(host: str, port: int, request: dict) -> dict:
  47. """同步 TCP 请求:连接到 worker,发送请求,接收响应。"""
  48. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  49. sock.settimeout(120) # 推理可能耗时较长
  50. try:
  51. sock.connect((host, port))
  52. # 发送:4 字节长度 + JSON
  53. body = json.dumps(request, ensure_ascii=False).encode("utf-8")
  54. sock.sendall(struct.pack(">I", len(body)))
  55. sock.sendall(body)
  56. # 接收:4 字节长度 + JSON
  57. len_data = _recv_exact(sock, 4)
  58. resp_len = struct.unpack(">I", len_data)[0]
  59. resp_data = _recv_exact(sock, resp_len)
  60. return json.loads(resp_data.decode("utf-8"))
  61. except socket.timeout:
  62. return {"error": "推理超时(120s)"}
  63. except ConnectionRefusedError:
  64. return {"error": f"无法连接到推理 worker({host}:{port}),服务可能已停止"}
  65. except Exception as e:
  66. return {"error": f"代理请求失败: {e}"}
  67. finally:
  68. sock.close()
  69. def _recv_exact(sock: socket.socket, n: int) -> bytes:
  70. """确保接收恰好 n 字节。"""
  71. buf = bytearray()
  72. while len(buf) < n:
  73. chunk = sock.recv(n - len(buf))
  74. if not chunk:
  75. raise ConnectionError("Connection closed while reading")
  76. buf.extend(chunk)
  77. return bytes(buf)
  78. # ---------------------------------------------------------------------------
  79. # 导出 Adapter(导出文件模式)
  80. # ---------------------------------------------------------------------------
  81. async def export_adapter(job_id: str, config: dict[str, Any], user_id: str = "") -> dict[str, Any]:
  82. """启动导出后台任务,立即返回 task_id。"""
  83. task_id = str(uuid.uuid4())
  84. merge_with_base = config.get("merge_with_base", False)
  85. export_format = config.get("export_format", "safetensors")
  86. task = DeployTaskModel(
  87. id=task_id,
  88. job_id=job_id,
  89. user_id=user_id or None,
  90. status="pending",
  91. deploy_mode="export",
  92. )
  93. async with async_session() as session:
  94. session.add(task)
  95. await session.commit()
  96. background_task_manager.register_task(task_id, "deployment", {"job_id": job_id})
  97. await background_task_manager.run(
  98. task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format)
  99. )
  100. logger.info(f"Deploy task started: job={job_id} (task_id={task_id})")
  101. return {"task_id": task_id, "job_id": job_id, "status": "pending", "deploy_mode": "export"}
  102. async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
  103. """后台执行导出。"""
  104. try:
  105. if settings.use_remote_compute:
  106. result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
  107. else:
  108. result = await _run_local_export(task_id, job_id, merge_with_base)
  109. output_path = result.get("output_path")
  110. # 把 inference_worker.py 和启动脚本复制到输出目录
  111. if output_path and settings.use_remote_compute:
  112. await _copy_worker_template_remote(output_path)
  113. await _update_deploy_status(task_id, "completed", output_path=output_path)
  114. return {"output_path": output_path}
  115. except Exception as e:
  116. logger.error(f"Export failed for job {job_id}: {e}")
  117. await _update_deploy_status(task_id, "failed", error=str(e))
  118. return {"error": str(e)}
  119. # ---------------------------------------------------------------------------
  120. # 部署为在线服务(serve 模式)
  121. # ---------------------------------------------------------------------------
  122. async def start_serving(job_id: str, config: dict[str, Any], user_id: str = "") -> dict[str, Any]:
  123. """部署为在线推理服务,151 代理对外,253 worker 做推理。"""
  124. task_id = str(uuid.uuid4())
  125. merge_with_base = config.get("merge_with_base", True)
  126. port = config.get("port")
  127. if not port:
  128. port = await _allocate_port()
  129. task = DeployTaskModel(
  130. id=task_id,
  131. job_id=job_id,
  132. user_id=user_id or None,
  133. status="pending",
  134. deploy_mode="serve",
  135. port=port,
  136. )
  137. async with async_session() as session:
  138. session.add(task)
  139. await session.commit()
  140. background_task_manager.register_task(task_id, "deployment", {"job_id": job_id, "mode": "serve"})
  141. await background_task_manager.run(
  142. task_id, "deployment", _execute_serve(task_id, job_id, merge_with_base, port)
  143. )
  144. logger.info(f"Serve task started: job={job_id} port={port} (task_id={task_id})")
  145. return {"task_id": task_id, "job_id": job_id, "status": "pending", "deploy_mode": "serve", "port": port}
  146. async def _execute_serve(task_id: str, job_id: str, merge_with_base: bool, port: int) -> dict:
  147. """后台执行:导出模型 → 复制 worker → 启动 TCP 推理 worker。"""
  148. try:
  149. # 第一步:导出(合并 adapter)
  150. if settings.use_remote_compute:
  151. export_result = await _run_remote_export(task_id, job_id, merge_with_base, "safetensors")
  152. output_path = export_result.get("output_path")
  153. else:
  154. export_result = await _run_local_export(task_id, job_id, merge_with_base)
  155. output_path = export_result.get("output_path")
  156. if not output_path:
  157. raise RuntimeError("导出失败,无法获取输出路径")
  158. # 第二步:启动推理 worker
  159. if settings.use_remote_compute:
  160. pid = await _launch_remote_worker(task_id, output_path, port)
  161. else:
  162. pid = await _launch_local_worker(task_id, output_path, port)
  163. # endpoint_url 是 151 上的代理路径(相对路径,前端拼接 origin)
  164. endpoint_url = f"/api/v1/deployment/proxy/{task_id}/v1"
  165. await _update_deploy_status(
  166. task_id, "running",
  167. output_path=output_path,
  168. endpoint_url=endpoint_url,
  169. port=port,
  170. pid=pid,
  171. )
  172. return {"endpoint_url": endpoint_url, "port": port, "pid": pid}
  173. except Exception as e:
  174. logger.error(f"Serve failed for job {job_id}: {e}")
  175. await _update_deploy_status(task_id, "failed", error=str(e))
  176. return {"error": str(e)}
  177. async def _launch_remote_worker(task_id: str, model_path: str, port: int) -> str:
  178. """在远程 253 容器里启动 inference_worker.py,返回进程 PID。
  179. 只依赖 torch + transformers(不需要 fastapi/uvicorn)。
  180. """
  181. # 启动前先清理端口占用,确保不会有旧进程残留
  182. # 253 容器内子进程多,docker exec 执行较慢,给足超时
  183. kill_cmd = (
  184. f"docker exec {settings.compute_node_docker_container} "
  185. f"bash -c 'fuser -k {port}/tcp 2>/dev/null; sleep 2; fuser -k {port}/tcp 2>/dev/null; sleep 1; true'"
  186. )
  187. await asyncio.to_thread(ssh_exec, kill_cmd, timeout=60)
  188. # worker 脚本在容器内的路径
  189. worker_template = f"{settings.compute_node_workdir}/app/core/inference_worker.py"
  190. # 复制 worker 到模型目录
  191. copy_cmd = (
  192. f"docker exec {settings.compute_node_docker_container} "
  193. f"bash -c 'cp {worker_template} {model_path}/inference_worker.py'"
  194. )
  195. code, _, stderr = await asyncio.to_thread(ssh_exec, copy_cmd, timeout=30)
  196. if code != 0:
  197. raise RuntimeError(f"复制 inference_worker.py 失败: {stderr}")
  198. # 在容器内后台启动 worker(多卡推理:CUDA_VISIBLE_DEVICES 使用配置项)
  199. launch_cmd = (
  200. f"docker exec "
  201. f"-e MACA_MPS_MODE=1 "
  202. f"-e CUDA_VISIBLE_DEVICES={settings.inference_cuda_devices} "
  203. f"-w {model_path} "
  204. f"{settings.compute_node_docker_container} "
  205. f"bash -c '"
  206. f"{settings.compute_node_python} inference_worker.py "
  207. f"--model-path {model_path} "
  208. f"--port {port} "
  209. f"</dev/null >/tmp/serve_{task_id}.log 2>&1 &"
  210. f" echo $!'"
  211. )
  212. code, stdout, stderr = await asyncio.to_thread(ssh_exec, launch_cmd, timeout=60)
  213. if code != 0:
  214. raise RuntimeError(f"启动推理 worker 失败: {stderr}")
  215. pid = stdout.strip()
  216. logger.info(f"Remote worker launched: task={task_id} port={port} pid={pid}")
  217. # 等待模型加载(可能需要较长时间),检查 READY 标记
  218. # 每次轮询只用一次 SSH 连接,同时检查 READY 和进程状态
  219. for attempt in range(60): # 最多等 5 分钟(60 * 5s)
  220. await asyncio.sleep(5)
  221. check_cmd = (
  222. f"docker exec {settings.compute_node_docker_container} "
  223. f"bash -c '"
  224. f" ready=$(grep -c READY /tmp/serve_{task_id}.log 2>/dev/null || echo 0); "
  225. f" if [ \"$ready\" != \"0\" ]; then echo \"READY:$ready\"; exit 0; fi; "
  226. f" if ! kill -0 {pid} 2>/dev/null; then echo \"DEAD\"; exit 0; fi; "
  227. f" echo \"ALIVE\"; "
  228. f"'"
  229. )
  230. code, stdout, stderr = await asyncio.to_thread(ssh_exec, check_cmd, timeout=60)
  231. if code == 0:
  232. result = stdout.strip()
  233. if result.startswith("READY:"):
  234. logger.info(f"Worker ready: task={task_id} (after ~{(attempt+1)*5}s)")
  235. # 校验实际占用端口的 PID(防止 stop 没杀干净旧进程导致 PID 对不上)
  236. actual_pid = await _get_port_pid(port)
  237. if actual_pid and actual_pid != pid:
  238. logger.warning(f"Port {port} PID mismatch: launched={pid}, actual={actual_pid}")
  239. pid = actual_pid
  240. return pid
  241. elif result == "DEAD":
  242. # 读取日志看什么错了
  243. log_cmd = (
  244. f"docker exec {settings.compute_node_docker_container} "
  245. f"bash -c 'tail -20 /tmp/serve_{task_id}.log 2>/dev/null'"
  246. )
  247. _, log_stdout, _ = await asyncio.to_thread(ssh_exec, log_cmd, timeout=60)
  248. raise RuntimeError(f"Worker 进程已退出: {log_stdout}")
  249. # result == "ALIVE" → 继续等待
  250. logger.warning(f"Worker not ready after 5min: task={task_id}, proceeding anyway")
  251. return pid
  252. async def _get_port_pid(port: int) -> str | None:
  253. """获取远程容器内占用指定端口的进程 PID。"""
  254. cmd = (
  255. f"docker exec {settings.compute_node_docker_container} "
  256. f"bash -c 'fuser {port}/tcp 2>/dev/null'"
  257. )
  258. code, stdout, _ = await asyncio.to_thread(ssh_exec, cmd, timeout=60)
  259. if code == 0 and stdout.strip():
  260. # fuser 输出格式可能是 "8100/tcp: 372" 或直接 " 372"
  261. parts = stdout.strip().split()
  262. for p in reversed(parts):
  263. if p.isdigit():
  264. return p
  265. return None
  266. async def _launch_local_worker(task_id: str, model_path: str, port: int) -> str:
  267. """在本地启动推理 worker(开发用)。"""
  268. import subprocess
  269. import shutil
  270. import sys
  271. worker_src = Path(__file__).resolve().parent.parent / "core" / "inference_worker.py"
  272. shutil.copy(worker_src, Path(model_path) / "inference_worker.py")
  273. proc = subprocess.Popen(
  274. [sys.executable, "inference_worker.py", "--model-path", model_path, "--port", str(port)],
  275. cwd=model_path,
  276. stdout=subprocess.DEVNULL,
  277. stderr=subprocess.DEVNULL,
  278. )
  279. return str(proc.pid)
  280. # ---------------------------------------------------------------------------
  281. # 停止服务 / 列表 / 状态
  282. # ---------------------------------------------------------------------------
  283. async def stop_serving(task_id: str, user_id: str = "") -> dict[str, Any]:
  284. """停止已部署的在线服务。"""
  285. async with async_session() as session:
  286. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  287. record = result.scalar_one_or_none()
  288. if not record:
  289. return {"error": "任务不存在"}
  290. if record.deploy_mode != "serve":
  291. return {"error": "该任务不是在线服务"}
  292. if user_id and record.user_id and record.user_id != user_id:
  293. return {"error": "无权操作此任务"}
  294. pid = record.pid
  295. port = record.port
  296. output_path = record.output_path
  297. if pid and settings.use_remote_compute:
  298. # 方式1: kill -9 主进程及其子进程
  299. # 方式2: fuser 直接杀占用端口的进程(最可靠,防止 PID 对不上)
  300. kill_cmd = (
  301. f"docker exec {settings.compute_node_docker_container} "
  302. f"bash -c '"
  303. f"kill -9 {pid} 2>/dev/null; "
  304. f"pkill -9 -P {pid} 2>/dev/null; "
  305. f"fuser -k {port}/tcp 2>/dev/null; "
  306. f"sleep 2; "
  307. f"fuser -k {port}/tcp 2>/dev/null; "
  308. f"true'"
  309. )
  310. code, _, _ = await asyncio.to_thread(ssh_exec, kill_cmd, timeout=60)
  311. logger.info(f"Stop serving: task={task_id} pid={pid} port={port} kill_code={code}")
  312. record.status = "stopped"
  313. record.pid = None
  314. record.finished_at = datetime.utcnow()
  315. await session.commit()
  316. background_task_manager.update_task(task_id, status="stopped")
  317. return {"task_id": task_id, "status": "stopped"}
  318. async def restart_serving(task_id: str, user_id: str = "") -> dict[str, Any]:
  319. """重启已停止的在线服务(不重新导出模型,只启动 worker)。"""
  320. async with async_session() as session:
  321. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  322. record = result.scalar_one_or_none()
  323. if not record:
  324. return {"error": "任务不存在"}
  325. if record.deploy_mode != "serve":
  326. return {"error": "该任务不是在线服务"}
  327. if record.status != "stopped":
  328. return {"error": f"只能重启已停止的服务(当前状态: {record.status})"}
  329. if user_id and record.user_id and record.user_id != user_id:
  330. return {"error": "无权操作此任务"}
  331. if not record.output_path:
  332. return {"error": "模型文件路径丢失,无法重启,请重新部署"}
  333. output_path = record.output_path
  334. original_port = record.port # 记住原端口,尽量复用
  335. # 优先复用原端口,如果被其他 pending/running 服务占了才重新分配
  336. if original_port:
  337. port_available = await _check_port_available(original_port)
  338. port = original_port if port_available else await _allocate_port()
  339. else:
  340. port = await _allocate_port()
  341. # 更新状态为 pending,标记正在重启
  342. await _update_deploy_status(task_id, "pending", port=port)
  343. background_task_manager.register_task(task_id, "deployment", {"mode": "restart"})
  344. await background_task_manager.run(
  345. task_id, "deployment", _execute_restart(task_id, output_path, port)
  346. )
  347. logger.info(f"Restart serving: task={task_id} output_path={output_path} port={port}")
  348. return {"task_id": task_id, "status": "pending", "deploy_mode": "serve", "port": port}
  349. async def _execute_restart(task_id: str, output_path: str, port: int) -> dict:
  350. """后台执行重启:只启动 worker,不重新导出。"""
  351. try:
  352. if settings.use_remote_compute:
  353. pid = await _launch_remote_worker(task_id, output_path, port)
  354. else:
  355. pid = await _launch_local_worker(task_id, output_path, port)
  356. endpoint_url = f"/api/v1/deployment/proxy/{task_id}/v1"
  357. await _update_deploy_status(
  358. task_id, "running",
  359. output_path=output_path,
  360. endpoint_url=endpoint_url,
  361. port=port,
  362. pid=pid,
  363. )
  364. return {"endpoint_url": endpoint_url, "port": port, "pid": pid}
  365. except Exception as e:
  366. logger.error(f"Restart failed for task {task_id}: {e}")
  367. await _update_deploy_status(task_id, "failed", error=str(e))
  368. return {"error": str(e)}
  369. async def list_deployed_services(user_id: str = "") -> list[dict[str, Any]]:
  370. """列出 serve 模式的部署任务(按用户过滤)。"""
  371. async with async_session() as session:
  372. query = select(DeployTaskModel).where(DeployTaskModel.deploy_mode == "serve")
  373. if user_id:
  374. query = query.where(DeployTaskModel.user_id == user_id)
  375. query = query.order_by(DeployTaskModel.created_at.desc())
  376. result = await session.execute(query)
  377. records = result.scalars().all()
  378. services = []
  379. for r in records:
  380. status = r.status
  381. # 对 running 状态,检查远程进程是否还活着
  382. if status == "running" and r.pid and settings.use_remote_compute:
  383. from app.core.remote_executor import is_process_running
  384. proc_state = await asyncio.to_thread(is_process_running, r.pid)
  385. if proc_state == "stopped":
  386. # 确认进程已退出,标记为 stopped
  387. status = "stopped"
  388. await _update_deploy_status(r.id, "stopped", error="进程已退出")
  389. r.port = None
  390. r.pid = None
  391. # proc_state == "unknown" 时不改状态(SSH 超时不代表进程死了)
  392. services.append({
  393. "task_id": r.id,
  394. "job_id": r.job_id,
  395. "status": status,
  396. "endpoint_url": r.endpoint_url,
  397. "base_url": r.endpoint_url,
  398. "port": r.port,
  399. "output_path": r.output_path,
  400. "created_at": r.created_at.isoformat() if r.created_at else None,
  401. "error": r.error,
  402. })
  403. return services
  404. async def get_deploy_status(task_id: str) -> dict[str, Any]:
  405. """获取部署任务状态。"""
  406. async with async_session() as session:
  407. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  408. record = result.scalar_one_or_none()
  409. if record:
  410. return {
  411. "task_id": record.id,
  412. "job_id": record.job_id,
  413. "status": record.status,
  414. "deploy_mode": record.deploy_mode or "export",
  415. "progress": record.progress,
  416. "output_path": record.output_path,
  417. "endpoint_url": record.endpoint_url,
  418. "port": record.port,
  419. "error": record.error,
  420. }
  421. return {"task_id": None, "job_id": "", "status": "not_found", "deploy_mode": "export",
  422. "progress": 0.0, "output_path": None, "endpoint_url": None, "port": None, "error": None}
  423. # ---------------------------------------------------------------------------
  424. # 辅助函数
  425. # ---------------------------------------------------------------------------
  426. async def _allocate_port() -> int:
  427. """从端口池里分配一个未使用的端口。"""
  428. async with async_session() as session:
  429. result = await session.execute(
  430. select(DeployTaskModel.port).where(
  431. DeployTaskModel.deploy_mode == "serve",
  432. DeployTaskModel.status.in_(["pending", "running"]),
  433. DeployTaskModel.port.isnot(None),
  434. )
  435. )
  436. used = {row[0] for row in result.all()}
  437. for port in range(_SERVE_PORT_MIN, _SERVE_PORT_MAX + 1):
  438. if port not in used:
  439. return port
  440. raise RuntimeError(f"无可用端口({_SERVE_PORT_MIN}-{_SERVE_PORT_MAX} 全部占用)")
  441. async def _check_port_available(port: int) -> bool:
  442. """检查指定端口是否可被复用(没有被其他 pending/running 服务占用)。"""
  443. async with async_session() as session:
  444. result = await session.execute(
  445. select(DeployTaskModel.id).where(
  446. DeployTaskModel.deploy_mode == "serve",
  447. DeployTaskModel.status.in_(["pending", "running"]),
  448. DeployTaskModel.port == port,
  449. )
  450. )
  451. return result.first() is None
  452. async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
  453. """通过 SSH 在远程容器执行模型合并/导出。"""
  454. remote_cmd = (
  455. f"docker exec "
  456. f"-e MACA_MPS_MODE=1 "
  457. f"-e CUDA_VISIBLE_DEVICES=3 "
  458. f"-w {settings.compute_node_workdir} "
  459. f"{settings.compute_node_docker_container} "
  460. f"{settings.compute_node_python} -c \""
  461. "import asyncio, json; "
  462. "from app.core.remote_deploy import run_remote_export; "
  463. f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); "
  464. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  465. )
  466. code, stdout, stderr = await asyncio.to_thread(ssh_exec, remote_cmd, timeout=600)
  467. if code != 0:
  468. raise RuntimeError(f"Remote export failed: {stderr}")
  469. for line in reversed(stdout.strip().split("\n")):
  470. line = line.strip()
  471. if line.startswith("{"):
  472. try:
  473. result = json.loads(line)
  474. if "error" in result:
  475. raise RuntimeError(result["error"])
  476. return result
  477. except json.JSONDecodeError:
  478. continue
  479. raise RuntimeError(f"Invalid response: {stdout[:500]}")
  480. async def _run_local_export(task_id: str, job_id: str, merge_with_base: bool) -> dict:
  481. """本地执行导出(开发用)。"""
  482. adapter_path = settings.adapters_dir / job_id
  483. if not adapter_path.exists():
  484. raise ValueError("Adapter not found")
  485. output_path = settings.adapters_dir / f"{job_id}_merged"
  486. if merge_with_base:
  487. import torch
  488. from transformers import AutoModelForCausalLM, AutoTokenizer
  489. base_model_id = _get_base_model_id_local(job_id)
  490. if base_model_id:
  491. from peft import PeftModel
  492. base_model = AutoModelForCausalLM.from_pretrained(
  493. base_model_id, torch_dtype=torch.float16, device_map="auto"
  494. )
  495. peft_model = PeftModel.from_pretrained(base_model, adapter_path)
  496. merged = peft_model.merge_and_unload()
  497. merged.save_pretrained(output_path)
  498. tokenizer = AutoTokenizer.from_pretrained(adapter_path)
  499. tokenizer.save_pretrained(output_path)
  500. else:
  501. from peft import PeftModel
  502. merged = PeftModel.from_pretrained(
  503. AutoModelForCausalLM.from_pretrained(
  504. str(adapter_path), torch_dtype=torch.float16
  505. ),
  506. str(adapter_path),
  507. )
  508. merged = merged.merge_and_unload()
  509. merged.save_pretrained(output_path)
  510. tokenizer = AutoTokenizer.from_pretrained(adapter_path)
  511. tokenizer.save_pretrained(output_path)
  512. else:
  513. import shutil
  514. if output_path.exists():
  515. shutil.rmtree(output_path)
  516. shutil.copytree(adapter_path, output_path)
  517. return {"output_path": str(output_path)}
  518. async def _copy_worker_template_remote(output_path: str):
  519. """把 inference_worker.py 和启动脚本复制到远程模型目录。"""
  520. worker_template = f"{settings.compute_node_workdir}/app/core/inference_worker.py"
  521. copy_cmd = (
  522. f"docker exec {settings.compute_node_docker_container} "
  523. f"bash -c 'cp {worker_template} {output_path}/inference_worker.py'"
  524. )
  525. code, _, stderr = await asyncio.to_thread(ssh_exec, copy_cmd, timeout=30)
  526. if code != 0:
  527. logger.warning(f"复制 inference_worker.py 到 {output_path} 失败: {stderr}")
  528. # 生成快捷启动脚本
  529. start_script = (
  530. f"#!/bin/bash\n"
  531. f"cd {output_path}\n"
  532. f"CUDA_VISIBLE_DEVICES={settings.inference_cuda_devices} MACA_MPS_MODE=1 "
  533. f"{settings.compute_node_python} inference_worker.py "
  534. f"--model-path . --port 8100\n"
  535. )
  536. script_cmd = (
  537. f"docker exec {settings.compute_node_docker_container} "
  538. f"bash -c 'cat > {output_path}/start.sh << \"EOF\"\n{start_script}EOF\n"
  539. f"chmod +x {output_path}/start.sh'"
  540. )
  541. code, _, _ = await asyncio.to_thread(ssh_exec, script_cmd, timeout=15)
  542. if code != 0:
  543. logger.warning(f"生成 start.sh 失败")
  544. def _get_base_model_id_local(job_id: str):
  545. config_path = settings.adapters_dir / job_id / "adapter_config.json"
  546. if config_path.exists():
  547. with open(config_path) as f:
  548. return json.load(f).get("base_model_name_or_path")
  549. return None
  550. async def _update_deploy_status(
  551. task_id: str, status: str,
  552. output_path: str = None, error: str = None,
  553. endpoint_url: str = None, port: int = None, pid: str = None,
  554. ):
  555. async with async_session() as session:
  556. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  557. record = result.scalar_one_or_none()
  558. if record:
  559. record.status = status
  560. if output_path:
  561. record.output_path = output_path
  562. if error:
  563. record.error = error
  564. if endpoint_url:
  565. record.endpoint_url = endpoint_url
  566. if port:
  567. record.port = port
  568. if pid:
  569. record.pid = pid
  570. if status in ("completed", "failed", "stopped"):
  571. record.finished_at = datetime.utcnow()
  572. if status == "pending":
  573. # 重启时清除完成时间和错误信息
  574. record.finished_at = None
  575. record.error = None
  576. await session.commit()
  577. background_task_manager.update_task(
  578. task_id, status=status, output_path=output_path, error=error,
  579. endpoint_url=endpoint_url,
  580. )
  581. async def recover_stale_deploys() -> None:
  582. async with async_session() as session:
  583. result = await session.execute(
  584. select(DeployTaskModel).where(
  585. DeployTaskModel.status.in_(["pending", "running"])
  586. )
  587. )
  588. records = result.scalars().all()
  589. for record in records:
  590. if record.deploy_mode == "export":
  591. record.status = "failed"
  592. record.error = "Server restarted, task interrupted"
  593. elif record.deploy_mode == "serve":
  594. if record.pid and settings.use_remote_compute:
  595. from app.core.remote_executor import is_process_running
  596. proc_state = await asyncio.to_thread(is_process_running, record.pid)
  597. if proc_state == "stopped":
  598. record.status = "stopped"
  599. record.error = "Server restarted, process no longer running"
  600. else:
  601. continue # 进程还在或无法确认,保持 running
  602. else:
  603. record.status = "stopped"
  604. record.error = "Server restarted, process state unknown"
  605. # 释放端口,确保下次分配时可用
  606. if record.status == "stopped":
  607. record.port = None
  608. record.pid = None
  609. record.finished_at = datetime.utcnow()
  610. if records:
  611. await session.commit()
  612. logger.info(f"Recovered {len(records)} stale deploy tasks")