deploy_service.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import uuid
  2. from datetime import datetime
  3. from pathlib import Path
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.db import async_session, DeployTaskModel
  7. from app.core.logging import logger
  8. from app.core.remote_executor import ssh_exec
  9. from sqlalchemy import select
  10. settings = get_settings()
  11. async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
  12. """合并 adapter 与基础模型,并可选导出为 GGUF。"""
  13. task_id = str(uuid.uuid4())
  14. merge_with_base = config.get("merge_with_base", False)
  15. export_format = config.get("export_format", "safetensors")
  16. # 写入数据库
  17. task = DeployTaskModel(
  18. id=task_id,
  19. job_id=job_id,
  20. status="pending",
  21. created_at=datetime.utcnow(),
  22. )
  23. async with async_session() as session:
  24. session.add(task)
  25. await session.commit()
  26. try:
  27. # 远程模式:通过 SSH 在算力节点执行
  28. if settings.use_remote_compute:
  29. result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
  30. return result
  31. # 本地模式(原有逻辑)
  32. adapter_path = settings.adapters_dir / job_id
  33. if not adapter_path.exists():
  34. return _update_task_status(task_id, "failed", error="Adapter not found")
  35. output_path = settings.adapters_dir / f"{job_id}_merged"
  36. import torch
  37. from transformers import AutoModelForCausalLM, AutoTokenizer
  38. if merge_with_base:
  39. base_model_id = _get_base_model_id_local(job_id)
  40. if base_model_id:
  41. base_model = AutoModelForCausalLM.from_pretrained(
  42. base_model_id, torch_dtype=torch.float16, device_map="auto"
  43. )
  44. else:
  45. from peft import PeftModel
  46. merged = PeftModel.from_pretrained(
  47. AutoModelForCausalLM.from_pretrained(
  48. str(adapter_path), torch_dtype=torch.float16
  49. ),
  50. adapter_path,
  51. )
  52. merged = merged.merge_and_unload()
  53. merged.save_pretrained(output_path)
  54. tokenizer = AutoTokenizer.from_pretrained(adapter_path)
  55. tokenizer.save_pretrained(output_path)
  56. else:
  57. import shutil
  58. shutil.copytree(adapter_path, output_path)
  59. if export_format == "gguf":
  60. gguf_path = output_path.with_suffix(".gguf")
  61. _export_to_gguf_local(output_path, gguf_path)
  62. return _update_task_status(task_id, "completed", output_path=str(output_path))
  63. except Exception as e:
  64. logger.error(f"Export failed for job {job_id}: {e}")
  65. return _update_task_status(task_id, "failed", error=str(e))
  66. async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
  67. """通过 SSH 在远程容器执行部署。"""
  68. remote_cmd = (
  69. f"docker exec "
  70. f"-e MACA_MPS_MODE=1 "
  71. f"-e METAX_VISIBLE_DEVICES=2,3 "
  72. f"-w {settings.compute_node_workdir} "
  73. f"{settings.compute_node_docker_container} "
  74. f"{settings.compute_node_python} -c \""
  75. "import asyncio, json; "
  76. "from app.core.remote_deploy import run_remote_export; "
  77. f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); "
  78. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  79. )
  80. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  81. if code != 0:
  82. logger.error(f"Remote export failed: {stderr}")
  83. return _update_task_status(task_id, "failed", error=stderr.strip())
  84. for line in reversed(stdout.strip().split("\n")):
  85. line = line.strip()
  86. if line.startswith("{"):
  87. try:
  88. result = json.loads(line)
  89. if "error" in result:
  90. return _update_task_status(task_id, "failed", error=result["error"])
  91. return _update_task_status(task_id, "completed", output_path=result.get("output_path"))
  92. except json.JSONDecodeError:
  93. continue
  94. return _update_task_status(task_id, "failed", error=f"Invalid response: {stdout[:500]}")
  95. def _update_task_status(task_id: str, status: str, output_path: str = None, error: str = None):
  96. import asyncio
  97. async def _update():
  98. async with async_session() as session:
  99. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  100. record = result.scalar_one_or_none()
  101. if record:
  102. record.status = status
  103. if output_path:
  104. record.output_path = output_path
  105. if error:
  106. record.error = error
  107. await session.commit()
  108. asyncio.get_event_loop().run_until_complete(_update())
  109. base = {"job_id": "", "status": status, "output_path": output_path}
  110. if error:
  111. base["error"] = error
  112. return base
  113. def _get_base_model_id_local(job_id: str):
  114. config_path = settings.adapters_dir / job_id / "adapter_config.json"
  115. if config_path.exists():
  116. import json
  117. with open(config_path) as f:
  118. return json.load(f).get("base_model_name_or_path")
  119. return None
  120. def _export_to_gguf_local(model_path: Path, output_path: Path):
  121. try:
  122. import subprocess
  123. result = subprocess.run(
  124. ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
  125. capture_output=True, text=True, timeout=600,
  126. )
  127. if result.returncode != 0:
  128. logger.error(f"GGUF export failed: {result.stderr}")
  129. except Exception as e:
  130. logger.warning(f"GGUF export not available: {e}")
  131. async def get_deploy_status(task_id: str) -> dict[str, Any]:
  132. """获取部署任务状态。"""
  133. async with async_session() as session:
  134. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  135. record = result.scalar_one_or_none()
  136. if record:
  137. return {
  138. "job_id": record.job_id,
  139. "status": record.status,
  140. "output_path": record.output_path,
  141. "error": record.error,
  142. }
  143. return {"job_id": "", "status": "not_found", "output_path": None, "error": None}