| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- import json
- import uuid
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.background_tasks import background_task_manager
- from app.core.db import async_session, DeployTaskModel
- from app.core.logging import logger
- from app.core.remote_executor import ssh_exec
- from sqlalchemy import select
- settings = get_settings()
- async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
- """启动导出后台任务,立即返回 task_id。"""
- task_id = str(uuid.uuid4())
- merge_with_base = config.get("merge_with_base", False)
- export_format = config.get("export_format", "safetensors")
- # 写 DB
- task = DeployTaskModel(
- id=task_id,
- job_id=job_id,
- status="pending",
- )
- async with async_session() as session:
- session.add(task)
- await session.commit()
- # 注册并启动
- background_task_manager.register_task(task_id, "deployment", {"job_id": job_id})
- background_task_manager.run(
- task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format)
- )
- logger.info(f"Deploy task started: job={job_id} (task_id={task_id})")
- return {"job_id": job_id, "status": "pending"}
- async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
- """后台执行导出。"""
- try:
- # 远程模式:通过 SSH 在算力节点执行
- if settings.use_remote_compute:
- result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
- return result
- # 本地模式
- adapter_path = settings.adapters_dir / job_id
- if not adapter_path.exists():
- raise ValueError("Adapter not found")
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- output_path = settings.adapters_dir / f"{job_id}_merged"
- if merge_with_base:
- base_model_id = _get_base_model_id_local(job_id)
- if base_model_id:
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_id, torch_dtype=torch.float16, device_map="auto"
- )
- else:
- from peft import PeftModel
- merged = PeftModel.from_pretrained(
- AutoModelForCausalLM.from_pretrained(
- str(adapter_path), torch_dtype=torch.float16
- ),
- str(adapter_path),
- )
- merged = merged.merge_and_unload()
- merged.save_pretrained(output_path)
- tokenizer = AutoTokenizer.from_pretrained(adapter_path)
- tokenizer.save_pretrained(output_path)
- else:
- import shutil
- shutil.copytree(adapter_path, output_path)
- if export_format == "gguf":
- gguf_path = output_path.with_suffix(".gguf")
- _export_to_gguf_local(output_path, gguf_path)
- await _update_deploy_status(task_id, "completed", output_path=str(output_path))
- return {"output_path": str(output_path)}
- except Exception as e:
- logger.error(f"Export failed for job {job_id}: {e}")
- await _update_deploy_status(task_id, "failed", error=str(e))
- return {"error": str(e)}
- async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
- """通过 SSH 在远程容器执行部署。"""
- remote_cmd = (
- f"docker exec "
- f"-e MACA_MPS_MODE=1 "
- f"-e CUDA_VISIBLE_DEVICES=3 "
- f"-w {settings.compute_node_workdir} "
- f"{settings.compute_node_docker_container} "
- f"{settings.compute_node_python} -c \""
- "import asyncio, json; "
- "from app.core.remote_deploy import run_remote_export; "
- f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); "
- "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
- )
- code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
- if code != 0:
- raise RuntimeError(f"Remote export failed: {stderr}")
- for line in reversed(stdout.strip().split("\n")):
- line = line.strip()
- if line.startswith("{"):
- try:
- result = json.loads(line)
- if "error" in result:
- raise RuntimeError(result["error"])
- await _update_deploy_status(task_id, "completed", output_path=result.get("output_path"))
- return {"output_path": result.get("output_path")}
- except json.JSONDecodeError:
- continue
- raise RuntimeError(f"Invalid response: {stdout[:500]}")
- async def _update_deploy_status(task_id: str, status: str, output_path: str = None, error: str = None):
- async with async_session() as session:
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- record.status = status
- if output_path:
- record.output_path = output_path
- if error:
- record.error = error
- if status in ("completed", "failed"):
- record.finished_at = datetime.utcnow()
- await session.commit()
- background_task_manager.update_task(
- task_id, status=status, output_path=output_path, error=error,
- )
- def _get_base_model_id_local(job_id: str):
- config_path = settings.adapters_dir / job_id / "adapter_config.json"
- if config_path.exists():
- import json
- with open(config_path) as f:
- return json.load(f).get("base_model_name_or_path")
- return None
- def _export_to_gguf_local(model_path: Path, output_path: Path):
- try:
- import subprocess
- result = subprocess.run(
- ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
- capture_output=True, text=True, timeout=600,
- )
- if result.returncode != 0:
- logger.error(f"GGUF export failed: {result.stderr}")
- except Exception as e:
- logger.warning(f"GGUF export not available: {e}")
- async def get_deploy_status(task_id: str) -> dict[str, Any]:
- """获取部署任务状态。"""
- async with async_session() as session:
- result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "job_id": record.job_id,
- "status": record.status,
- "progress": record.progress,
- "output_path": record.output_path,
- "error": record.error,
- }
- return {"job_id": "", "status": "not_found", "progress": 0.0, "output_path": None, "error": None}
- async def recover_stale_deploys() -> None:
- async with async_session() as session:
- result = await session.execute(
- select(DeployTaskModel).where(
- DeployTaskModel.status.in_(["pending", "running"])
- )
- )
- records = result.scalars().all()
- for record in records:
- record.status = "failed"
- record.error = "Server restarted, task interrupted"
- record.finished_at = datetime.utcnow()
- if records:
- await session.commit()
- logger.info(f"Recovered {len(records)} stale deploy tasks")
|