inference_service.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """推理服务 — 支持本地执行和 SSH 远程执行两种模式。"""
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.logging import logger
  7. settings = get_settings()
  8. async def generate(
  9. adapter_path: str,
  10. prompt: str,
  11. max_new_tokens: int = 256,
  12. temperature: float = 0.8,
  13. top_p: float = 0.95,
  14. repetition_penalty: float = 1.1,
  15. do_sample: bool = True,
  16. ) -> dict[str, Any]:
  17. """使用已训练的 adapter 生成文本。"""
  18. # 从 adapter config 中获取 base model ID
  19. base_model_id = _get_base_model_id(adapter_path)
  20. if not base_model_id:
  21. return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
  22. if settings.use_remote_compute:
  23. # 远程执行模式
  24. from app.core.remote_executor import run_inference_remote
  25. adapter_dir = Path(adapter_path)
  26. adapter_id = adapter_dir.name
  27. result = run_inference_remote(
  28. model_id=base_model_id,
  29. adapter_id=adapter_id,
  30. prompt=prompt,
  31. max_new_tokens=max_new_tokens,
  32. temperature=temperature,
  33. top_p=top_p,
  34. repetition_penalty=repetition_penalty,
  35. do_sample=do_sample,
  36. )
  37. if result:
  38. return result
  39. return {"error": "Remote inference failed"}
  40. # 本地执行模式
  41. return _generate_local(
  42. adapter_path=adapter_path,
  43. base_model_id=base_model_id,
  44. prompt=prompt,
  45. max_new_tokens=max_new_tokens,
  46. temperature=temperature,
  47. top_p=top_p,
  48. repetition_penalty=repetition_penalty,
  49. do_sample=do_sample,
  50. )
  51. def _generate_local(
  52. adapter_path: str,
  53. base_model_id: str,
  54. prompt: str,
  55. max_new_tokens: int,
  56. temperature: float,
  57. top_p: float,
  58. repetition_penalty: float,
  59. do_sample: bool,
  60. ) -> dict[str, Any]:
  61. """本地执行推理。"""
  62. try:
  63. import os
  64. import torch
  65. from transformers import AutoModelForCausalLM, AutoTokenizer
  66. from peft import PeftModel
  67. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  68. if tokenizer.pad_token is None:
  69. tokenizer.pad_token = tokenizer.eos_token
  70. # CUDA_VISIBLE_DEVICES=3 已将物理 GPU 3 映射为逻辑 GPU 0
  71. import torch
  72. device_map = {"": 0}
  73. torch.cuda.set_device(0)
  74. base_model = AutoModelForCausalLM.from_pretrained(
  75. base_model_id,
  76. torch_dtype=torch.float16,
  77. device_map=device_map,
  78. )
  79. model = PeftModel.from_pretrained(base_model, adapter_path)
  80. model.eval()
  81. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  82. with torch.no_grad():
  83. outputs = model.generate(
  84. **inputs,
  85. max_new_tokens=max_new_tokens,
  86. temperature=temperature,
  87. top_p=top_p,
  88. repetition_penalty=repetition_penalty,
  89. do_sample=do_sample,
  90. pad_token_id=tokenizer.eos_token_id,
  91. )
  92. generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  93. generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  94. return {
  95. "prompt": prompt,
  96. "generated_text": generated_text,
  97. "generated_only": generated_only,
  98. "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]),
  99. }
  100. except Exception as e:
  101. logger.error(f"Inference failed: {e}")
  102. return {"error": str(e)}
  103. def _get_base_model_id(adapter_path: str) -> str | None:
  104. """从 adapter config 中获取 base model ID。"""
  105. config_path = Path(adapter_path) / "adapter_config.json"
  106. if config_path.exists():
  107. with open(config_path) as f:
  108. cfg = json.load(f)
  109. return cfg.get("base_model_name_or_path")
  110. return None
  111. async def get_available_adapters() -> list[dict[str, Any]]:
  112. """列出所有已训练完成的 adapter(仅显示 status=completed 的任务)。"""
  113. from app.core.db import async_session, TrainingJobModel
  114. from sqlalchemy import select
  115. # 查询数据库中训练完成的任务
  116. async with async_session() as session:
  117. result = await session.execute(
  118. select(TrainingJobModel).where(TrainingJobModel.status == "completed")
  119. )
  120. completed_jobs = {job.id: job for job in result.scalars().all()}
  121. if not completed_jobs:
  122. return []
  123. adapters_dir = settings.adapters_dir
  124. if not adapters_dir.exists():
  125. return []
  126. result = []
  127. for job_id, job in sorted(completed_jobs.items(), key=lambda x: x[1].created_at, reverse=True):
  128. adapter_dir = adapters_dir / job_id
  129. if not adapter_dir.is_dir():
  130. continue
  131. adapter_config = adapter_dir / "adapter_config.json"
  132. if not adapter_config.exists():
  133. continue
  134. with open(adapter_config) as f:
  135. cfg = json.load(f)
  136. result.append({
  137. "id": job_id,
  138. "path": str(adapter_dir),
  139. "base_model": cfg.get("base_model_name_or_path", "unknown"),
  140. "peft_type": cfg.get("peft_type", "unknown"),
  141. "model_id": job.model_id,
  142. "peft_method": job.peft_method,
  143. "task_type": job.task_type,
  144. "created_at": job.created_at.isoformat() if job.created_at else None,
  145. })
  146. return result
  147. async def run_inference_single(
  148. model_id: str,
  149. adapter_id: str,
  150. prompt: str,
  151. max_new_tokens: int,
  152. temperature: float,
  153. top_p: float,
  154. repetition_penalty: float,
  155. do_sample: bool,
  156. ) -> dict[str, Any]:
  157. """供远程 SSH 调用的单条推理入口。"""
  158. adapter_path = str(settings.adapters_dir / adapter_id)
  159. return _generate_local(
  160. adapter_path=adapter_path,
  161. base_model_id=model_id,
  162. prompt=prompt,
  163. max_new_tokens=max_new_tokens,
  164. temperature=temperature,
  165. top_p=top_p,
  166. repetition_penalty=repetition_penalty,
  167. do_sample=do_sample,
  168. )