"""推理服务 — 支持本地执行和 SSH 远程执行两种模式。""" import asyncio import json from pathlib import Path from typing import Any from app.config import get_settings from app.core.logging import logger settings = get_settings() async def generate( adapter_path: str, prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_p: float = 0.95, repetition_penalty: float = 1.1, do_sample: bool = True, ) -> dict[str, Any]: """使用已训练的 adapter 生成文本。""" # 从 adapter config 中获取 base model ID base_model_id = _get_base_model_id(adapter_path) if not base_model_id: return {"error": "无法找到基础模型信息,请确保训练任务已完成"} if settings.use_remote_compute: # 远程执行模式(用 to_thread 避免阻塞事件循环) from app.core.remote_executor import run_inference_remote adapter_dir = Path(adapter_path) adapter_id = adapter_dir.name result = await asyncio.to_thread( run_inference_remote, model_id=base_model_id, adapter_id=adapter_id, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, ) if result: return result return {"error": "Remote inference failed"} # 本地执行模式(用 to_thread 避免 GPU 操作阻塞事件循环) return await asyncio.to_thread( _generate_local, adapter_path=adapter_path, base_model_id=base_model_id, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, ) def _generate_local( adapter_path: str, base_model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, do_sample: bool, ) -> dict[str, Any]: """本地执行推理。""" try: import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Single GPU inference: load entire model on cuda:0 # Avoid device_map="auto" which can split model across GPUs and cause # device mismatch errors (e.g., rotary_emb, bmm operations) import torch device_map = {"": 0} if torch.cuda.is_available() else "cpu" base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map=device_map, ) model = PeftModel.from_pretrained(base_model, adapter_path) model.eval() inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # 构建 stop criteria:遇到角色标记就停止,防止复读 stop_phrases = ["<|user|>", "<|system|>", "<|assistant|>"] stop_token_ids = [tokenizer.encode(p, add_special_tokens=False) for p in stop_phrases] class StopOnRoleToken(StoppingCriteria): def __call__(self, input_ids, scores, **kwargs): gen_seq = input_ids[0].tolist() for s_ids in stop_token_ids: if len(gen_seq) >= len(s_ids) and gen_seq[-len(s_ids):] == s_ids: return True return False stopping_criteria = StoppingCriteriaList([StopOnRoleToken()]) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, stopping_criteria=stopping_criteria, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) # 清理可能残留的角色标记 for marker in ["<|user|>", "<|system|>", "<|assistant|>"]: if marker in generated_only: generated_only = generated_only[:generated_only.index(marker)] generated_only = generated_only.strip() return { "prompt": prompt, "generated_text": generated_text, "generated_only": generated_only, "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]), } except Exception as e: logger.error(f"Inference failed: {e}") return {"error": str(e)} def _get_base_model_id(adapter_path: str) -> str | None: """从 adapter config 中获取 base model ID。""" config_path = Path(adapter_path) / "adapter_config.json" if config_path.exists(): with open(config_path) as f: cfg = json.load(f) return cfg.get("base_model_name_or_path") return None async def get_available_adapters() -> list[dict[str, Any]]: """列出所有已训练完成的 adapter(仅显示 status=completed 的任务)。""" from app.core.db import async_session, TrainingJobModel from sqlalchemy import select # 查询数据库中训练完成的任务 async with async_session() as session: result = await session.execute( select(TrainingJobModel).where(TrainingJobModel.status == "completed") ) completed_jobs = {job.id: job for job in result.scalars().all()} if not completed_jobs: return [] adapters_dir = settings.adapters_dir if not adapters_dir.exists(): return [] result = [] for job_id, job in sorted(completed_jobs.items(), key=lambda x: x[1].created_at, reverse=True): adapter_dir = adapters_dir / job_id if not adapter_dir.is_dir(): continue adapter_config = adapter_dir / "adapter_config.json" if not adapter_config.exists(): continue with open(adapter_config) as f: cfg = json.load(f) result.append({ "id": job_id, "path": str(adapter_dir), "base_model": cfg.get("base_model_name_or_path", "unknown"), "peft_type": cfg.get("peft_type", "unknown"), "model_id": job.model_id, "peft_method": job.peft_method, "task_type": job.task_type, "created_at": job.created_at.isoformat() if job.created_at else None, }) return result async def run_inference_single( model_id: str, adapter_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, do_sample: bool, ) -> dict[str, Any]: """供远程 SSH 调用的单条推理入口。""" adapter_path = str(settings.adapters_dir / adapter_id) return _generate_local( adapter_path=adapter_path, base_model_id=model_id, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, )