model_test_service.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. if settings.use_remote_compute:
  9. return _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
  10. return _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  11. def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  12. """通过 SSH 在算力节点执行模型测试。
  13. 使用独立的 remote_model_test.py 脚本(无 app/db 依赖,不依赖 sqlalchemy),
  14. 通过 SSH + heredoc 部署到远端,docker exec 在容器内执行。
  15. """
  16. import json
  17. from app.core.remote_executor import ssh_exec
  18. # 转义 prompt 中的单引号和反斜杠,用于 shell 安全传递
  19. safe_prompt = prompt.replace("\\", "\\\\").replace("'", "\\'")
  20. container = settings.compute_node_docker_container
  21. python = settings.compute_node_python
  22. workdir = settings.compute_node_workdir
  23. # 将脚本写入远端临时文件,执行后清理
  24. remote_cmd = (
  25. f"cat > /tmp/remote_model_test.py << 'SCRIPT_EOF'\n"
  26. f"import json, sys\n"
  27. f"from pathlib import Path\n"
  28. f"import torch\n"
  29. f"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel\n"
  30. f"\n"
  31. f"def find_model_path(model_id):\n"
  32. f" candidates = [\n"
  33. f" '/root/.cache/huggingface/hub',\n"
  34. f" '/root/.cache/modelscope/hub',\n"
  35. f" '/root/models',\n"
  36. f" ]\n"
  37. f" for base in candidates:\n"
  38. f" bp = Path(base)\n"
  39. f" if not bp.is_dir():\n"
  40. f" continue\n"
  41. f" # Direct match\n"
  42. f" for child in bp.rglob('config.json'):\n"
  43. f" parent = child.parent\n"
  44. f" if parent.is_dir():\n"
  45. f" return str(parent)\n"
  46. f" return None\n"
  47. f"\n"
  48. f"model_id = sys.argv[1]\n"
  49. f"prompt = sys.argv[2]\n"
  50. f"max_new_tokens = int(sys.argv[3])\n"
  51. f"temperature = float(sys.argv[4])\n"
  52. f"top_p = float(sys.argv[5])\n"
  53. f"\n"
  54. f"model_path = find_model_path(model_id)\n"
  55. f"if model_path is None:\n"
  56. f" print(json.dumps({{'error': f'Model not found in cache: {{model_id}}'}}))\n"
  57. f" sys.exit(1)\n"
  58. f"\n"
  59. f"t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n"
  60. f"t.pad_token = t.pad_token or t.eos_token\n"
  61. f"\n"
  62. f"m = None\n"
  63. f"for cls, kw in [\n"
  64. f" (AutoModelForCausalLM, {{'trust_remote_code': True}}),\n"
  65. f" (AutoModel, {{'trust_remote_code': True}}),\n"
  66. f"]:\n"
  67. f" try:\n"
  68. f" m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw)\n"
  69. f" break\n"
  70. f" except Exception:\n"
  71. f" pass\n"
  72. f"\n"
  73. f"if m is None:\n"
  74. f" print(json.dumps({{'error': 'Unable to load model'}}))\n"
  75. f" sys.exit(1)\n"
  76. f"\n"
  77. f"m.eval()\n"
  78. f"inp = t(prompt, return_tensors='pt').to(m.device)\n"
  79. f"out = m.generate(**inp, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample={str(temperature > 0).lower()}, pad_token_id=t.eos_token_id)\n"
  80. f"gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)\n"
  81. f"print(json.dumps({{'generated_text': gen}}))\n"
  82. f"SCRIPT_EOF\n"
  83. f"\n"
  84. f"docker exec -w {workdir} {container} {python} /tmp/remote_model_test.py '{model_id}' '{safe_prompt}' {max_new_tokens} {temperature} {top_p}\n"
  85. f"rm -f /tmp/remote_model_test.py"
  86. )
  87. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  88. logger.info(f"Remote test result: code={code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}")
  89. if stdout:
  90. logger.info(f"stdout (first 500): {stdout[:500]}")
  91. if stderr:
  92. logger.info(f"stderr (first 500): {stderr[:500]}")
  93. if code != 0:
  94. logger.error(f"Remote model test failed: {stderr}")
  95. return {"error": stderr.strip() or "Remote test failed"}
  96. # 提取最后一行 JSON
  97. for line in reversed(stdout.strip().split("\n")):
  98. line = line.strip()
  99. if line.startswith("{"):
  100. try:
  101. result = json.loads(line)
  102. result["model_id"] = model_id
  103. result["prompt"] = prompt
  104. return result
  105. except json.JSONDecodeError:
  106. continue
  107. return {"error": f"Invalid response: {stdout[:500]}"}
  108. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  109. """本地执行模型测试(仅用于开发环境)。"""
  110. import torch
  111. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  112. from app.services.model_service import resolve_model_path
  113. model_path = await resolve_model_path(model_id)
  114. if not model_path:
  115. return {"error": f"Model not found in cache: {model_id}"}
  116. model_dir = Path(model_path)
  117. if not (model_dir / "config.json").exists():
  118. return {"error": f"Model directory not found: {model_dir}"}
  119. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  120. if tokenizer.pad_token is None:
  121. tokenizer.pad_token = tokenizer.eos_token
  122. # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
  123. model = None
  124. for loader_cls, kwargs in [
  125. (AutoModelForCausalLM, {"trust_remote_code": True}),
  126. (AutoModel, {"trust_remote_code": True}),
  127. ]:
  128. try:
  129. model = loader_cls.from_pretrained(
  130. model_dir,
  131. torch_dtype=torch.float16,
  132. device_map="auto",
  133. **kwargs,
  134. )
  135. break
  136. except Exception:
  137. continue
  138. if model is None:
  139. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  140. model.eval()
  141. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  142. with torch.no_grad():
  143. outputs = model.generate(
  144. **inputs,
  145. max_new_tokens=max_new_tokens,
  146. temperature=temperature,
  147. top_p=top_p,
  148. do_sample=temperature > 0,
  149. pad_token_id=tokenizer.eos_token_id,
  150. )
  151. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  152. return {
  153. "model_id": model_id,
  154. "prompt": prompt,
  155. "generated_text": generated_text,
  156. }