model_test_service.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. if settings.use_remote_compute:
  9. return _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
  10. return _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  11. def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  12. """通过 SSH 在算力节点执行模型测试。"""
  13. import json
  14. from app.core.remote_executor import ssh_exec
  15. safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
  16. container = settings.compute_node_docker_container
  17. python = settings.compute_node_python
  18. remote_cmd = (
  19. f"docker exec {container} "
  20. f"{python} -c \""
  21. "import json, asyncio; "
  22. "from app.services.model_service import resolve_model_path; "
  23. "model_path = asyncio.run(resolve_model_path('" + model_id + "')); "
  24. "import torch; "
  25. "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel; "
  26. "t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True); "
  27. "t.pad_token = t.pad_token or t.eos_token; "
  28. "m = None; "
  29. "for cls, kw in [(AutoModelForCausalLM, {'trust_remote_code': True}), (AutoModel, {'trust_remote_code': True})]: "
  30. " try: m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw); break; "
  31. " except: pass; "
  32. "m.eval(); "
  33. "inp = t('" + safe_prompt + "', return_tensors='pt').to(m.device); "
  34. "out = m.generate(**inp, max_new_tokens=" + str(max_new_tokens) + ", temperature=" + str(temperature) + ", top_p=" + str(top_p) + ", do_sample=" + str(temperature > 0).lower() + ", pad_token_id=t.eos_token_id); "
  35. "gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True); "
  36. "print(json.dumps({'generated_text': gen}))\" 2>&1"
  37. )
  38. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  39. if code != 0:
  40. logger.error(f"Remote model test failed: {stderr}")
  41. return {"error": stderr.strip() or "Remote test failed"}
  42. # 提取最后一行 JSON
  43. for line in reversed(stdout.strip().split("\n")):
  44. line = line.strip()
  45. if line.startswith("{"):
  46. try:
  47. result = json.loads(line)
  48. result["model_id"] = model_id
  49. result["prompt"] = prompt
  50. return result
  51. except json.JSONDecodeError:
  52. continue
  53. return {"error": f"Invalid response: {stdout[:500]}"}
  54. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  55. """本地执行模型测试(仅用于开发环境)。"""
  56. import torch
  57. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  58. from app.services.model_service import resolve_model_path
  59. model_path = await resolve_model_path(model_id)
  60. if not model_path:
  61. return {"error": f"Model not found in cache: {model_id}"}
  62. model_dir = Path(model_path)
  63. if not (model_dir / "config.json").exists():
  64. return {"error": f"Model directory not found: {model_dir}"}
  65. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  66. if tokenizer.pad_token is None:
  67. tokenizer.pad_token = tokenizer.eos_token
  68. # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
  69. model = None
  70. for loader_cls, kwargs in [
  71. (AutoModelForCausalLM, {"trust_remote_code": True}),
  72. (AutoModel, {"trust_remote_code": True}),
  73. ]:
  74. try:
  75. model = loader_cls.from_pretrained(
  76. model_dir,
  77. torch_dtype=torch.float16,
  78. device_map="auto",
  79. **kwargs,
  80. )
  81. break
  82. except Exception:
  83. continue
  84. if model is None:
  85. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  86. model.eval()
  87. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  88. with torch.no_grad():
  89. outputs = model.generate(
  90. **inputs,
  91. max_new_tokens=max_new_tokens,
  92. temperature=temperature,
  93. top_p=top_p,
  94. do_sample=temperature > 0,
  95. pad_token_id=tokenizer.eos_token_id,
  96. )
  97. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  98. return {
  99. "model_id": model_id,
  100. "prompt": prompt,
  101. "generated_text": generated_text,
  102. }