"""推理服务 — 本地执行。""" import json from pathlib import Path from typing import Any from app.config import get_settings from app.core.logging import logger settings = get_settings() async def generate( adapter_path: str, prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_p: float = 0.95, repetition_penalty: float = 1.1, do_sample: bool = True, ) -> dict[str, Any]: """使用已训练的 adapter 生成文本。""" # 从 adapter config 中获取 base model ID base_model_id = _get_base_model_id(adapter_path) if not base_model_id: return {"error": "无法找到基础模型信息,请确保训练任务已完成"} return _generate_local( adapter_path=adapter_path, base_model_id=base_model_id, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, ) def _generate_local( adapter_path: str, base_model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, do_sample: bool, ) -> dict[str, Any]: """本地执行推理。""" try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto", ) model = PeftModel.from_pretrained(base_model, adapter_path) model.eval() inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) return { "prompt": prompt, "generated_text": generated_text, "generated_only": generated_only, "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]), } except Exception as e: logger.error(f"Inference failed: {e}") return {"error": str(e)} def _get_base_model_id(adapter_path: str) -> str | None: """从 adapter config 中获取 base model ID。""" config_path = Path(adapter_path) / "adapter_config.json" if config_path.exists(): with open(config_path) as f: cfg = json.load(f) return cfg.get("base_model_name_or_path") return None async def get_available_adapters() -> list[dict[str, Any]]: """列出所有已训练的 adapter。""" adapters_dir = settings.adapters_dir if not adapters_dir.exists(): return [] result = [] for d in sorted(adapters_dir.iterdir()): if not d.is_dir(): continue adapter_config = d / "adapter_config.json" if adapter_config.exists(): with open(adapter_config) as f: cfg = json.load(f) result.append({ "id": d.name, "path": str(d), "base_model": cfg.get("base_model_name_or_path", "unknown"), "peft_type": cfg.get("peft_type", "unknown"), }) return result