| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- """推理服务 — 支持本地执行和 SSH 远程执行两种模式。"""
- import json
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.logging import logger
- settings = get_settings()
- async def generate(
- adapter_path: str,
- prompt: str,
- max_new_tokens: int = 256,
- temperature: float = 0.8,
- top_p: float = 0.95,
- repetition_penalty: float = 1.1,
- do_sample: bool = True,
- ) -> dict[str, Any]:
- """使用已训练的 adapter 生成文本。"""
- # 从 adapter config 中获取 base model ID
- base_model_id = _get_base_model_id(adapter_path)
- if not base_model_id:
- return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
- if settings.use_remote_compute:
- # 远程执行模式
- from app.core.remote_executor import run_inference_remote
- adapter_dir = Path(adapter_path)
- adapter_id = adapter_dir.name
- result = run_inference_remote(
- model_id=base_model_id,
- adapter_id=adapter_id,
- prompt=prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- )
- if result:
- return result
- return {"error": "Remote inference failed"}
- # 本地执行模式
- return _generate_local(
- adapter_path=adapter_path,
- base_model_id=base_model_id,
- prompt=prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- )
- def _generate_local(
- adapter_path: str,
- base_model_id: str,
- prompt: str,
- max_new_tokens: int,
- temperature: float,
- top_p: float,
- repetition_penalty: float,
- do_sample: bool,
- ) -> dict[str, Any]:
- """本地执行推理。"""
- try:
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from peft import PeftModel
- tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_id,
- torch_dtype=torch.float16,
- device_map="auto",
- )
- model = PeftModel.from_pretrained(base_model, adapter_path)
- model.eval()
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- pad_token_id=tokenizer.eos_token_id,
- )
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
- generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
- return {
- "prompt": prompt,
- "generated_text": generated_text,
- "generated_only": generated_only,
- "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]),
- }
- except Exception as e:
- logger.error(f"Inference failed: {e}")
- return {"error": str(e)}
- def _get_base_model_id(adapter_path: str) -> str | None:
- """从 adapter config 中获取 base model ID。"""
- config_path = Path(adapter_path) / "adapter_config.json"
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- return cfg.get("base_model_name_or_path")
- return None
- async def get_available_adapters() -> list[dict[str, Any]]:
- """列出所有已训练的 adapter。"""
- adapters_dir = settings.adapters_dir
- if not adapters_dir.exists():
- return []
- result = []
- for d in sorted(adapters_dir.iterdir()):
- if not d.is_dir():
- continue
- adapter_config = d / "adapter_config.json"
- if adapter_config.exists():
- with open(adapter_config) as f:
- cfg = json.load(f)
- result.append({
- "id": d.name,
- "path": str(d),
- "base_model": cfg.get("base_model_name_or_path", "unknown"),
- "peft_type": cfg.get("peft_type", "unknown"),
- })
- return result
- async def run_inference_single(
- model_id: str,
- adapter_id: str,
- prompt: str,
- max_new_tokens: int,
- temperature: float,
- top_p: float,
- repetition_penalty: float,
- do_sample: bool,
- ) -> dict[str, Any]:
- """供远程 SSH 调用的单条推理入口。"""
- adapter_path = str(settings.adapters_dir / adapter_id)
- return _generate_local(
- adapter_path=adapter_path,
- base_model_id=model_id,
- prompt=prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- )
|