inference_service.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. """推理服务 — 支持本地执行和 SSH 远程执行两种模式。"""
  2. import asyncio
  3. import json
  4. from pathlib import Path
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.logging import logger
  8. settings = get_settings()
  9. async def generate(
  10. adapter_path: str,
  11. prompt: str,
  12. max_new_tokens: int = 256,
  13. temperature: float = 0.8,
  14. top_p: float = 0.95,
  15. repetition_penalty: float = 1.1,
  16. do_sample: bool = True,
  17. ) -> dict[str, Any]:
  18. """使用已训练的 adapter 生成文本。"""
  19. # 从 adapter config 中获取 base model ID
  20. base_model_id = _get_base_model_id(adapter_path)
  21. if not base_model_id:
  22. return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
  23. if settings.use_remote_compute:
  24. # 远程执行模式(用 to_thread 避免阻塞事件循环)
  25. from app.core.remote_executor import run_inference_remote
  26. adapter_dir = Path(adapter_path)
  27. adapter_id = adapter_dir.name
  28. result = await asyncio.to_thread(
  29. run_inference_remote,
  30. model_id=base_model_id,
  31. adapter_id=adapter_id,
  32. prompt=prompt,
  33. max_new_tokens=max_new_tokens,
  34. temperature=temperature,
  35. top_p=top_p,
  36. repetition_penalty=repetition_penalty,
  37. do_sample=do_sample,
  38. )
  39. if result:
  40. return result
  41. return {"error": "Remote inference failed"}
  42. # 本地执行模式(用 to_thread 避免 GPU 操作阻塞事件循环)
  43. return await asyncio.to_thread(
  44. _generate_local,
  45. adapter_path=adapter_path,
  46. base_model_id=base_model_id,
  47. prompt=prompt,
  48. max_new_tokens=max_new_tokens,
  49. temperature=temperature,
  50. top_p=top_p,
  51. repetition_penalty=repetition_penalty,
  52. do_sample=do_sample,
  53. )
  54. def _generate_local(
  55. adapter_path: str,
  56. base_model_id: str,
  57. prompt: str,
  58. max_new_tokens: int,
  59. temperature: float,
  60. top_p: float,
  61. repetition_penalty: float,
  62. do_sample: bool,
  63. ) -> dict[str, Any]:
  64. """本地执行推理。"""
  65. try:
  66. import os
  67. import torch
  68. from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
  69. from peft import PeftModel
  70. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  71. if tokenizer.pad_token is None:
  72. tokenizer.pad_token = tokenizer.eos_token
  73. # CUDA_VISIBLE_DEVICES 由调用方(deploy_service)设置为多卡
  74. # device_map="auto" 自动将模型层分散到所有可见 GPU
  75. import torch
  76. device_map = "auto" if torch.cuda.is_available() else "cpu"
  77. torch.cuda.set_device(0)
  78. base_model = AutoModelForCausalLM.from_pretrained(
  79. base_model_id,
  80. torch_dtype=torch.float16,
  81. device_map=device_map,
  82. )
  83. model = PeftModel.from_pretrained(base_model, adapter_path)
  84. model.eval()
  85. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  86. # 构建 stop criteria:遇到角色标记就停止,防止复读
  87. stop_phrases = ["<|user|>", "<|system|>", "<|assistant|>"]
  88. stop_token_ids = [tokenizer.encode(p, add_special_tokens=False) for p in stop_phrases]
  89. class StopOnRoleToken(StoppingCriteria):
  90. def __call__(self, input_ids, scores, **kwargs):
  91. gen_seq = input_ids[0].tolist()
  92. for s_ids in stop_token_ids:
  93. if len(gen_seq) >= len(s_ids) and gen_seq[-len(s_ids):] == s_ids:
  94. return True
  95. return False
  96. stopping_criteria = StoppingCriteriaList([StopOnRoleToken()])
  97. with torch.no_grad():
  98. outputs = model.generate(
  99. **inputs,
  100. max_new_tokens=max_new_tokens,
  101. temperature=temperature,
  102. top_p=top_p,
  103. repetition_penalty=repetition_penalty,
  104. do_sample=do_sample,
  105. pad_token_id=tokenizer.eos_token_id,
  106. eos_token_id=tokenizer.eos_token_id,
  107. stopping_criteria=stopping_criteria,
  108. )
  109. generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  110. generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  111. # 清理可能残留的角色标记
  112. for marker in ["<|user|>", "<|system|>", "<|assistant|>"]:
  113. if marker in generated_only:
  114. generated_only = generated_only[:generated_only.index(marker)]
  115. generated_only = generated_only.strip()
  116. return {
  117. "prompt": prompt,
  118. "generated_text": generated_text,
  119. "generated_only": generated_only,
  120. "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]),
  121. }
  122. except Exception as e:
  123. logger.error(f"Inference failed: {e}")
  124. return {"error": str(e)}
  125. def _get_base_model_id(adapter_path: str) -> str | None:
  126. """从 adapter config 中获取 base model ID。"""
  127. config_path = Path(adapter_path) / "adapter_config.json"
  128. if config_path.exists():
  129. with open(config_path) as f:
  130. cfg = json.load(f)
  131. return cfg.get("base_model_name_or_path")
  132. return None
  133. async def get_available_adapters() -> list[dict[str, Any]]:
  134. """列出所有已训练完成的 adapter(仅显示 status=completed 的任务)。"""
  135. from app.core.db import async_session, TrainingJobModel
  136. from sqlalchemy import select
  137. # 查询数据库中训练完成的任务
  138. async with async_session() as session:
  139. result = await session.execute(
  140. select(TrainingJobModel).where(TrainingJobModel.status == "completed")
  141. )
  142. completed_jobs = {job.id: job for job in result.scalars().all()}
  143. if not completed_jobs:
  144. return []
  145. adapters_dir = settings.adapters_dir
  146. if not adapters_dir.exists():
  147. return []
  148. result = []
  149. for job_id, job in sorted(completed_jobs.items(), key=lambda x: x[1].created_at, reverse=True):
  150. adapter_dir = adapters_dir / job_id
  151. if not adapter_dir.is_dir():
  152. continue
  153. adapter_config = adapter_dir / "adapter_config.json"
  154. if not adapter_config.exists():
  155. continue
  156. with open(adapter_config) as f:
  157. cfg = json.load(f)
  158. result.append({
  159. "id": job_id,
  160. "path": str(adapter_dir),
  161. "base_model": cfg.get("base_model_name_or_path", "unknown"),
  162. "peft_type": cfg.get("peft_type", "unknown"),
  163. "model_id": job.model_id,
  164. "peft_method": job.peft_method,
  165. "task_type": job.task_type,
  166. "created_at": job.created_at.isoformat() if job.created_at else None,
  167. })
  168. return result
  169. async def run_inference_single(
  170. model_id: str,
  171. adapter_id: str,
  172. prompt: str,
  173. max_new_tokens: int,
  174. temperature: float,
  175. top_p: float,
  176. repetition_penalty: float,
  177. do_sample: bool,
  178. ) -> dict[str, Any]:
  179. """供远程 SSH 调用的单条推理入口。"""
  180. adapter_path = str(settings.adapters_dir / adapter_id)
  181. return _generate_local(
  182. adapter_path=adapter_path,
  183. base_model_id=model_id,
  184. prompt=prompt,
  185. max_new_tokens=max_new_tokens,
  186. temperature=temperature,
  187. top_p=top_p,
  188. repetition_penalty=repetition_penalty,
  189. do_sample=do_sample,
  190. )