import uuid from datetime import datetime from pathlib import Path from typing import Any from app.config import get_settings from app.core.db import async_session, DeployTaskModel from app.core.logging import logger from app.core.remote_executor import ssh_exec from sqlalchemy import select settings = get_settings() async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]: """合并 adapter 与基础模型,并可选导出为 GGUF。""" task_id = str(uuid.uuid4()) merge_with_base = config.get("merge_with_base", False) export_format = config.get("export_format", "safetensors") # 写入数据库 task = DeployTaskModel( id=task_id, job_id=job_id, status="pending", created_at=datetime.utcnow(), ) async with async_session() as session: session.add(task) await session.commit() try: # 远程模式:通过 SSH 在算力节点执行 if settings.use_remote_compute: result = await _run_remote_export(task_id, job_id, merge_with_base, export_format) return result # 本地模式(原有逻辑) adapter_path = settings.adapters_dir / job_id if not adapter_path.exists(): return _update_task_status(task_id, "failed", error="Adapter not found") output_path = settings.adapters_dir / f"{job_id}_merged" import torch from transformers import AutoModelForCausalLM, AutoTokenizer if merge_with_base: base_model_id = _get_base_model_id_local(job_id) if base_model_id: base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto" ) else: from peft import PeftModel merged = PeftModel.from_pretrained( AutoModelForCausalLM.from_pretrained( str(adapter_path), torch_dtype=torch.float16 ), adapter_path, ) merged = merged.merge_and_unload() merged.save_pretrained(output_path) tokenizer = AutoTokenizer.from_pretrained(adapter_path) tokenizer.save_pretrained(output_path) else: import shutil shutil.copytree(adapter_path, output_path) if export_format == "gguf": gguf_path = output_path.with_suffix(".gguf") _export_to_gguf_local(output_path, gguf_path) return _update_task_status(task_id, "completed", output_path=str(output_path)) except Exception as e: logger.error(f"Export failed for job {job_id}: {e}") return _update_task_status(task_id, "failed", error=str(e)) async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict: """通过 SSH 在远程容器执行部署。""" remote_cmd = ( f"docker exec " f"-e MACA_MPS_MODE=1 " f"-e CUDA_VISIBLE_DEVICES=2,3 " f"-w {settings.compute_node_workdir} " f"{settings.compute_node_docker_container} " f"{settings.compute_node_python} -c \"" "import asyncio, json; " "from app.core.remote_deploy import run_remote_export; " f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); " "print(json.dumps(result, ensure_ascii=False))\" 2>&1" ) code, stdout, stderr = ssh_exec(remote_cmd, timeout=600) if code != 0: logger.error(f"Remote export failed: {stderr}") return _update_task_status(task_id, "failed", error=stderr.strip()) for line in reversed(stdout.strip().split("\n")): line = line.strip() if line.startswith("{"): try: result = json.loads(line) if "error" in result: return _update_task_status(task_id, "failed", error=result["error"]) return _update_task_status(task_id, "completed", output_path=result.get("output_path")) except json.JSONDecodeError: continue return _update_task_status(task_id, "failed", error=f"Invalid response: {stdout[:500]}") def _update_task_status(task_id: str, status: str, output_path: str = None, error: str = None): import asyncio async def _update(): async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if record: record.status = status if output_path: record.output_path = output_path if error: record.error = error await session.commit() asyncio.get_event_loop().run_until_complete(_update()) base = {"job_id": "", "status": status, "output_path": output_path} if error: base["error"] = error return base def _get_base_model_id_local(job_id: str): config_path = settings.adapters_dir / job_id / "adapter_config.json" if config_path.exists(): import json with open(config_path) as f: return json.load(f).get("base_model_name_or_path") return None def _export_to_gguf_local(model_path: Path, output_path: Path): try: import subprocess result = subprocess.run( ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)], capture_output=True, text=True, timeout=600, ) if result.returncode != 0: logger.error(f"GGUF export failed: {result.stderr}") except Exception as e: logger.warning(f"GGUF export not available: {e}") async def get_deploy_status(task_id: str) -> dict[str, Any]: """获取部署任务状态。""" async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if record: return { "job_id": record.job_id, "status": record.status, "output_path": record.output_path, "error": record.error, } return {"job_id": "", "status": "not_found", "output_path": None, "error": None}