| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723 |
- """部署服务 —— 导出模型 / 部署为在线推理服务。
- 架构:
- - 253 算力节点运行轻量 inference_worker.py(纯 stdlib + torch/transformers,不需要 fastapi/uvicorn)
- - 151 主节点对外提供 OpenAI 兼容代理 API,通过 TCP 转发请求到 253
- """
- import asyncio
- import json
- import socket
- import struct
- import uuid
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.background_tasks import background_task_manager
- from app.core.db import async_session, DeployTaskModel
- from app.core.logging import logger
- from app.core.remote_executor import ssh_exec
- from sqlalchemy import select
- settings = get_settings()
- # 253 上 worker 的 TCP 端口范围
- _SERVE_PORT_MIN = 8100
- _SERVE_PORT_MAX = 8199
- # ---------------------------------------------------------------------------
- # TCP 代理:151 → 253 inference_worker
- # ---------------------------------------------------------------------------
- async def proxy_to_worker(task_id: str, request: dict) -> dict:
- """通过 TCP 把推理请求转发到 253 的 inference_worker,返回响应。
- 协议:4 字节大端长度前缀 + JSON body
- """
- # 查 DB 获取 worker 监听的端口
- async with async_session() as session:
- result = await session.execute(
- select(DeployTaskModel).where(DeployTaskModel.id == task_id)
- )
- record = result.scalar_one_or_none()
- if not record:
- return {"error": "部署任务不存在"}
- if record.status != "running":
- return {"error": f"服务未运行(当前状态: {record.status})"}
- port = record.port
- if not port:
- return {"error": "未找到 worker 端口"}
- # 通过 asyncio 在线程池中执行同步 TCP 操作
- return await asyncio.to_thread(_tcp_request, settings.compute_node_host, port, request)
- def _tcp_request(host: str, port: int, request: dict) -> dict:
- """同步 TCP 请求:连接到 worker,发送请求,接收响应。"""
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.settimeout(120) # 推理可能耗时较长
- try:
- sock.connect((host, port))
- # 发送:4 字节长度 + JSON
- body = json.dumps(request, ensure_ascii=False).encode("utf-8")
- sock.sendall(struct.pack(">I", len(body)))
- sock.sendall(body)
- # 接收:4 字节长度 + JSON
- len_data = _recv_exact(sock, 4)
- resp_len = struct.unpack(">I", len_data)[0]
- resp_data = _recv_exact(sock, resp_len)
- return json.loads(resp_data.decode("utf-8"))
- except socket.timeout:
- return {"error": "推理超时(120s)"}
- except ConnectionRefusedError:
- return {"error": f"无法连接到推理 worker({host}:{port}),服务可能已停止"}
- except Exception as e:
- return {"error": f"代理请求失败: {e}"}
- finally:
- sock.close()
- def _recv_exact(sock: socket.socket, n: int) -> bytes:
- """确保接收恰好 n 字节。"""
- buf = bytearray()
- while len(buf) < n:
- chunk = sock.recv(n - len(buf))
- if not chunk:
- raise ConnectionError("Connection closed while reading")
- buf.extend(chunk)
- return bytes(buf)
- # ---------------------------------------------------------------------------
- # 导出 Adapter(导出文件模式)
- # ---------------------------------------------------------------------------
- async def export_adapter(job_id: str, config: dict[str, Any], user_id: str = "") -> dict[str, Any]:
- """启动导出后台任务,立即返回 task_id。"""
- task_id = str(uuid.uuid4())
- merge_with_base = config.get("merge_with_base", False)
- export_format = config.get("export_format", "safetensors")
- task = DeployTaskModel(
- id=task_id,
- job_id=job_id,
- user_id=user_id or None,
- status="pending",
- deploy_mode="export",
- )
- async with async_session() as session:
- session.add(task)
- await session.commit()
- background_task_manager.register_task(task_id, "deployment", {"job_id": job_id})
- await background_task_manager.run(
- task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format)
- )
- logger.info(f"Deploy task started: job={job_id} (task_id={task_id})")
- return {"task_id": task_id, "job_id": job_id, "status": "pending", "deploy_mode": "export"}
- async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
- """后台执行导出。"""
- try:
- if settings.use_remote_compute:
- result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
- else:
- result = await _run_local_export(task_id, job_id, merge_with_base)
- output_path = result.get("output_path")
- # 把 inference_worker.py 和启动脚本复制到输出目录
- if output_path and settings.use_remote_compute:
- await _copy_worker_template_remote(output_path)
- await _update_deploy_status(task_id, "completed", output_path=output_path)
- return {"output_path": output_path}
- except Exception as e:
- logger.error(f"Export failed for job {job_id}: {e}")
- await _update_deploy_status(task_id, "failed", error=str(e))
- return {"error": str(e)}
- # ---------------------------------------------------------------------------
- # 部署为在线服务(serve 模式)
- # ---------------------------------------------------------------------------
- async def start_serving(job_id: str, config: dict[str, Any], user_id: str = "") -> dict[str, Any]:
- """部署为在线推理服务,151 代理对外,253 worker 做推理。"""
- task_id = str(uuid.uuid4())
- merge_with_base = config.get("merge_with_base", True)
- port = config.get("port")
- if not port:
- port = await _allocate_port()
- task = DeployTaskModel(
- id=task_id,
- job_id=job_id,
- user_id=user_id or None,
- status="pending",
- deploy_mode="serve",
- port=port,
- )
- async with async_session() as session:
- session.add(task)
- await session.commit()
- background_task_manager.register_task(task_id, "deployment", {"job_id": job_id, "mode": "serve"})
- await background_task_manager.run(
- task_id, "deployment", _execute_serve(task_id, job_id, merge_with_base, port)
- )
- logger.info(f"Serve task started: job={job_id} port={port} (task_id={task_id})")
- return {"task_id": task_id, "job_id": job_id, "status": "pending", "deploy_mode": "serve", "port": port}
- async def _execute_serve(task_id: str, job_id: str, merge_with_base: bool, port: int) -> dict:
- """后台执行:导出模型 → 复制 worker → 启动 TCP 推理 worker。"""
- try:
- # 第一步:导出(合并 adapter)
- if settings.use_remote_compute:
- export_result = await _run_remote_export(task_id, job_id, merge_with_base, "safetensors")
- output_path = export_result.get("output_path")
- else:
- export_result = await _run_local_export(task_id, job_id, merge_with_base)
- output_path = export_result.get("output_path")
- if not output_path:
- raise RuntimeError("导出失败,无法获取输出路径")
- # 第二步:启动推理 worker
- if settings.use_remote_compute:
- pid = await _launch_remote_worker(task_id, output_path, port)
- else:
- pid = await _launch_local_worker(task_id, output_path, port)
- # endpoint_url 是 151 上的代理路径(相对路径,前端拼接 origin)
- endpoint_url = f"/api/v1/deployment/proxy/{task_id}/v1"
- await _update_deploy_status(
- task_id, "running",
- output_path=output_path,
- endpoint_url=endpoint_url,
- port=port,
- pid=pid,
- )
- return {"endpoint_url": endpoint_url, "port": port, "pid": pid}
- except Exception as e:
- logger.error(f"Serve failed for job {job_id}: {e}")
- await _update_deploy_status(task_id, "failed", error=str(e))
- return {"error": str(e)}
- async def _launch_remote_worker(task_id: str, model_path: str, port: int) -> str:
- """在远程 253 容器里启动 inference_worker.py,返回进程 PID。
- 只依赖 torch + transformers(不需要 fastapi/uvicorn)。
- """
- # 启动前先清理端口占用,确保不会有旧进程残留
- # 253 容器内子进程多,docker exec 执行较慢,给足超时
- kill_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'fuser -k {port}/tcp 2>/dev/null; sleep 2; fuser -k {port}/tcp 2>/dev/null; sleep 1; true'"
- )
- await asyncio.to_thread(ssh_exec, kill_cmd, timeout=60)
- # worker 脚本在容器内的路径
- worker_template = f"{settings.compute_node_workdir}/app/core/inference_worker.py"
- # 复制 worker 到模型目录
- copy_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'cp {worker_template} {model_path}/inference_worker.py'"
- )
- code, _, stderr = await asyncio.to_thread(ssh_exec, copy_cmd, timeout=30)
- if code != 0:
- raise RuntimeError(f"复制 inference_worker.py 失败: {stderr}")
- # 在容器内后台启动 worker(单卡推理:取 inference_cuda_devices 的第一张 GPU)
- inference_gpu = settings.inference_cuda_devices.split(",")[0].strip()
- launch_cmd = (
- f"docker exec "
- f"-e MACA_MPS_MODE=1 "
- f"-e CUDA_VISIBLE_DEVICES={inference_gpu} "
- f"-w {model_path} "
- f"{settings.compute_node_docker_container} "
- f"bash -c '"
- f"{settings.compute_node_python} inference_worker.py "
- f"--model-path {model_path} "
- f"--port {port} "
- f"</dev/null >/tmp/serve_{task_id}.log 2>&1 &"
- f" echo $!'"
- )
- code, stdout, stderr = await asyncio.to_thread(ssh_exec, launch_cmd, timeout=60)
- if code != 0:
- raise RuntimeError(f"启动推理 worker 失败: {stderr}")
- pid = stdout.strip()
- logger.info(f"Remote worker launched: task={task_id} port={port} pid={pid}")
- # 等待模型加载(可能需要较长时间),检查 READY 标记
- # 每次轮询只用一次 SSH 连接,同时检查 READY 和进程状态
- for attempt in range(60): # 最多等 5 分钟(60 * 5s)
- await asyncio.sleep(5)
- check_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c '"
- f" ready=$(grep -c READY /tmp/serve_{task_id}.log 2>/dev/null || echo 0); "
- f" if [ \"$ready\" != \"0\" ]; then echo \"READY:$ready\"; exit 0; fi; "
- f" if ! kill -0 {pid} 2>/dev/null; then echo \"DEAD\"; exit 0; fi; "
- f" echo \"ALIVE\"; "
- f"'"
- )
- code, stdout, stderr = await asyncio.to_thread(ssh_exec, check_cmd, timeout=60)
- if code == 0:
- result = stdout.strip()
- if result.startswith("READY:"):
- logger.info(f"Worker ready: task={task_id} (after ~{(attempt+1)*5}s)")
- # 校验实际占用端口的 PID(防止 stop 没杀干净旧进程导致 PID 对不上)
- actual_pid = await _get_port_pid(port)
- if actual_pid and actual_pid != pid:
- logger.warning(f"Port {port} PID mismatch: launched={pid}, actual={actual_pid}")
- pid = actual_pid
- return pid
- elif result == "DEAD":
- # 读取日志看什么错了
- log_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'tail -20 /tmp/serve_{task_id}.log 2>/dev/null'"
- )
- _, log_stdout, _ = await asyncio.to_thread(ssh_exec, log_cmd, timeout=60)
- raise RuntimeError(f"Worker 进程已退出: {log_stdout}")
- # result == "ALIVE" → 继续等待
- logger.warning(f"Worker not ready after 5min: task={task_id}, proceeding anyway")
- return pid
- async def _get_port_pid(port: int) -> str | None:
- """获取远程容器内占用指定端口的进程 PID。"""
- cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'fuser {port}/tcp 2>/dev/null'"
- )
- code, stdout, _ = await asyncio.to_thread(ssh_exec, cmd, timeout=60)
- if code == 0 and stdout.strip():
- # fuser 输出格式可能是 "8100/tcp: 372" 或直接 " 372"
- parts = stdout.strip().split()
- for p in reversed(parts):
- if p.isdigit():
- return p
- return None
- async def _launch_local_worker(task_id: str, model_path: str, port: int) -> str:
- """在本地启动推理 worker(开发用)。"""
- import subprocess
- import shutil
- import sys
- worker_src = Path(__file__).resolve().parent.parent / "core" / "inference_worker.py"
- shutil.copy(worker_src, Path(model_path) / "inference_worker.py")
- proc = subprocess.Popen(
- [sys.executable, "inference_worker.py", "--model-path", model_path, "--port", str(port)],
- cwd=model_path,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- return str(proc.pid)
- # ---------------------------------------------------------------------------
- # 停止服务 / 列表 / 状态
- # ---------------------------------------------------------------------------
- async def stop_serving(task_id: str, user_id: str = "") -> dict[str, Any]:
- """停止已部署的在线服务。"""
- async with async_session() as session:
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"error": "任务不存在"}
- if record.deploy_mode != "serve":
- return {"error": "该任务不是在线服务"}
- if user_id and record.user_id and record.user_id != user_id:
- return {"error": "无权操作此任务"}
- pid = record.pid
- port = record.port
- output_path = record.output_path
- if pid and settings.use_remote_compute:
- # 方式1: kill -9 主进程及其子进程
- # 方式2: fuser 直接杀占用端口的进程(最可靠,防止 PID 对不上)
- kill_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c '"
- f"kill -9 {pid} 2>/dev/null; "
- f"pkill -9 -P {pid} 2>/dev/null; "
- f"fuser -k {port}/tcp 2>/dev/null; "
- f"sleep 2; "
- f"fuser -k {port}/tcp 2>/dev/null; "
- f"true'"
- )
- code, _, _ = await asyncio.to_thread(ssh_exec, kill_cmd, timeout=60)
- logger.info(f"Stop serving: task={task_id} pid={pid} port={port} kill_code={code}")
- record.status = "stopped"
- record.pid = None
- record.finished_at = datetime.utcnow()
- await session.commit()
- background_task_manager.update_task(task_id, status="stopped")
- return {"task_id": task_id, "status": "stopped"}
- async def restart_serving(task_id: str, user_id: str = "") -> dict[str, Any]:
- """重启已停止的在线服务(不重新导出模型,只启动 worker)。"""
- async with async_session() as session:
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"error": "任务不存在"}
- if record.deploy_mode != "serve":
- return {"error": "该任务不是在线服务"}
- if record.status != "stopped":
- return {"error": f"只能重启已停止的服务(当前状态: {record.status})"}
- if user_id and record.user_id and record.user_id != user_id:
- return {"error": "无权操作此任务"}
- if not record.output_path:
- return {"error": "模型文件路径丢失,无法重启,请重新部署"}
- output_path = record.output_path
- original_port = record.port # 记住原端口,尽量复用
- # 优先复用原端口,如果被其他 pending/running 服务占了才重新分配
- if original_port:
- port_available = await _check_port_available(original_port)
- port = original_port if port_available else await _allocate_port()
- else:
- port = await _allocate_port()
- # 更新状态为 pending,标记正在重启
- await _update_deploy_status(task_id, "pending", port=port)
- background_task_manager.register_task(task_id, "deployment", {"mode": "restart"})
- await background_task_manager.run(
- task_id, "deployment", _execute_restart(task_id, output_path, port)
- )
- logger.info(f"Restart serving: task={task_id} output_path={output_path} port={port}")
- return {"task_id": task_id, "status": "pending", "deploy_mode": "serve", "port": port}
- async def _execute_restart(task_id: str, output_path: str, port: int) -> dict:
- """后台执行重启:只启动 worker,不重新导出。"""
- try:
- if settings.use_remote_compute:
- pid = await _launch_remote_worker(task_id, output_path, port)
- else:
- pid = await _launch_local_worker(task_id, output_path, port)
- endpoint_url = f"/api/v1/deployment/proxy/{task_id}/v1"
- await _update_deploy_status(
- task_id, "running",
- output_path=output_path,
- endpoint_url=endpoint_url,
- port=port,
- pid=pid,
- )
- return {"endpoint_url": endpoint_url, "port": port, "pid": pid}
- except Exception as e:
- logger.error(f"Restart failed for task {task_id}: {e}")
- await _update_deploy_status(task_id, "failed", error=str(e))
- return {"error": str(e)}
- async def list_deployed_services(user_id: str = "") -> list[dict[str, Any]]:
- """列出 serve 模式的部署任务(按用户过滤)。"""
- async with async_session() as session:
- query = select(DeployTaskModel).where(DeployTaskModel.deploy_mode == "serve")
- if user_id:
- query = query.where(DeployTaskModel.user_id == user_id)
- query = query.order_by(DeployTaskModel.created_at.desc())
- result = await session.execute(query)
- records = result.scalars().all()
- services = []
- for r in records:
- status = r.status
- # 对 running 状态,检查远程进程是否还活着
- if status == "running" and r.pid and settings.use_remote_compute:
- from app.core.remote_executor import is_process_running
- proc_state = await asyncio.to_thread(is_process_running, r.pid)
- if proc_state == "stopped":
- # 确认进程已退出,标记为 stopped
- status = "stopped"
- await _update_deploy_status(r.id, "stopped", error="进程已退出")
- r.port = None
- r.pid = None
- # proc_state == "unknown" 时不改状态(SSH 超时不代表进程死了)
- services.append({
- "task_id": r.id,
- "job_id": r.job_id,
- "status": status,
- "endpoint_url": r.endpoint_url,
- "base_url": r.endpoint_url,
- "port": r.port,
- "output_path": r.output_path,
- "created_at": r.created_at.isoformat() if r.created_at else None,
- "error": r.error,
- })
- return services
- async def get_deploy_status(task_id: str) -> dict[str, Any]:
- """获取部署任务状态。"""
- async with async_session() as session:
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "task_id": record.id,
- "job_id": record.job_id,
- "status": record.status,
- "deploy_mode": record.deploy_mode or "export",
- "progress": record.progress,
- "output_path": record.output_path,
- "endpoint_url": record.endpoint_url,
- "port": record.port,
- "error": record.error,
- }
- return {"task_id": None, "job_id": "", "status": "not_found", "deploy_mode": "export",
- "progress": 0.0, "output_path": None, "endpoint_url": None, "port": None, "error": None}
- # ---------------------------------------------------------------------------
- # 辅助函数
- # ---------------------------------------------------------------------------
- async def _allocate_port() -> int:
- """从端口池里分配一个未使用的端口。"""
- async with async_session() as session:
- result = await session.execute(
- select(DeployTaskModel.port).where(
- DeployTaskModel.deploy_mode == "serve",
- DeployTaskModel.status.in_(["pending", "running"]),
- DeployTaskModel.port.isnot(None),
- )
- )
- used = {row[0] for row in result.all()}
- for port in range(_SERVE_PORT_MIN, _SERVE_PORT_MAX + 1):
- if port not in used:
- return port
- raise RuntimeError(f"无可用端口({_SERVE_PORT_MIN}-{_SERVE_PORT_MAX} 全部占用)")
- async def _check_port_available(port: int) -> bool:
- """检查指定端口是否可被复用(没有被其他 pending/running 服务占用)。"""
- async with async_session() as session:
- result = await session.execute(
- select(DeployTaskModel.id).where(
- DeployTaskModel.deploy_mode == "serve",
- DeployTaskModel.status.in_(["pending", "running"]),
- DeployTaskModel.port == port,
- )
- )
- return result.first() is None
- async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
- """通过 SSH 在远程容器执行模型合并/导出。"""
- remote_cmd = (
- f"docker exec "
- f"-e MACA_MPS_MODE=1 "
- f"-e CUDA_VISIBLE_DEVICES=3 "
- f"-w {settings.compute_node_workdir} "
- f"{settings.compute_node_docker_container} "
- f"{settings.compute_node_python} -c \""
- "import asyncio, json; "
- "from app.core.remote_deploy import run_remote_export; "
- f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); "
- "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
- )
- code, stdout, stderr = await asyncio.to_thread(ssh_exec, remote_cmd, timeout=600)
- if code != 0:
- raise RuntimeError(f"Remote export failed: {stderr}")
- for line in reversed(stdout.strip().split("\n")):
- line = line.strip()
- if line.startswith("{"):
- try:
- result = json.loads(line)
- if "error" in result:
- raise RuntimeError(result["error"])
- return result
- except json.JSONDecodeError:
- continue
- raise RuntimeError(f"Invalid response: {stdout[:500]}")
- async def _run_local_export(task_id: str, job_id: str, merge_with_base: bool) -> dict:
- """本地执行导出(开发用)。"""
- adapter_path = settings.adapters_dir / job_id
- if not adapter_path.exists():
- raise ValueError("Adapter not found")
- output_path = settings.adapters_dir / f"{job_id}_merged"
- if merge_with_base:
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- base_model_id = _get_base_model_id_local(job_id)
- if base_model_id:
- from peft import PeftModel
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_id, torch_dtype=torch.float16, device_map="auto"
- )
- peft_model = PeftModel.from_pretrained(base_model, adapter_path)
- merged = peft_model.merge_and_unload()
- merged.save_pretrained(output_path)
- tokenizer = AutoTokenizer.from_pretrained(adapter_path)
- tokenizer.save_pretrained(output_path)
- else:
- from peft import PeftModel
- merged = PeftModel.from_pretrained(
- AutoModelForCausalLM.from_pretrained(
- str(adapter_path), torch_dtype=torch.float16
- ),
- str(adapter_path),
- )
- merged = merged.merge_and_unload()
- merged.save_pretrained(output_path)
- tokenizer = AutoTokenizer.from_pretrained(adapter_path)
- tokenizer.save_pretrained(output_path)
- else:
- import shutil
- if output_path.exists():
- shutil.rmtree(output_path)
- shutil.copytree(adapter_path, output_path)
- return {"output_path": str(output_path)}
- async def _copy_worker_template_remote(output_path: str):
- """把 inference_worker.py 和启动脚本复制到远程模型目录。"""
- worker_template = f"{settings.compute_node_workdir}/app/core/inference_worker.py"
- copy_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'cp {worker_template} {output_path}/inference_worker.py'"
- )
- code, _, stderr = await asyncio.to_thread(ssh_exec, copy_cmd, timeout=30)
- if code != 0:
- logger.warning(f"复制 inference_worker.py 到 {output_path} 失败: {stderr}")
- # 生成快捷启动脚本
- inference_gpu = settings.inference_cuda_devices.split(",")[0].strip()
- start_script = (
- f"#!/bin/bash\n"
- f"cd {output_path}\n"
- f"CUDA_VISIBLE_DEVICES={inference_gpu} MACA_MPS_MODE=1 "
- f"{settings.compute_node_python} inference_worker.py "
- f"--model-path . --port 8100\n"
- )
- script_cmd = (
- f"docker exec {settings.compute_node_docker_container} "
- f"bash -c 'cat > {output_path}/start.sh << \"EOF\"\n{start_script}EOF\n"
- f"chmod +x {output_path}/start.sh'"
- )
- code, _, _ = await asyncio.to_thread(ssh_exec, script_cmd, timeout=15)
- if code != 0:
- logger.warning(f"生成 start.sh 失败")
- def _get_base_model_id_local(job_id: str):
- config_path = settings.adapters_dir / job_id / "adapter_config.json"
- if config_path.exists():
- with open(config_path) as f:
- return json.load(f).get("base_model_name_or_path")
- return None
- async def _update_deploy_status(
- task_id: str, status: str,
- output_path: str = None, error: str = None,
- endpoint_url: str = None, port: int = None, pid: str = None,
- ):
- async with async_session() as session:
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- record.status = status
- if output_path:
- record.output_path = output_path
- if error:
- record.error = error
- if endpoint_url:
- record.endpoint_url = endpoint_url
- if port:
- record.port = port
- if pid:
- record.pid = pid
- if status in ("completed", "failed", "stopped"):
- record.finished_at = datetime.utcnow()
- if status == "pending":
- # 重启时清除完成时间和错误信息
- record.finished_at = None
- record.error = None
- await session.commit()
- background_task_manager.update_task(
- task_id, status=status, output_path=output_path, error=error,
- endpoint_url=endpoint_url,
- )
- async def recover_stale_deploys() -> None:
- async with async_session() as session:
- result = await session.execute(
- select(DeployTaskModel).where(
- DeployTaskModel.status.in_(["pending", "running"])
- )
- )
- records = result.scalars().all()
- for record in records:
- if record.deploy_mode == "export":
- record.status = "failed"
- record.error = "Server restarted, task interrupted"
- elif record.deploy_mode == "serve":
- if record.pid and settings.use_remote_compute:
- from app.core.remote_executor import is_process_running
- proc_state = await asyncio.to_thread(is_process_running, record.pid)
- if proc_state == "stopped":
- record.status = "stopped"
- record.error = "Server restarted, process no longer running"
- else:
- continue # 进程还在或无法确认,保持 running
- else:
- record.status = "stopped"
- record.error = "Server restarted, process state unknown"
- # 释放端口,确保下次分配时可用
- if record.status == "stopped":
- record.port = None
- record.pid = None
- record.finished_at = datetime.utcnow()
- if records:
- await session.commit()
- logger.info(f"Recovered {len(records)} stale deploy tasks")
|