| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import json
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.db import async_session, TrainingJobModel
- from app.core.logging import logger
- from sqlalchemy import select
- settings = get_settings()
- async def generate(
- adapter_path: str,
- prompt: str,
- max_new_tokens: int = 256,
- temperature: float = 0.8,
- top_p: float = 0.95,
- repetition_penalty: float = 1.1,
- do_sample: bool = True,
- ) -> dict[str, Any]:
- """使用已训练的 adapter 生成文本。"""
- try:
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- # 推断 base model
- base_model_id = _get_base_model_id(adapter_path)
- if not base_model_id:
- # 尝试从训练记录中获取
- return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
- # 加载 tokenizer
- tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- # 加载 base model + adapter
- from peft import PeftModel
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_id,
- torch_dtype=torch.float16,
- device_map="auto",
- )
- model = PeftModel.from_pretrained(base_model, adapter_path)
- model.eval()
- # Tokenize prompt
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
- # Generate
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- pad_token_id=tokenizer.eos_token_id,
- )
- # Decode
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
- generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
- return {
- "prompt": prompt,
- "generated_text": generated_text,
- "generated_only": generated_only,
- "tokens_generated": outputs.shape[1] - inputs["input_ids"].shape[1],
- }
- except Exception as e:
- logger.error(f"Inference failed: {e}")
- return {"error": str(e)}
- def _get_base_model_id(adapter_path: str) -> str | None:
- """从 adapter config 中获取 base model ID。"""
- config_path = Path(adapter_path) / "adapter_config.json"
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- return cfg.get("base_model_name_or_path")
- return None
- async def get_available_adapters() -> list[dict[str, Any]]:
- """列出所有已训练的 adapter。"""
- adapters_dir = settings.adapters_dir
- if not adapters_dir.exists():
- return []
- result = []
- for d in adapters_dir.iterdir():
- if not d.is_dir():
- continue
- adapter_config = d / "adapter_config.json"
- if adapter_config.exists():
- with open(adapter_config) as f:
- cfg = json.load(f)
- result.append({
- "id": d.name,
- "path": str(d),
- "base_model": cfg.get("base_model_name_or_path", "unknown"),
- "peft_type": cfg.get("peft_type", "unknown"),
- })
- return result
|