inference_service.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import json
  2. from pathlib import Path
  3. from typing import Any
  4. from app.config import get_settings
  5. from app.core.db import async_session, TrainingJobModel
  6. from app.core.logging import logger
  7. from sqlalchemy import select
  8. settings = get_settings()
  9. async def generate(
  10. adapter_path: str,
  11. prompt: str,
  12. max_new_tokens: int = 256,
  13. temperature: float = 0.8,
  14. top_p: float = 0.95,
  15. repetition_penalty: float = 1.1,
  16. do_sample: bool = True,
  17. ) -> dict[str, Any]:
  18. """使用已训练的 adapter 生成文本。"""
  19. try:
  20. import torch
  21. from transformers import AutoModelForCausalLM, AutoTokenizer
  22. # 推断 base model
  23. base_model_id = _get_base_model_id(adapter_path)
  24. if not base_model_id:
  25. # 尝试从训练记录中获取
  26. return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
  27. # 加载 tokenizer
  28. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  29. if tokenizer.pad_token is None:
  30. tokenizer.pad_token = tokenizer.eos_token
  31. # 加载 base model + adapter
  32. from peft import PeftModel
  33. base_model = AutoModelForCausalLM.from_pretrained(
  34. base_model_id,
  35. torch_dtype=torch.float16,
  36. device_map="auto",
  37. )
  38. model = PeftModel.from_pretrained(base_model, adapter_path)
  39. model.eval()
  40. # Tokenize prompt
  41. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  42. # Generate
  43. with torch.no_grad():
  44. outputs = model.generate(
  45. **inputs,
  46. max_new_tokens=max_new_tokens,
  47. temperature=temperature,
  48. top_p=top_p,
  49. repetition_penalty=repetition_penalty,
  50. do_sample=do_sample,
  51. pad_token_id=tokenizer.eos_token_id,
  52. )
  53. # Decode
  54. generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  55. generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  56. return {
  57. "prompt": prompt,
  58. "generated_text": generated_text,
  59. "generated_only": generated_only,
  60. "tokens_generated": outputs.shape[1] - inputs["input_ids"].shape[1],
  61. }
  62. except Exception as e:
  63. logger.error(f"Inference failed: {e}")
  64. return {"error": str(e)}
  65. def _get_base_model_id(adapter_path: str) -> str | None:
  66. """从 adapter config 中获取 base model ID。"""
  67. config_path = Path(adapter_path) / "adapter_config.json"
  68. if config_path.exists():
  69. with open(config_path) as f:
  70. cfg = json.load(f)
  71. return cfg.get("base_model_name_or_path")
  72. return None
  73. async def get_available_adapters() -> list[dict[str, Any]]:
  74. """列出所有已训练的 adapter。"""
  75. adapters_dir = settings.adapters_dir
  76. if not adapters_dir.exists():
  77. return []
  78. result = []
  79. for d in adapters_dir.iterdir():
  80. if not d.is_dir():
  81. continue
  82. adapter_config = d / "adapter_config.json"
  83. if adapter_config.exists():
  84. with open(adapter_config) as f:
  85. cfg = json.load(f)
  86. result.append({
  87. "id": d.name,
  88. "path": str(d),
  89. "base_model": cfg.get("base_model_name_or_path", "unknown"),
  90. "peft_type": cfg.get("peft_type", "unknown"),
  91. })
  92. return result