model_test_service.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. if settings.use_remote_compute:
  9. return _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
  10. return _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  11. def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  12. """通过 SSH 在算力节点执行模型测试。"""
  13. import base64
  14. import json
  15. from app.core.remote_executor import ssh_exec
  16. # 将 prompt 中的单引号/反斜杠转义
  17. safe_prompt = prompt.replace("\\", "\\\\").replace("'", "\\'")
  18. python_script = """\
  19. import json, asyncio
  20. from app.services.model_service import resolve_model_path
  21. model_path = asyncio.run(resolve_model_path('%s'))
  22. import torch
  23. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
  24. t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  25. t.pad_token = t.pad_token or t.eos_token
  26. m = None
  27. loaders = [
  28. (AutoModelForCausalLM, {'trust_remote_code': True}),
  29. (AutoModel, {'trust_remote_code': True}),
  30. ]
  31. for cls, kw in loaders:
  32. try:
  33. m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw)
  34. break
  35. except Exception:
  36. pass
  37. if m is None:
  38. print(json.dumps({'error': 'Unable to load model'}))
  39. exit(1)
  40. m.eval()
  41. inp = t('%s', return_tensors='pt').to(m.device)
  42. out = m.generate(**inp, max_new_tokens=%d, temperature=%f, top_p=%f, do_sample=%s, pad_token_id=t.eos_token_id)
  43. gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)
  44. print(json.dumps({'generated_text': gen}))
  45. """ % (model_id, safe_prompt, max_new_tokens, temperature, top_p, str(temperature > 0).lower())
  46. container = settings.compute_node_docker_container
  47. python = settings.compute_node_python
  48. workdir = settings.compute_node_workdir
  49. # 用 base64 编码脚本,通过 bash -c 执行:
  50. # 1. bash -c 能激活 conda 环境(与训练命令一致)
  51. # 2. base64 避免引号嵌套和命令截断问题
  52. script_b64 = base64.b64encode(python_script.encode()).decode()
  53. remote_cmd = (
  54. f"docker exec -w {workdir} {container} "
  55. f"bash -c 'echo {script_b64} | base64 -d | {python}'"
  56. )
  57. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  58. logger.info(f"Remote test result: code={code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}")
  59. if stdout:
  60. logger.info(f"stdout (first 500): {stdout[:500]}")
  61. if stderr:
  62. logger.info(f"stderr (first 500): {stderr[:500]}")
  63. if code != 0:
  64. logger.error(f"Remote model test failed: {stderr}")
  65. return {"error": stderr.strip() or "Remote test failed"}
  66. # 提取最后一行 JSON
  67. for line in reversed(stdout.strip().split("\n")):
  68. line = line.strip()
  69. if line.startswith("{"):
  70. try:
  71. result = json.loads(line)
  72. result["model_id"] = model_id
  73. result["prompt"] = prompt
  74. return result
  75. except json.JSONDecodeError:
  76. continue
  77. return {"error": f"Invalid response: {stdout[:500]}"}
  78. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  79. """本地执行模型测试(仅用于开发环境)。"""
  80. import torch
  81. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  82. from app.services.model_service import resolve_model_path
  83. model_path = await resolve_model_path(model_id)
  84. if not model_path:
  85. return {"error": f"Model not found in cache: {model_id}"}
  86. model_dir = Path(model_path)
  87. if not (model_dir / "config.json").exists():
  88. return {"error": f"Model directory not found: {model_dir}"}
  89. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  90. if tokenizer.pad_token is None:
  91. tokenizer.pad_token = tokenizer.eos_token
  92. # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
  93. model = None
  94. for loader_cls, kwargs in [
  95. (AutoModelForCausalLM, {"trust_remote_code": True}),
  96. (AutoModel, {"trust_remote_code": True}),
  97. ]:
  98. try:
  99. model = loader_cls.from_pretrained(
  100. model_dir,
  101. torch_dtype=torch.float16,
  102. device_map="auto",
  103. **kwargs,
  104. )
  105. break
  106. except Exception:
  107. continue
  108. if model is None:
  109. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  110. model.eval()
  111. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  112. with torch.no_grad():
  113. outputs = model.generate(
  114. **inputs,
  115. max_new_tokens=max_new_tokens,
  116. temperature=temperature,
  117. top_p=top_p,
  118. do_sample=temperature > 0,
  119. pad_token_id=tokenizer.eos_token_id,
  120. )
  121. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  122. return {
  123. "model_id": model_id,
  124. "prompt": prompt,
  125. "generated_text": generated_text,
  126. }