import json import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from app.config import get_settings from app.core.background_tasks import background_task_manager from app.core.db import async_session, DeployTaskModel from app.core.logging import logger from app.core.remote_executor import ssh_exec from sqlalchemy import select settings = get_settings() async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]: """启动导出后台任务,立即返回 task_id。""" task_id = str(uuid.uuid4()) merge_with_base = config.get("merge_with_base", False) export_format = config.get("export_format", "safetensors") # 写 DB task = DeployTaskModel( id=task_id, job_id=job_id, status="pending", ) async with async_session() as session: session.add(task) await session.commit() # 注册并启动 background_task_manager.register_task(task_id, "deployment", {"job_id": job_id}) background_task_manager.run( task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format) ) logger.info(f"Deploy task started: job={job_id} (task_id={task_id})") return {"job_id": job_id, "status": "pending"} async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict: """后台执行导出。""" try: # 远程模式:通过 SSH 在算力节点执行 if settings.use_remote_compute: result = await _run_remote_export(task_id, job_id, merge_with_base, export_format) return result # 本地模式 adapter_path = settings.adapters_dir / job_id if not adapter_path.exists(): raise ValueError("Adapter not found") import torch from transformers import AutoModelForCausalLM, AutoTokenizer output_path = settings.adapters_dir / f"{job_id}_merged" if merge_with_base: base_model_id = _get_base_model_id_local(job_id) if base_model_id: base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto" ) else: from peft import PeftModel merged = PeftModel.from_pretrained( AutoModelForCausalLM.from_pretrained( str(adapter_path), torch_dtype=torch.float16 ), str(adapter_path), ) merged = merged.merge_and_unload() merged.save_pretrained(output_path) tokenizer = AutoTokenizer.from_pretrained(adapter_path) tokenizer.save_pretrained(output_path) else: import shutil shutil.copytree(adapter_path, output_path) if export_format == "gguf": gguf_path = output_path.with_suffix(".gguf") _export_to_gguf_local(output_path, gguf_path) await _update_deploy_status(task_id, "completed", output_path=str(output_path)) return {"output_path": str(output_path)} except Exception as e: logger.error(f"Export failed for job {job_id}: {e}") await _update_deploy_status(task_id, "failed", error=str(e)) return {"error": str(e)} async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict: """通过 SSH 在远程容器执行部署。""" remote_cmd = ( f"docker exec " f"-e MACA_MPS_MODE=1 " f"-e CUDA_VISIBLE_DEVICES=3 " f"-w {settings.compute_node_workdir} " f"{settings.compute_node_docker_container} " f"{settings.compute_node_python} -c \"" "import asyncio, json; " "from app.core.remote_deploy import run_remote_export; " f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); " "print(json.dumps(result, ensure_ascii=False))\" 2>&1" ) code, stdout, stderr = ssh_exec(remote_cmd, timeout=600) if code != 0: raise RuntimeError(f"Remote export failed: {stderr}") for line in reversed(stdout.strip().split("\n")): line = line.strip() if line.startswith("{"): try: result = json.loads(line) if "error" in result: raise RuntimeError(result["error"]) await _update_deploy_status(task_id, "completed", output_path=result.get("output_path")) return {"output_path": result.get("output_path")} except json.JSONDecodeError: continue raise RuntimeError(f"Invalid response: {stdout[:500]}") async def _update_deploy_status(task_id: str, status: str, output_path: str = None, error: str = None): async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if record: record.status = status if output_path: record.output_path = output_path if error: record.error = error if status in ("completed", "failed"): record.finished_at = datetime.utcnow() await session.commit() background_task_manager.update_task( task_id, status=status, output_path=output_path, error=error, ) def _get_base_model_id_local(job_id: str): config_path = settings.adapters_dir / job_id / "adapter_config.json" if config_path.exists(): import json with open(config_path) as f: return json.load(f).get("base_model_name_or_path") return None def _export_to_gguf_local(model_path: Path, output_path: Path): try: import subprocess result = subprocess.run( ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)], capture_output=True, text=True, timeout=600, ) if result.returncode != 0: logger.error(f"GGUF export failed: {result.stderr}") except Exception as e: logger.warning(f"GGUF export not available: {e}") async def get_deploy_status(task_id: str) -> dict[str, Any]: """获取部署任务状态。""" async with async_session() as session: result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id)) record = result.scalar_one_or_none() if record: return { "job_id": record.job_id, "status": record.status, "progress": record.progress, "output_path": record.output_path, "error": record.error, } return {"job_id": "", "status": "not_found", "progress": 0.0, "output_path": None, "error": None} async def recover_stale_deploys() -> None: async with async_session() as session: result = await session.execute( select(DeployTaskModel).where( DeployTaskModel.status.in_(["pending", "running"]) ) ) records = result.scalars().all() for record in records: record.status = "failed" record.error = "Server restarted, task interrupted" record.finished_at = datetime.utcnow() if records: await session.commit() logger.info(f"Recovered {len(records)} stale deploy tasks")