model_test_service.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from pathlib import Path
  2. from typing import Any
  3. from app.core.logging import logger
  4. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  5. """加载已缓存模型并生成测试响应。"""
  6. try:
  7. import torch
  8. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
  9. from transformers import AutoConfig
  10. from app.services.model_service import resolve_model_path
  11. model_path = await resolve_model_path(model_id)
  12. if not model_path:
  13. return {"error": f"Model not found in cache: {model_id}"}
  14. model_dir = Path(model_path)
  15. if not (model_dir / "config.json").exists():
  16. return {"error": f"Model directory not found: {model_dir}"}
  17. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  18. if tokenizer.pad_token is None:
  19. tokenizer.pad_token = tokenizer.eos_token
  20. # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
  21. model = None
  22. for loader_cls, kwargs in [
  23. (AutoModelForCausalLM, {"trust_remote_code": True}),
  24. (AutoModel, {"trust_remote_code": True}),
  25. ]:
  26. try:
  27. model = loader_cls.from_pretrained(
  28. model_dir,
  29. torch_dtype=torch.float16,
  30. device_map="auto",
  31. **kwargs,
  32. )
  33. break
  34. except Exception:
  35. continue
  36. if model is None:
  37. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  38. model.eval()
  39. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  40. with torch.no_grad():
  41. outputs = model.generate(
  42. **inputs,
  43. max_new_tokens=max_new_tokens,
  44. temperature=temperature,
  45. top_p=top_p,
  46. do_sample=temperature > 0,
  47. pad_token_id=tokenizer.eos_token_id,
  48. )
  49. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  50. return {
  51. "model_id": model_id,
  52. "prompt": prompt,
  53. "generated_text": generated_text,
  54. }
  55. except Exception as e:
  56. logger.error(f"Model test failed: {e}")
  57. return {"error": str(e)}