deploy_service.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import uuid
  2. from datetime import datetime
  3. from pathlib import Path
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.db import async_session, DeployTaskModel
  7. from app.core.logging import logger
  8. settings = get_settings()
  9. async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
  10. """合并 adapter 与基础模型,并可选导出为 GGUF。"""
  11. task_id = str(uuid.uuid4())
  12. merge_with_base = config.get("merge_with_base", False)
  13. export_format = config.get("export_format", "safetensors")
  14. adapter_path = settings.adapters_dir / job_id
  15. if not adapter_path.exists():
  16. return {"job_id": job_id, "status": "failed", "output_path": None, "error": "Adapter not found"}
  17. output_path = settings.adapters_dir / f"{job_id}_merged"
  18. # 写入数据库
  19. task = DeployTaskModel(
  20. id=task_id,
  21. job_id=job_id,
  22. status="pending",
  23. created_at=datetime.utcnow(),
  24. )
  25. async with async_session() as session:
  26. session.add(task)
  27. await session.commit()
  28. try:
  29. import torch
  30. from transformers import AutoModelForCausalLM, AutoTokenizer
  31. if merge_with_base:
  32. # 加载 base model 并合并 adapter
  33. base_model_id = _get_base_model_id(job_id)
  34. if base_model_id:
  35. base_model = AutoModelForCausalLM.from_pretrained(
  36. base_model_id, torch_dtype=torch.float16, device_map="auto"
  37. )
  38. else:
  39. # 尝试从 adapter config 中推断
  40. from peft import PeftModel
  41. # 直接从 adapter 加载(需要 base_model_name_or_path)
  42. merged = PeftModel.from_pretrained(
  43. AutoModelForCausalLM.from_pretrained(
  44. adapter_path / "adapter_config.json", torch_dtype=torch.float16
  45. ),
  46. adapter_path,
  47. )
  48. merged = merged.merge_and_unload()
  49. merged.save_pretrained(output_path)
  50. tokenizer = AutoTokenizer.from_pretrained(adapter_path)
  51. tokenizer.save_pretrained(output_path)
  52. logger.info(f"Adapter merged and saved to {output_path}")
  53. else:
  54. # 仅复制 adapter 文件
  55. import shutil
  56. shutil.copytree(adapter_path, output_path)
  57. logger.info(f"Adapter copied to {output_path}")
  58. # 可选导出 GGUF
  59. if export_format == "gguf":
  60. gguf_path = output_path.with_suffix(".gguf")
  61. _export_to_gguf(output_path, gguf_path)
  62. # 更新数据库
  63. async with async_session() as session:
  64. from sqlalchemy import select
  65. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  66. record = result.scalar_one_or_none()
  67. if record:
  68. record.status = "completed"
  69. record.output_path = str(output_path)
  70. await session.commit()
  71. return {"job_id": job_id, "status": "completed", "output_path": str(output_path)}
  72. except Exception as e:
  73. logger.error(f"Export failed for job {job_id}: {e}")
  74. async with async_session() as session:
  75. from sqlalchemy import select
  76. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  77. record = result.scalar_one_or_none()
  78. if record:
  79. record.status = "failed"
  80. record.error = str(e)
  81. await session.commit()
  82. return {"job_id": job_id, "status": "failed", "output_path": None, "error": str(e)}
  83. async def get_deploy_status(task_id: str) -> dict[str, Any]:
  84. """获取部署任务状态。"""
  85. async with async_session() as session:
  86. from sqlalchemy import select
  87. result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
  88. record = result.scalar_one_or_none()
  89. if record:
  90. return {
  91. "job_id": record.job_id,
  92. "status": record.status,
  93. "output_path": record.output_path,
  94. "error": record.error,
  95. }
  96. return {"job_id": "", "status": "not_found", "output_path": None, "error": None}
  97. def _get_base_model_id(job_id: str) -> str | None:
  98. """从 adapter config 中获取 base model ID。"""
  99. config_path = settings.adapters_dir / job_id / "adapter_config.json"
  100. if config_path.exists():
  101. import json
  102. with open(config_path) as f:
  103. cfg = json.load(f)
  104. return cfg.get("base_model_name_or_path")
  105. return None
  106. def _export_to_gguf(model_path: Path, output_path: Path):
  107. """导出模型为 GGUF 格式。"""
  108. try:
  109. from llama_cpp import Llama
  110. # 使用 llama-cpp-python 的 convert 工具
  111. import subprocess
  112. result = subprocess.run(
  113. ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
  114. capture_output=True, text=True, timeout=600,
  115. )
  116. if result.returncode != 0:
  117. logger.error(f"GGUF export failed: {result.stderr}")
  118. except Exception as e:
  119. logger.warning(f"GGUF export not available: {e}")