model_test_service.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. return await _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  9. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  10. """本地执行模型测试。"""
  11. import torch
  12. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  13. from app.services.model_service import resolve_model_path
  14. model_path = await resolve_model_path(model_id)
  15. if not model_path:
  16. return {"error": f"Model not found in cache: {model_id}"}
  17. model_dir = Path(model_path)
  18. if not (model_dir / "config.json").exists():
  19. return {"error": f"Model directory not found: {model_dir}"}
  20. logger.info(f"Loading model: {model_id} from {model_dir}")
  21. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  22. if tokenizer.pad_token is None:
  23. tokenizer.pad_token = tokenizer.eos_token
  24. model = None
  25. for loader_cls, kwargs in [
  26. (AutoModelForCausalLM, {"trust_remote_code": True}),
  27. (AutoModel, {"trust_remote_code": True}),
  28. ]:
  29. try:
  30. model = loader_cls.from_pretrained(
  31. model_dir,
  32. torch_dtype=torch.float16,
  33. device_map="auto",
  34. **kwargs,
  35. )
  36. break
  37. except Exception as e:
  38. logger.warning(f"Failed to load with {loader_cls.__name__}: {e}")
  39. continue
  40. if model is None:
  41. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  42. model.eval()
  43. # 限制 prompt 长度,避免 OOM
  44. max_prompt_len = getattr(settings, "default_max_seq_length", 2048)
  45. inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_prompt_len).to(model.device)
  46. prompt_tokens = inputs["input_ids"].shape[1]
  47. logger.info(f"Prompt tokenized: {prompt_tokens} tokens, generating up to {max_new_tokens} new tokens")
  48. with torch.no_grad():
  49. outputs = model.generate(
  50. **inputs,
  51. max_new_tokens=max_new_tokens,
  52. temperature=temperature,
  53. top_p=top_p,
  54. do_sample=temperature > 0,
  55. pad_token_id=tokenizer.eos_token_id,
  56. )
  57. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  58. logger.info(f"Generated {outputs.shape[1] - inputs['input_ids'].shape[1]} tokens")
  59. return {
  60. "model_id": model_id,
  61. "prompt": prompt,
  62. "generated_text": generated_text,
  63. }