inference_service.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """推理服务 — 本地执行。"""
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.logging import logger
  7. settings = get_settings()
  8. async def generate(
  9. adapter_path: str,
  10. prompt: str,
  11. max_new_tokens: int = 256,
  12. temperature: float = 0.8,
  13. top_p: float = 0.95,
  14. repetition_penalty: float = 1.1,
  15. do_sample: bool = True,
  16. ) -> dict[str, Any]:
  17. """使用已训练的 adapter 生成文本。"""
  18. # 从 adapter config 中获取 base model ID
  19. base_model_id = _get_base_model_id(adapter_path)
  20. if not base_model_id:
  21. return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
  22. return _generate_local(
  23. adapter_path=adapter_path,
  24. base_model_id=base_model_id,
  25. prompt=prompt,
  26. max_new_tokens=max_new_tokens,
  27. temperature=temperature,
  28. top_p=top_p,
  29. repetition_penalty=repetition_penalty,
  30. do_sample=do_sample,
  31. )
  32. def _generate_local(
  33. adapter_path: str,
  34. base_model_id: str,
  35. prompt: str,
  36. max_new_tokens: int,
  37. temperature: float,
  38. top_p: float,
  39. repetition_penalty: float,
  40. do_sample: bool,
  41. ) -> dict[str, Any]:
  42. """本地执行推理。"""
  43. try:
  44. import torch
  45. from transformers import AutoModelForCausalLM, AutoTokenizer
  46. from peft import PeftModel
  47. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  48. if tokenizer.pad_token is None:
  49. tokenizer.pad_token = tokenizer.eos_token
  50. base_model = AutoModelForCausalLM.from_pretrained(
  51. base_model_id,
  52. torch_dtype=torch.float16,
  53. device_map="auto",
  54. )
  55. model = PeftModel.from_pretrained(base_model, adapter_path)
  56. model.eval()
  57. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  58. with torch.no_grad():
  59. outputs = model.generate(
  60. **inputs,
  61. max_new_tokens=max_new_tokens,
  62. temperature=temperature,
  63. top_p=top_p,
  64. repetition_penalty=repetition_penalty,
  65. do_sample=do_sample,
  66. pad_token_id=tokenizer.eos_token_id,
  67. )
  68. generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  69. generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  70. return {
  71. "prompt": prompt,
  72. "generated_text": generated_text,
  73. "generated_only": generated_only,
  74. "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]),
  75. }
  76. except Exception as e:
  77. logger.error(f"Inference failed: {e}")
  78. return {"error": str(e)}
  79. def _get_base_model_id(adapter_path: str) -> str | None:
  80. """从 adapter config 中获取 base model ID。"""
  81. config_path = Path(adapter_path) / "adapter_config.json"
  82. if config_path.exists():
  83. with open(config_path) as f:
  84. cfg = json.load(f)
  85. return cfg.get("base_model_name_or_path")
  86. return None
  87. async def get_available_adapters() -> list[dict[str, Any]]:
  88. """列出所有已训练的 adapter。"""
  89. adapters_dir = settings.adapters_dir
  90. if not adapters_dir.exists():
  91. return []
  92. result = []
  93. for d in sorted(adapters_dir.iterdir()):
  94. if not d.is_dir():
  95. continue
  96. adapter_config = d / "adapter_config.json"
  97. if adapter_config.exists():
  98. with open(adapter_config) as f:
  99. cfg = json.load(f)
  100. result.append({
  101. "id": d.name,
  102. "path": str(d),
  103. "base_model": cfg.get("base_model_name_or_path", "unknown"),
  104. "peft_type": cfg.get("peft_type", "unknown"),
  105. })
  106. return result