"""部署服务 —— 导出模型 / 部署为在线推理服务。 架构: - 253 算力节点运行轻量 inference_worker.py(纯 stdlib + torch/transformers,不需要 fastapi/uvicorn) - 151 主节点对外提供 OpenAI 兼容代理 API,通过 TCP 转发请求到 253 """ import asyncio import json import socket import struct import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from app.config import get_settings from app.core.background_tasks import background_task_manager from app.core.db import async_session, DeployTaskModel from app.core.logging import logger from app.core.remote_executor import ssh_exec from sqlalchemy import select settings = get_settings() # 253 上 worker 的 TCP 端口范围 _SERVE_PORT_MIN = 8100 _SERVE_PORT_MAX = 8199 # --------------------------------------------------------------------------- # TCP 代理:151 → 253 inference_worker # --------------------------------------------------------------------------- async def proxy_to_worker(task_id: str, request: dict) -> dict: """通过 TCP 把推理请求转发到 253 的 inference_worker,返回响应。 协议:4 字节大端长度前缀 + JSON body """ # 查 DB 获取 worker 监听的端口 async with async_session() as session: result = await session.execute( select(DeployTaskModel).where(DeployTaskModel.id == task_id) ) record = result.scalar_one_or_none() if not record: return {"error": "部署任务不存在"} if record.status != "running": return {"error": f"服务未运行(当前状态: {record.status})"} port = record.port if not port: return {"error": "未找到 worker 端口"} # 通过 asyncio 在线程池中执行同步 TCP 操作 return await asyncio.to_thread(_tcp_request, settings.compute_node_host, port, request) def _tcp_request(host: str, port: int, request: dict) -> dict: """同步 TCP 请求:连接到 worker,发送请求,接收响应。""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(120) # 推理可能耗时较长 try: sock.connect((host, port)) # 发送:4 字节长度 + JSON body = json.dumps(request, ensure_ascii=False).encode("utf-8") sock.sendall(struct.pack(">I", len(body))) sock.sendall(body) # 接收:4 字节长度 + JSON len_data = _recv_exact(sock, 4) resp_len = struct.unpack(">I", len_data)[0] resp_data = _recv_exact(sock, resp_len) return json.loads(resp_data.decode("utf-8")) except socket.timeout: return {"error": "推理超时(120s)"} except ConnectionRefusedError: return {"error": f"无法连接到推理 worker({host}:{port}),服务可能已停止"} except Exception as e: return {"error": f"代理请求失败: {e}"} finally: sock.close() def _recv_exact(sock: socket.socket, n: int) -> bytes: """确保接收恰好 n 字节。""" buf = bytearray() while len(buf) < n: chunk = sock.recv(n - len(buf)) if not chunk: raise ConnectionError("Connection closed while reading") buf.extend(chunk) return bytes(buf) # --------------------------------------------------------------------------- # 导出 Adapter(导出文件模式) # --------------------------------------------------------------------------- async def export_adapter(job_id: str, config: dict[str, Any], user_id: str = "") -> dict[str, Any]: """启动导出后台任务,立即返回 task_id。""" task_id = str(uuid.uuid4()) merge_with_base = config.get("merge_with_base", False) export_format = config.get("export_format", "safetensors") task = DeployTaskModel( id=task_id, job_id=job_id, user_id=user_id or None, status="pending", deploy_mode="export", ) async with async_session() as session: session.add(task) await session.commit() background_task_manager.register_task(task_id, "deployment", {"job_id": job_id}) await background_task_manager.run( task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format) ) logger.info(f"Deploy task started: job={job_id} (task_id={task_id})") return {"task_id": task_id, "job_id": job_id, "status": "pending", "deploy_mode": "export"} async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict: """后台执行导出。""" try: if settings.use_remote_compute: result = await _run_remote_export(task_id, job_id, merge_with_base, export_format) else: result = await _run_local_export(task_id, job_id, merge_with_base) output_path = result.get("output_path") # 把 inference_worker.py 和启动脚本复制到输出目录 if output_path and settings.use_remote_compute: _copy_worker_template_remote(output_path) await _update_deploy_status(task_id, "completed", output_path=output_path) return {"output_path": output_path} except Exception as e: logger.error(f"Export failed for job {job_id}: {e}") await _update_deploy_status(task_id, "failed", error=str(e)) return {"error": str(e)} # --------------------------------------------------------------------------- # 部署为在线服务(serve 模式) # --------------------------------------------------------------------------- async def start_serving(job_id: str, config: dict[str, Any], user_id: str = "") -> dict[str, Any]: """部署为在线推理服务,151 代理对外,253 worker 做推理。""" task_id = str(uuid.uuid4()) merge_with_base = config.get("merge_with_base", True) port = config.get("port") if not port: port = await _allocate_port() task = DeployTaskModel( id=task_id, job_id=job_id, user_id=user_id or None, status="pending", deploy_mode="serve", port=port, ) async with async_session() as session: session.add(task) await session.commit() background_task_manager.register_task(task_id, "deployment", {"job_id": job_id, "mode": "serve"}) await background_task_manager.run( task_id, "deployment", _execute_serve(task_id, job_id, merge_with_base, port) ) logger.info(f"Serve task started: job={job_id} port={port} (task_id={task_id})") return {"task_id": task_id, "job_id": job_id, "status": "pending", "deploy_mode": "serve", "port": port} async def _execute_serve(task_id: str, job_id: str, merge_with_base: bool, port: int) -> dict: """后台执行:导出模型 → 复制 worker → 启动 TCP 推理 worker。""" try: # 第一步:导出(合并 adapter) if settings.use_remote_compute: export_result = await _run_remote_export(task_id, job_id, merge_with_base, "safetensors") output_path = export_result.get("output_path") else: export_result = await _run_local_export(task_id, job_id, merge_with_base) output_path = export_result.get("output_path") if not output_path: raise RuntimeError("导出失败,无法获取输出路径") # 第二步:启动推理 worker if settings.use_remote_compute: pid = await _launch_remote_worker(task_id, output_path, port) else: pid = await _launch_local_worker(task_id, output_path, port) # endpoint_url 是 151 上的代理路径(相对路径,前端拼接 origin) endpoint_url = f"/api/v1/deployment/proxy/{task_id}/v1" await _update_deploy_status( task_id, "running", output_path=output_path, endpoint_url=endpoint_url, port=port, pid=pid, ) return {"endpoint_url": endpoint_url, "port": port, "pid": pid} except Exception as e: logger.error(f"Serve failed for job {job_id}: {e}") await _update_deploy_status(task_id, "failed", error=str(e)) return {"error": str(e)} async def _launch_remote_worker(task_id: str, model_path: str, port: int) -> str: """在远程 253 容器里启动 inference_worker.py,返回进程 PID。 只依赖 torch + transformers(不需要 fastapi/uvicorn)。 """ # worker 脚本在容器内的路径 worker_template = f"{settings.compute_node_workdir}/app/core/inference_worker.py" # 复制 worker 到模型目录 copy_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"bash -c 'cp {worker_template} {model_path}/inference_worker.py'" ) code, _, stderr = ssh_exec(copy_cmd, timeout=30) if code != 0: raise RuntimeError(f"复制 inference_worker.py 失败: {stderr}") # 在容器内后台启动 worker # 使用 exec 让 Python 进程直接占用 PID,避免 setsid session leader PID 不匹配 launch_cmd = ( f"docker exec " f"-e MACA_MPS_MODE=1 " f"-e CUDA_VISIBLE_DEVICES=3 " f"-w {model_path} " f"{settings.compute_node_docker_container} " f"bash -c '" f"nohup {settings.compute_node_python} inference_worker.py " f"--model-path {model_path} " f"--port {port} " f"/tmp/serve_{task_id}.log 2>&1 &" f" echo $!'" ) code, stdout, stderr = ssh_exec(launch_cmd, timeout=30) if code != 0: raise RuntimeError(f"启动推理 worker 失败: {stderr}") pid = stdout.strip() logger.info(f"Remote worker launched: task={task_id} port={port} pid={pid}") # 等待模型加载(可能需要较长时间),检查 READY 标记 # 每次轮询只用一次 SSH 连接,同时检查 READY 和进程状态 import asyncio as _aio for attempt in range(60): # 最多等 5 分钟(60 * 5s) await _aio.sleep(5) check_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"bash -c '" f" ready=$(grep -c READY /tmp/serve_{task_id}.log 2>/dev/null || echo 0); " f" if [ \"$ready\" != \"0\" ]; then echo \"READY:$ready\"; exit 0; fi; " f" if ! kill -0 {pid} 2>/dev/null; then echo \"DEAD\"; exit 0; fi; " f" echo \"ALIVE\"; " f"'" ) code, stdout, stderr = ssh_exec(check_cmd, timeout=30) if code == 0: result = stdout.strip() if result.startswith("READY:"): logger.info(f"Worker ready: task={task_id} (after ~{(attempt+1)*5}s)") return pid elif result == "DEAD": # 读取日志看什么错了 log_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"bash -c 'tail -20 /tmp/serve_{task_id}.log 2>/dev/null'" ) _, log_stdout, _ = ssh_exec(log_cmd, timeout=30) raise RuntimeError(f"Worker 进程已退出: {log_stdout}") # result == "ALIVE" → 继续等待 logger.warning(f"Worker not ready after 5min: task={task_id}, proceeding anyway") return pid async def _launch_local_worker(task_id: str, model_path: str, port: int) -> str: """在本地启动推理 worker(开发用)。""" import subprocess import shutil import sys worker_src = Path(__file__).resolve().parent.parent / "core" / "inference_worker.py" shutil.copy(worker_src, Path(model_path) / "inference_worker.py") proc = subprocess.Popen( [sys.executable, "inference_worker.py", "--model-path", model_path, "--port", str(port)], cwd=model_path, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) return str(proc.pid) # --------------------------------------------------------------------------- # 停止服务 / 列表 / 状态 # --------------------------------------------------------------------------- async def stop_serving(task_id: str, user_id: str = "") -> dict[str, Any]: """停止已部署的在线服务。""" async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if not record: return {"error": "任务不存在"} if record.deploy_mode != "serve": return {"error": "该任务不是在线服务"} if user_id and record.user_id and record.user_id != user_id: return {"error": "无权操作此任务"} pid = record.pid if pid and settings.use_remote_compute: # 杀掉远程 worker 进程及其子线程 kill_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"bash -c 'kill {pid} 2>/dev/null; pkill -P {pid} 2>/dev/null; true'" ) code, _, _ = ssh_exec(kill_cmd, timeout=15) logger.info(f"Stop serving: task={task_id} pid={pid} kill_code={code}") record.status = "stopped" record.pid = None record.finished_at = datetime.utcnow() await session.commit() background_task_manager.update_task(task_id, status="stopped") return {"task_id": task_id, "status": "stopped"} async def list_deployed_services(user_id: str = "") -> list[dict[str, Any]]: """列出 serve 模式的部署任务(按用户过滤)。""" async with async_session() as session: query = select(DeployTaskModel).where(DeployTaskModel.deploy_mode == "serve") if user_id: query = query.where(DeployTaskModel.user_id == user_id) query = query.order_by(DeployTaskModel.created_at.desc()) result = await session.execute(query) records = result.scalars().all() services = [] for r in records: status = r.status # 对 running 状态,检查远程进程是否还活着 if status == "running" and r.pid and settings.use_remote_compute: from app.core.remote_executor import is_process_running if not is_process_running(r.pid): status = "stopped" await _update_deploy_status(r.id, "stopped", error="进程已退出") services.append({ "task_id": r.id, "job_id": r.job_id, "status": status, "endpoint_url": r.endpoint_url, "base_url": r.endpoint_url, "port": r.port, "output_path": r.output_path, "created_at": r.created_at.isoformat() if r.created_at else None, "error": r.error, }) return services async def get_deploy_status(task_id: str) -> dict[str, Any]: """获取部署任务状态。""" async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if record: return { "task_id": record.id, "job_id": record.job_id, "status": record.status, "deploy_mode": record.deploy_mode or "export", "progress": record.progress, "output_path": record.output_path, "endpoint_url": record.endpoint_url, "port": record.port, "error": record.error, } return {"task_id": None, "job_id": "", "status": "not_found", "deploy_mode": "export", "progress": 0.0, "output_path": None, "endpoint_url": None, "port": None, "error": None} # --------------------------------------------------------------------------- # 辅助函数 # --------------------------------------------------------------------------- async def _allocate_port() -> int: """从端口池里分配一个未使用的端口。""" async with async_session() as session: result = await session.execute( select(DeployTaskModel.port).where( DeployTaskModel.deploy_mode == "serve", DeployTaskModel.status.in_(["pending", "running"]), DeployTaskModel.port.isnot(None), ) ) used = {row[0] for row in result.all()} for port in range(_SERVE_PORT_MIN, _SERVE_PORT_MAX + 1): if port not in used: return port raise RuntimeError(f"无可用端口({_SERVE_PORT_MIN}-{_SERVE_PORT_MAX} 全部占用)") async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict: """通过 SSH 在远程容器执行模型合并/导出。""" remote_cmd = ( f"docker exec " f"-e MACA_MPS_MODE=1 " f"-e CUDA_VISIBLE_DEVICES=3 " f"-w {settings.compute_node_workdir} " f"{settings.compute_node_docker_container} " f"{settings.compute_node_python} -c \"" "import asyncio, json; " "from app.core.remote_deploy import run_remote_export; " f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); " "print(json.dumps(result, ensure_ascii=False))\" 2>&1" ) code, stdout, stderr = ssh_exec(remote_cmd, timeout=600) if code != 0: raise RuntimeError(f"Remote export failed: {stderr}") for line in reversed(stdout.strip().split("\n")): line = line.strip() if line.startswith("{"): try: result = json.loads(line) if "error" in result: raise RuntimeError(result["error"]) return result except json.JSONDecodeError: continue raise RuntimeError(f"Invalid response: {stdout[:500]}") async def _run_local_export(task_id: str, job_id: str, merge_with_base: bool) -> dict: """本地执行导出(开发用)。""" adapter_path = settings.adapters_dir / job_id if not adapter_path.exists(): raise ValueError("Adapter not found") output_path = settings.adapters_dir / f"{job_id}_merged" if merge_with_base: import torch from transformers import AutoModelForCausalLM, AutoTokenizer base_model_id = _get_base_model_id_local(job_id) if base_model_id: from peft import PeftModel base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto" ) peft_model = PeftModel.from_pretrained(base_model, adapter_path) merged = peft_model.merge_and_unload() merged.save_pretrained(output_path) tokenizer = AutoTokenizer.from_pretrained(adapter_path) tokenizer.save_pretrained(output_path) else: from peft import PeftModel merged = PeftModel.from_pretrained( AutoModelForCausalLM.from_pretrained( str(adapter_path), torch_dtype=torch.float16 ), str(adapter_path), ) merged = merged.merge_and_unload() merged.save_pretrained(output_path) tokenizer = AutoTokenizer.from_pretrained(adapter_path) tokenizer.save_pretrained(output_path) else: import shutil if output_path.exists(): shutil.rmtree(output_path) shutil.copytree(adapter_path, output_path) return {"output_path": str(output_path)} def _copy_worker_template_remote(output_path: str): """把 inference_worker.py 和启动脚本复制到远程模型目录。""" worker_template = f"{settings.compute_node_workdir}/app/core/inference_worker.py" copy_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"bash -c 'cp {worker_template} {output_path}/inference_worker.py'" ) code, _, stderr = ssh_exec(copy_cmd, timeout=30) if code != 0: logger.warning(f"复制 inference_worker.py 到 {output_path} 失败: {stderr}") # 生成快捷启动脚本 start_script = ( f"#!/bin/bash\n" f"cd {output_path}\n" f"CUDA_VISIBLE_DEVICES=3 MACA_MPS_MODE=1 " f"{settings.compute_node_python} inference_worker.py " f"--model-path . --port 8100\n" ) script_cmd = ( f"docker exec {settings.compute_node_docker_container} " f"bash -c 'cat > {output_path}/start.sh << \"EOF\"\n{start_script}EOF\n" f"chmod +x {output_path}/start.sh'" ) code, _, _ = ssh_exec(script_cmd, timeout=15) if code != 0: logger.warning(f"生成 start.sh 失败") def _get_base_model_id_local(job_id: str): config_path = settings.adapters_dir / job_id / "adapter_config.json" if config_path.exists(): with open(config_path) as f: return json.load(f).get("base_model_name_or_path") return None async def _update_deploy_status( task_id: str, status: str, output_path: str = None, error: str = None, endpoint_url: str = None, port: int = None, pid: str = None, ): async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if record: record.status = status if output_path: record.output_path = output_path if error: record.error = error if endpoint_url: record.endpoint_url = endpoint_url if port: record.port = port if pid: record.pid = pid if status in ("completed", "failed", "stopped"): record.finished_at = datetime.utcnow() await session.commit() background_task_manager.update_task( task_id, status=status, output_path=output_path, error=error, endpoint_url=endpoint_url, ) async def recover_stale_deploys() -> None: async with async_session() as session: result = await session.execute( select(DeployTaskModel).where( DeployTaskModel.status.in_(["pending", "running"]) ) ) records = result.scalars().all() for record in records: if record.deploy_mode == "export": record.status = "failed" record.error = "Server restarted, task interrupted" elif record.deploy_mode == "serve": if record.pid and settings.use_remote_compute: from app.core.remote_executor import is_process_running if not is_process_running(record.pid): record.status = "stopped" record.error = "Server restarted, process no longer running" else: continue # 进程还在,保持 running else: record.status = "stopped" record.error = "Server restarted, process state unknown" record.finished_at = datetime.utcnow() if records: await session.commit() logger.info(f"Recovered {len(records)} stale deploy tasks")