| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- """推理服务 — 支持本地执行和 SSH 远程执行两种模式。"""
- import asyncio
- import json
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.logging import logger
- settings = get_settings()
- async def generate(
- adapter_path: str,
- prompt: str,
- max_new_tokens: int = 256,
- temperature: float = 0.8,
- top_p: float = 0.95,
- repetition_penalty: float = 1.1,
- do_sample: bool = True,
- ) -> dict[str, Any]:
- """使用已训练的 adapter 生成文本。"""
- # 从 adapter config 中获取 base model ID
- base_model_id = _get_base_model_id(adapter_path)
- if not base_model_id:
- return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
- if settings.use_remote_compute:
- # 远程执行模式(用 to_thread 避免阻塞事件循环)
- from app.core.remote_executor import run_inference_remote
- adapter_dir = Path(adapter_path)
- adapter_id = adapter_dir.name
- result = await asyncio.to_thread(
- run_inference_remote,
- model_id=base_model_id,
- adapter_id=adapter_id,
- prompt=prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- )
- if result:
- return result
- return {"error": "Remote inference failed"}
- # 本地执行模式(用 to_thread 避免 GPU 操作阻塞事件循环)
- return await asyncio.to_thread(
- _generate_local,
- adapter_path=adapter_path,
- base_model_id=base_model_id,
- prompt=prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- )
- def _generate_local(
- adapter_path: str,
- base_model_id: str,
- prompt: str,
- max_new_tokens: int,
- temperature: float,
- top_p: float,
- repetition_penalty: float,
- do_sample: bool,
- ) -> dict[str, Any]:
- """本地执行推理。"""
- try:
- import os
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
- from peft import PeftModel
- tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- # Single GPU inference: load entire model on cuda:0
- # Avoid device_map="auto" which can split model across GPUs and cause
- # device mismatch errors (e.g., rotary_emb, bmm operations)
- import torch
- device_map = {"": 0} if torch.cuda.is_available() else "cpu"
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_id,
- torch_dtype=torch.float16,
- device_map=device_map,
- )
- model = PeftModel.from_pretrained(base_model, adapter_path)
- model.eval()
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
- # 构建 stop criteria:遇到角色标记就停止,防止复读
- stop_phrases = ["<|user|>", "<|system|>", "<|assistant|>"]
- stop_token_ids = [tokenizer.encode(p, add_special_tokens=False) for p in stop_phrases]
- class StopOnRoleToken(StoppingCriteria):
- def __call__(self, input_ids, scores, **kwargs):
- gen_seq = input_ids[0].tolist()
- for s_ids in stop_token_ids:
- if len(gen_seq) >= len(s_ids) and gen_seq[-len(s_ids):] == s_ids:
- return True
- return False
- stopping_criteria = StoppingCriteriaList([StopOnRoleToken()])
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- pad_token_id=tokenizer.eos_token_id,
- eos_token_id=tokenizer.eos_token_id,
- stopping_criteria=stopping_criteria,
- )
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
- generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
- # 清理可能残留的角色标记
- for marker in ["<|user|>", "<|system|>", "<|assistant|>"]:
- if marker in generated_only:
- generated_only = generated_only[:generated_only.index(marker)]
- generated_only = generated_only.strip()
- return {
- "prompt": prompt,
- "generated_text": generated_text,
- "generated_only": generated_only,
- "tokens_generated": int(outputs.shape[1] - inputs["input_ids"].shape[1]),
- }
- except Exception as e:
- logger.error(f"Inference failed: {e}")
- return {"error": str(e)}
- def _get_base_model_id(adapter_path: str) -> str | None:
- """从 adapter config 中获取 base model ID。"""
- config_path = Path(adapter_path) / "adapter_config.json"
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- return cfg.get("base_model_name_or_path")
- return None
- async def get_available_adapters() -> list[dict[str, Any]]:
- """列出所有已训练完成的 adapter(仅显示 status=completed 的任务)。"""
- from app.core.db import async_session, TrainingJobModel
- from sqlalchemy import select
- # 查询数据库中训练完成的任务
- async with async_session() as session:
- result = await session.execute(
- select(TrainingJobModel).where(TrainingJobModel.status == "completed")
- )
- completed_jobs = {job.id: job for job in result.scalars().all()}
- if not completed_jobs:
- return []
- adapters_dir = settings.adapters_dir
- if not adapters_dir.exists():
- return []
- result = []
- for job_id, job in sorted(completed_jobs.items(), key=lambda x: x[1].created_at, reverse=True):
- adapter_dir = adapters_dir / job_id
- if not adapter_dir.is_dir():
- continue
- adapter_config = adapter_dir / "adapter_config.json"
- if not adapter_config.exists():
- continue
- with open(adapter_config) as f:
- cfg = json.load(f)
- result.append({
- "id": job_id,
- "path": str(adapter_dir),
- "base_model": cfg.get("base_model_name_or_path", "unknown"),
- "peft_type": cfg.get("peft_type", "unknown"),
- "model_id": job.model_id,
- "peft_method": job.peft_method,
- "task_type": job.task_type,
- "created_at": job.created_at.isoformat() if job.created_at else None,
- })
- return result
- async def run_inference_single(
- model_id: str,
- adapter_id: str,
- prompt: str,
- max_new_tokens: int,
- temperature: float,
- top_p: float,
- repetition_penalty: float,
- do_sample: bool,
- ) -> dict[str, Any]:
- """供远程 SSH 调用的单条推理入口。"""
- adapter_path = str(settings.adapters_dir / adapter_id)
- return _generate_local(
- adapter_path=adapter_path,
- base_model_id=model_id,
- prompt=prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- do_sample=do_sample,
- )
|