deploy_service.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import json
  2. import uuid
  3. from datetime import datetime, timezone
  4. from pathlib import Path
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.background_tasks import background_task_manager
  8. from app.core.db import async_session, DeployTaskModel
  9. from app.core.logging import logger
  10. from app.core.remote_executor import ssh_exec
  11. from sqlalchemy import select
  12. settings = get_settings()
  13. async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
  14. """启动导出后台任务,立即返回 task_id。"""
  15. task_id = str(uuid.uuid4())
  16. merge_with_base = config.get("merge_with_base", False)
  17. export_format = config.get("export_format", "safetensors")
  18. # 写 DB
  19. task = DeployTaskModel(
  20. id=task_id,
  21. job_id=job_id,
  22. status="pending",
  23. )
  24. async with async_session() as session:
  25. session.add(task)
  26. await session.commit()
  27. # 注册并启动
  28. background_task_manager.register_task(task_id, "deployment", {"job_id": job_id})
  29. background_task_manager.run(
  30. task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format)
  31. )
  32. logger.info(f"Deploy task started: job={job_id} (task_id={task_id})")
  33. return {"job_id": job_id, "status": "pending"}
  34. async def _execute_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
  35. """后台执行导出。"""
  36. try:
  37. # 远程模式:通过 SSH 在算力节点执行
  38. if settings.use_remote_compute:
  39. result = await _run_remote_export(task_id, job_id, merge_with_base, export_format)
  40. return result
  41. # 本地模式
  42. adapter_path = settings.adapters_dir / job_id
  43. if not adapter_path.exists():
  44. raise ValueError("Adapter not found")
  45. import torch
  46. from transformers import AutoModelForCausalLM, AutoTokenizer
  47. output_path = settings.adapters_dir / f"{job_id}_merged"
  48. if merge_with_base:
  49. base_model_id = _get_base_model_id_local(job_id)
  50. if base_model_id:
  51. base_model = AutoModelForCausalLM.from_pretrained(
  52. base_model_id, torch_dtype=torch.float16, device_map="auto"
  53. )
  54. else:
  55. from peft import PeftModel
  56. merged = PeftModel.from_pretrained(
  57. AutoModelForCausalLM.from_pretrained(
  58. str(adapter_path), torch_dtype=torch.float16
  59. ),
  60. str(adapter_path),
  61. )
  62. merged = merged.merge_and_unload()
  63. merged.save_pretrained(output_path)
  64. tokenizer = AutoTokenizer.from_pretrained(adapter_path)
  65. tokenizer.save_pretrained(output_path)
  66. else:
  67. import shutil
  68. shutil.copytree(adapter_path, output_path)
  69. if export_format == "gguf":
  70. gguf_path = output_path.with_suffix(".gguf")
  71. _export_to_gguf_local(output_path, gguf_path)
  72. await _update_deploy_status(task_id, "completed", output_path=str(output_path))
  73. return {"output_path": str(output_path)}
  74. except Exception as e:
  75. logger.error(f"Export failed for job {job_id}: {e}")
  76. await _update_deploy_status(task_id, "failed", error=str(e))
  77. return {"error": str(e)}
  78. async def _run_remote_export(task_id: str, job_id: str, merge_with_base: bool, export_format: str) -> dict:
  79. """通过 SSH 在远程容器执行部署。"""
  80. remote_cmd = (
  81. f"docker exec "
  82. f"-e MACA_MPS_MODE=1 "
  83. f"-e CUDA_VISIBLE_DEVICES=3 "
  84. f"-w {settings.compute_node_workdir} "
  85. f"{settings.compute_node_docker_container} "
  86. f"{settings.compute_node_python} -c \""
  87. "import asyncio, json; "
  88. "from app.core.remote_deploy import run_remote_export; "
  89. f"result = asyncio.run(run_remote_export('{job_id}', {merge_with_base}, '{export_format}')); "
  90. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  91. )
  92. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  93. if code != 0:
  94. raise RuntimeError(f"Remote export failed: {stderr}")
  95. for line in reversed(stdout.strip().split("\n")):
  96. line = line.strip()
  97. if line.startswith("{"):
  98. try:
  99. result = json.loads(line)
  100. if "error" in result:
  101. raise RuntimeError(result["error"])
  102. await _update_deploy_status(task_id, "completed", output_path=result.get("output_path"))
  103. return {"output_path": result.get("output_path")}
  104. except json.JSONDecodeError:
  105. continue
  106. raise RuntimeError(f"Invalid response: {stdout[:500]}")
  107. async def _update_deploy_status(task_id: str, status: str, output_path: str = None, error: str = None):
  108. async with async_session() as session:
  109. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  110. record = result.scalar_one_or_none()
  111. if record:
  112. record.status = status
  113. if output_path:
  114. record.output_path = output_path
  115. if error:
  116. record.error = error
  117. if status in ("completed", "failed"):
  118. record.finished_at = datetime.utcnow()
  119. await session.commit()
  120. background_task_manager.update_task(
  121. task_id, status=status, output_path=output_path, error=error,
  122. )
  123. def _get_base_model_id_local(job_id: str):
  124. config_path = settings.adapters_dir / job_id / "adapter_config.json"
  125. if config_path.exists():
  126. import json
  127. with open(config_path) as f:
  128. return json.load(f).get("base_model_name_or_path")
  129. return None
  130. def _export_to_gguf_local(model_path: Path, output_path: Path):
  131. try:
  132. import subprocess
  133. result = subprocess.run(
  134. ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
  135. capture_output=True, text=True, timeout=600,
  136. )
  137. if result.returncode != 0:
  138. logger.error(f"GGUF export failed: {result.stderr}")
  139. except Exception as e:
  140. logger.warning(f"GGUF export not available: {e}")
  141. async def get_deploy_status(task_id: str) -> dict[str, Any]:
  142. """获取部署任务状态。"""
  143. async with async_session() as session:
  144. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  145. record = result.scalar_one_or_none()
  146. if record:
  147. return {
  148. "job_id": record.job_id,
  149. "status": record.status,
  150. "progress": record.progress,
  151. "output_path": record.output_path,
  152. "error": record.error,
  153. }
  154. return {"job_id": "", "status": "not_found", "progress": 0.0, "output_path": None, "error": None}
  155. async def recover_stale_deploys() -> None:
  156. async with async_session() as session:
  157. result = await session.execute(
  158. select(DeployTaskModel).where(
  159. DeployTaskModel.status.in_(["pending", "running"])
  160. )
  161. )
  162. records = result.scalars().all()
  163. for record in records:
  164. record.status = "failed"
  165. record.error = "Server restarted, task interrupted"
  166. record.finished_at = datetime.utcnow()
  167. if records:
  168. await session.commit()
  169. logger.info(f"Recovered {len(records)} stale deploy tasks")