model_test_service.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. if settings.use_remote_compute:
  9. return _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
  10. return _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  11. def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  12. """通过 SSH 在算力节点执行模型测试。"""
  13. import json
  14. from app.core.remote_executor import ssh_exec
  15. # 将 prompt 中的单引号转义,用于 Python 字符串格式化
  16. safe_prompt = prompt.replace("'", "\\'")
  17. python_script = """\
  18. import json, asyncio
  19. from app.services.model_service import resolve_model_path
  20. model_path = asyncio.run(resolve_model_path('%s'))
  21. import torch
  22. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
  23. t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  24. t.pad_token = t.pad_token or t.eos_token
  25. m = None
  26. loaders = [
  27. (AutoModelForCausalLM, {'trust_remote_code': True}),
  28. (AutoModel, {'trust_remote_code': True}),
  29. ]
  30. for cls, kw in loaders:
  31. try:
  32. m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw)
  33. break
  34. except Exception:
  35. pass
  36. if m is None:
  37. print(json.dumps({'error': 'Unable to load model'}))
  38. exit(1)
  39. m.eval()
  40. inp = t('%s', return_tensors='pt').to(m.device)
  41. out = m.generate(**inp, max_new_tokens=%d, temperature=%f, top_p=%f, do_sample=%s, pad_token_id=t.eos_token_id)
  42. gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)
  43. print(json.dumps({'generated_text': gen}))
  44. """ % (model_id, safe_prompt, max_new_tokens, temperature, top_p, str(temperature > 0).lower())
  45. container = settings.compute_node_docker_container
  46. python = settings.compute_node_python
  47. # 使用 docker exec -i + heredoc 传递脚本到容器内 Python stdin,
  48. # 避免长命令被截断或引号解析错误;-w 指定工作目录确保 app 模块可导入
  49. remote_cmd = (
  50. f"docker exec -i -w {settings.compute_node_workdir} {container} "
  51. f"{python} << 'PYTHON_SCRIPT_EOF'\n"
  52. f"{python_script}\n"
  53. f"PYTHON_SCRIPT_EOF"
  54. )
  55. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  56. logger.info(f"Remote test result: code={code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}")
  57. if stdout:
  58. logger.info(f"stdout (first 500): {stdout[:500]}")
  59. if stderr:
  60. logger.info(f"stderr (first 500): {stderr[:500]}")
  61. if code != 0:
  62. logger.error(f"Remote model test failed: {stderr}")
  63. return {"error": stderr.strip() or "Remote test failed"}
  64. # 提取最后一行 JSON
  65. for line in reversed(stdout.strip().split("\n")):
  66. line = line.strip()
  67. if line.startswith("{"):
  68. try:
  69. result = json.loads(line)
  70. result["model_id"] = model_id
  71. result["prompt"] = prompt
  72. return result
  73. except json.JSONDecodeError:
  74. continue
  75. return {"error": f"Invalid response: {stdout[:500]}"}
  76. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  77. """本地执行模型测试(仅用于开发环境)。"""
  78. import torch
  79. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  80. from app.services.model_service import resolve_model_path
  81. model_path = await resolve_model_path(model_id)
  82. if not model_path:
  83. return {"error": f"Model not found in cache: {model_id}"}
  84. model_dir = Path(model_path)
  85. if not (model_dir / "config.json").exists():
  86. return {"error": f"Model directory not found: {model_dir}"}
  87. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  88. if tokenizer.pad_token is None:
  89. tokenizer.pad_token = tokenizer.eos_token
  90. # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
  91. model = None
  92. for loader_cls, kwargs in [
  93. (AutoModelForCausalLM, {"trust_remote_code": True}),
  94. (AutoModel, {"trust_remote_code": True}),
  95. ]:
  96. try:
  97. model = loader_cls.from_pretrained(
  98. model_dir,
  99. torch_dtype=torch.float16,
  100. device_map="auto",
  101. **kwargs,
  102. )
  103. break
  104. except Exception:
  105. continue
  106. if model is None:
  107. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  108. model.eval()
  109. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  110. with torch.no_grad():
  111. outputs = model.generate(
  112. **inputs,
  113. max_new_tokens=max_new_tokens,
  114. temperature=temperature,
  115. top_p=top_p,
  116. do_sample=temperature > 0,
  117. pad_token_id=tokenizer.eos_token_id,
  118. )
  119. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  120. return {
  121. "model_id": model_id,
  122. "prompt": prompt,
  123. "generated_text": generated_text,
  124. }