from pathlib import Path from typing import Any from app.config import get_settings from app.core.logging import logger settings = get_settings() async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]: """加载已缓存模型并生成测试响应。""" return await _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p) async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]: """本地执行模型测试。""" import torch from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig from app.services.model_service import resolve_model_path model_path = await resolve_model_path(model_id) if not model_path: return {"error": f"Model not found in cache: {model_id}"} model_dir = Path(model_path) if not (model_dir / "config.json").exists(): return {"error": f"Model directory not found: {model_dir}"} logger.info(f"Loading model: {model_id} from {model_dir}") tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = None for loader_cls, kwargs in [ (AutoModelForCausalLM, {"trust_remote_code": True}), (AutoModel, {"trust_remote_code": True}), ]: try: model = loader_cls.from_pretrained( model_dir, torch_dtype=torch.float16, device_map="auto", **kwargs, ) break except Exception as e: logger.warning(f"Failed to load with {loader_cls.__name__}: {e}") continue if model is None: return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."} model.eval() # 限制 prompt 长度,避免 OOM max_prompt_len = getattr(settings, "default_max_seq_length", 2048) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_prompt_len).to(model.device) prompt_tokens = inputs["input_ids"].shape[1] logger.info(f"Prompt tokenized: {prompt_tokens} tokens, generating up to {max_new_tokens} new tokens") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) logger.info(f"Generated {outputs.shape[1] - inputs['input_ids'].shape[1]} tokens") return { "model_id": model_id, "prompt": prompt, "generated_text": generated_text, }