|
@@ -6,6 +6,8 @@ from typing import Any
|
|
|
from app.config import get_settings
|
|
from app.config import get_settings
|
|
|
from app.core.db import async_session, DeployTaskModel
|
|
from app.core.db import async_session, DeployTaskModel
|
|
|
from app.core.logging import logger
|
|
from app.core.logging import logger
|
|
|
|
|
+from app.core.remote_executor import ssh_exec
|
|
|
|
|
+from sqlalchemy import select
|
|
|
|
|
|
|
|
settings = get_settings()
|
|
settings = get_settings()
|
|
|
|
|
|
|
@@ -16,12 +18,6 @@ async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
|
merge_with_base = config.get("merge_with_base", False)
|
|
merge_with_base = config.get("merge_with_base", False)
|
|
|
export_format = config.get("export_format", "safetensors")
|
|
export_format = config.get("export_format", "safetensors")
|
|
|
|
|
|
|
|
- adapter_path = settings.adapters_dir / job_id
|
|
|
|
|
- if not adapter_path.exists():
|
|
|
|
|
- return {"job_id": job_id, "status": "failed", "output_path": None, "error": "Adapter not found"}
|
|
|
|
|
-
|
|
|
|
|
- output_path = settings.adapters_dir / f"{job_id}_merged"
|
|
|
|
|
-
|
|
|
|
|
# 写入数据库
|
|
# 写入数据库
|
|
|
task = DeployTaskModel(
|
|
task = DeployTaskModel(
|
|
|
id=task_id,
|
|
id=task_id,
|
|
@@ -34,24 +30,32 @@ async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
|
await session.commit()
|
|
await session.commit()
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
|
|
+ # 远程模式:通过 SSH 在算力节点执行
|
|
|
|
|
+ if settings.use_remote_compute:
|
|
|
|
|
+ result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ # 本地模式(原有逻辑)
|
|
|
|
|
+ adapter_path = settings.adapters_dir / job_id
|
|
|
|
|
+ if not adapter_path.exists():
|
|
|
|
|
+ return _update_task_status(task_id, "failed", error="Adapter not found")
|
|
|
|
|
+
|
|
|
|
|
+ output_path = settings.adapters_dir / f"{job_id}_merged"
|
|
|
|
|
+
|
|
|
import torch
|
|
import torch
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
if merge_with_base:
|
|
if merge_with_base:
|
|
|
- # 加载 base model 并合并 adapter
|
|
|
|
|
- base_model_id = _get_base_model_id(job_id)
|
|
|
|
|
|
|
+ base_model_id = _get_base_model_id_local(job_id)
|
|
|
if base_model_id:
|
|
if base_model_id:
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
base_model_id, torch_dtype=torch.float16, device_map="auto"
|
|
base_model_id, torch_dtype=torch.float16, device_map="auto"
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
- # 尝试从 adapter config 中推断
|
|
|
|
|
from peft import PeftModel
|
|
from peft import PeftModel
|
|
|
-
|
|
|
|
|
- # 直接从 adapter 加载(需要 base_model_name_or_path)
|
|
|
|
|
merged = PeftModel.from_pretrained(
|
|
merged = PeftModel.from_pretrained(
|
|
|
AutoModelForCausalLM.from_pretrained(
|
|
AutoModelForCausalLM.from_pretrained(
|
|
|
- adapter_path / "adapter_config.json", torch_dtype=torch.float16
|
|
|
|
|
|
|
+ str(adapter_path), torch_dtype=torch.float16
|
|
|
),
|
|
),
|
|
|
adapter_path,
|
|
adapter_path,
|
|
|
)
|
|
)
|
|
@@ -59,76 +63,89 @@ async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
|
merged.save_pretrained(output_path)
|
|
merged.save_pretrained(output_path)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(adapter_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(adapter_path)
|
|
|
tokenizer.save_pretrained(output_path)
|
|
tokenizer.save_pretrained(output_path)
|
|
|
- logger.info(f"Adapter merged and saved to {output_path}")
|
|
|
|
|
else:
|
|
else:
|
|
|
- # 仅复制 adapter 文件
|
|
|
|
|
import shutil
|
|
import shutil
|
|
|
shutil.copytree(adapter_path, output_path)
|
|
shutil.copytree(adapter_path, output_path)
|
|
|
- logger.info(f"Adapter copied to {output_path}")
|
|
|
|
|
|
|
|
|
|
- # 可选导出 GGUF
|
|
|
|
|
if export_format == "gguf":
|
|
if export_format == "gguf":
|
|
|
gguf_path = output_path.with_suffix(".gguf")
|
|
gguf_path = output_path.with_suffix(".gguf")
|
|
|
- _export_to_gguf(output_path, gguf_path)
|
|
|
|
|
|
|
+ _export_to_gguf_local(output_path, gguf_path)
|
|
|
|
|
|
|
|
- # 更新数据库
|
|
|
|
|
- async with async_session() as session:
|
|
|
|
|
- from sqlalchemy import select
|
|
|
|
|
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
|
|
|
|
|
- record = result.scalar_one_or_none()
|
|
|
|
|
- if record:
|
|
|
|
|
- record.status = "completed"
|
|
|
|
|
- record.output_path = str(output_path)
|
|
|
|
|
- await session.commit()
|
|
|
|
|
-
|
|
|
|
|
- return {"job_id": job_id, "status": "completed", "output_path": str(output_path)}
|
|
|
|
|
|
|
+ return _update_task_status(task_id, "completed", output_path=str(output_path))
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Export failed for job {job_id}: {e}")
|
|
logger.error(f"Export failed for job {job_id}: {e}")
|
|
|
|
|
+ return _update_task_status(task_id, "failed", error=str(e))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
|
|
|
|
|
+ """通过 SSH 在远程容器执行部署。"""
|
|
|
|
|
+ remote_cmd = (
|
|
|
|
|
+ f"docker exec "
|
|
|
|
|
+ f"-e MACA_MPS_MODE=1 "
|
|
|
|
|
+ f"-e METAX_VISIBLE_DEVICES=2,3 "
|
|
|
|
|
+ f"-w {settings.compute_node_workdir} "
|
|
|
|
|
+ f"{settings.compute_node_docker_container} "
|
|
|
|
|
+ f"{settings.compute_node_python} -c \""
|
|
|
|
|
+ "import asyncio, json; "
|
|
|
|
|
+ "from app.core.remote_deploy import run_remote_export; "
|
|
|
|
|
+ f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); "
|
|
|
|
|
+ "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
|
|
|
|
|
+
|
|
|
|
|
+ if code != 0:
|
|
|
|
|
+ logger.error(f"Remote export failed: {stderr}")
|
|
|
|
|
+ return _update_task_status(task_id, "failed", error=stderr.strip())
|
|
|
|
|
+
|
|
|
|
|
+ for line in reversed(stdout.strip().split("\n")):
|
|
|
|
|
+ line = line.strip()
|
|
|
|
|
+ if line.startswith("{"):
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = json.loads(line)
|
|
|
|
|
+ if "error" in result:
|
|
|
|
|
+ return _update_task_status(task_id, "failed", error=result["error"])
|
|
|
|
|
+ return _update_task_status(task_id, "completed", output_path=result.get("output_path"))
|
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ return _update_task_status(task_id, "failed", error=f"Invalid response: {stdout[:500]}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _update_task_status(task_id: str, status: str, output_path: str = None, error: str = None):
|
|
|
|
|
+ import asyncio
|
|
|
|
|
+
|
|
|
|
|
+ async def _update():
|
|
|
async with async_session() as session:
|
|
async with async_session() as session:
|
|
|
- from sqlalchemy import select
|
|
|
|
|
result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
|
|
result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
|
|
|
record = result.scalar_one_or_none()
|
|
record = result.scalar_one_or_none()
|
|
|
if record:
|
|
if record:
|
|
|
- record.status = "failed"
|
|
|
|
|
- record.error = str(e)
|
|
|
|
|
|
|
+ record.status = status
|
|
|
|
|
+ if output_path:
|
|
|
|
|
+ record.output_path = output_path
|
|
|
|
|
+ if error:
|
|
|
|
|
+ record.error = error
|
|
|
await session.commit()
|
|
await session.commit()
|
|
|
|
|
|
|
|
- return {"job_id": job_id, "status": "failed", "output_path": None, "error": str(e)}
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-async def get_deploy_status(task_id: str) -> dict[str, Any]:
|
|
|
|
|
- """获取部署任务状态。"""
|
|
|
|
|
- async with async_session() as session:
|
|
|
|
|
- from sqlalchemy import select
|
|
|
|
|
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
|
|
|
|
|
- record = result.scalar_one_or_none()
|
|
|
|
|
- if record:
|
|
|
|
|
- return {
|
|
|
|
|
- "job_id": record.job_id,
|
|
|
|
|
- "status": record.status,
|
|
|
|
|
- "output_path": record.output_path,
|
|
|
|
|
- "error": record.error,
|
|
|
|
|
- }
|
|
|
|
|
- return {"job_id": "", "status": "not_found", "output_path": None, "error": None}
|
|
|
|
|
|
|
+ asyncio.get_event_loop().run_until_complete(_update())
|
|
|
|
|
+ base = {"job_id": "", "status": status, "output_path": output_path}
|
|
|
|
|
+ if error:
|
|
|
|
|
+ base["error"] = error
|
|
|
|
|
+ return base
|
|
|
|
|
|
|
|
|
|
|
|
|
-def _get_base_model_id(job_id: str) -> str | None:
|
|
|
|
|
- """从 adapter config 中获取 base model ID。"""
|
|
|
|
|
|
|
+def _get_base_model_id_local(job_id: str):
|
|
|
config_path = settings.adapters_dir / job_id / "adapter_config.json"
|
|
config_path = settings.adapters_dir / job_id / "adapter_config.json"
|
|
|
if config_path.exists():
|
|
if config_path.exists():
|
|
|
import json
|
|
import json
|
|
|
with open(config_path) as f:
|
|
with open(config_path) as f:
|
|
|
- cfg = json.load(f)
|
|
|
|
|
- return cfg.get("base_model_name_or_path")
|
|
|
|
|
|
|
+ return json.load(f).get("base_model_name_or_path")
|
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
-def _export_to_gguf(model_path: Path, output_path: Path):
|
|
|
|
|
- """导出模型为 GGUF 格式。"""
|
|
|
|
|
|
|
+def _export_to_gguf_local(model_path: Path, output_path: Path):
|
|
|
try:
|
|
try:
|
|
|
- from llama_cpp import Llama
|
|
|
|
|
- # 使用 llama-cpp-python 的 convert 工具
|
|
|
|
|
import subprocess
|
|
import subprocess
|
|
|
result = subprocess.run(
|
|
result = subprocess.run(
|
|
|
["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
|
|
["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
|
|
@@ -138,3 +155,18 @@ def _export_to_gguf(model_path: Path, output_path: Path):
|
|
|
logger.error(f"GGUF export failed: {result.stderr}")
|
|
logger.error(f"GGUF export failed: {result.stderr}")
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.warning(f"GGUF export not available: {e}")
|
|
logger.warning(f"GGUF export not available: {e}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def get_deploy_status(task_id: str) -> dict[str, Any]:
|
|
|
|
|
+ """获取部署任务状态。"""
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
|
|
|
|
|
+ record = result.scalar_one_or_none()
|
|
|
|
|
+ if record:
|
|
|
|
|
+ return {
|
|
|
|
|
+ "job_id": record.job_id,
|
|
|
|
|
+ "status": record.status,
|
|
|
|
|
+ "output_path": record.output_path,
|
|
|
|
|
+ "error": record.error,
|
|
|
|
|
+ }
|
|
|
|
|
+ return {"job_id": "", "status": "not_found", "output_path": None, "error": None}
|