model_test_service.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. if settings.use_remote_compute:
  9. return _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
  10. return _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  11. def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  12. """通过 SSH 在算力节点执行模型测试。
  13. 通过环境变量传递参数,base64 编码脚本通过 stdin 管道传给 docker exec -i python。
  14. """
  15. import base64
  16. import json
  17. from app.core.remote_executor import ssh_exec
  18. container = settings.compute_node_docker_container
  19. python = settings.compute_node_python
  20. workdir = settings.compute_node_workdir
  21. # 参数通过 base64 编码,脚本内通过 os.environ 读取,完全避免引号/转义问题
  22. prompt_b64 = base64.b64encode(prompt.encode('utf-8')).decode()
  23. do_sample = str(temperature > 0).lower()
  24. # 独立的 Python 脚本(参数通过环境变量传入)
  25. script = rf"""\
  26. import json, os, base64
  27. from pathlib import Path
  28. import torch
  29. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
  30. def find_model_path(model_id):
  31. for base in ['/root/.cache/huggingface/hub', '/root/.cache/modelscope/hub', '/root/models']:
  32. bp = Path(base)
  33. if not bp.is_dir():
  34. continue
  35. try:
  36. for child in bp.rglob('config.json'):
  37. if child.parent.is_dir():
  38. return str(child.parent)
  39. except Exception:
  40. pass
  41. return None
  42. model_id = os.environ.get('MODEL_ID', '')
  43. prompt = base64.b64decode(os.environ.get('PROMPT_B64', '')).decode('utf-8')
  44. max_new_tokens = int(os.environ.get('MAX_TOKENS', '128'))
  45. temperature = float(os.environ.get('TEMPERATURE', '0.8'))
  46. top_p = float(os.environ.get('TOP_P', '0.95'))
  47. do_sample = os.environ.get('DO_SAMPLE', 'true').lower() == 'true'
  48. model_path = find_model_path(model_id)
  49. if model_path is None:
  50. print(json.dumps({{'error': f'Model not found in cache: {{model_id}}'}}))
  51. exit(1)
  52. t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  53. t.pad_token = t.pad_token or t.eos_token
  54. m = None
  55. for cls, kw in [(AutoModelForCausalLM, {{'trust_remote_code': True}}), (AutoModel, {{'trust_remote_code': True}})]:
  56. try:
  57. m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw)
  58. break
  59. except Exception:
  60. pass
  61. if m is None:
  62. print(json.dumps({{'error': 'Unable to load model'}}))
  63. exit(1)
  64. m.eval()
  65. inp = t(prompt, return_tensors='pt').to(m.device)
  66. out = m.generate(**inp, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=t.eos_token_id)
  67. gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)
  68. print(json.dumps({{'generated_text': gen}}))
  69. """
  70. script_b64 = base64.b64encode(script.encode()).decode()
  71. # 环境变量通过 docker exec -e 传入容器,脚本通过 stdin 管道传入
  72. remote_cmd = (
  73. f"echo {script_b64} | base64 -d | "
  74. f"docker exec -i -w {workdir} "
  75. f"-e MODEL_ID={model_id} "
  76. f"-e PROMPT_B64={prompt_b64} "
  77. f"-e MAX_TOKENS={max_new_tokens} "
  78. f"-e TEMPERATURE={temperature} "
  79. f"-e TOP_P={top_p} "
  80. f"-e DO_SAMPLE={do_sample} "
  81. f"{container} {python}"
  82. )
  83. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  84. logger.info(f"Remote test result: code={code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}")
  85. if stdout:
  86. logger.info(f"stdout (first 500): {stdout[:500]}")
  87. if stderr:
  88. logger.info(f"stderr (first 500): {stderr[:500]}")
  89. if code != 0:
  90. logger.error(f"Remote model test failed: {stderr}")
  91. return {"error": stderr.strip() or "Remote test failed"}
  92. # 提取最后一行 JSON
  93. for line in reversed(stdout.strip().split("\n")):
  94. line = line.strip()
  95. if line.startswith("{"):
  96. try:
  97. result = json.loads(line)
  98. result["model_id"] = model_id
  99. result["prompt"] = prompt
  100. return result
  101. except json.JSONDecodeError:
  102. continue
  103. return {"error": f"Invalid response: {stdout[:500]}"}
  104. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  105. """本地执行模型测试(仅用于开发环境)。"""
  106. import torch
  107. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  108. from app.services.model_service import resolve_model_path
  109. model_path = await resolve_model_path(model_id)
  110. if not model_path:
  111. return {"error": f"Model not found in cache: {model_id}"}
  112. model_dir = Path(model_path)
  113. if not (model_dir / "config.json").exists():
  114. return {"error": f"Model directory not found: {model_dir}"}
  115. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  116. if tokenizer.pad_token is None:
  117. tokenizer.pad_token = tokenizer.eos_token
  118. # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
  119. model = None
  120. for loader_cls, kwargs in [
  121. (AutoModelForCausalLM, {"trust_remote_code": True}),
  122. (AutoModel, {"trust_remote_code": True}),
  123. ]:
  124. try:
  125. model = loader_cls.from_pretrained(
  126. model_dir,
  127. torch_dtype=torch.float16,
  128. device_map="auto",
  129. **kwargs,
  130. )
  131. break
  132. except Exception:
  133. continue
  134. if model is None:
  135. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  136. model.eval()
  137. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  138. with torch.no_grad():
  139. outputs = model.generate(
  140. **inputs,
  141. max_new_tokens=max_new_tokens,
  142. temperature=temperature,
  143. top_p=top_p,
  144. do_sample=temperature > 0,
  145. pad_token_id=tokenizer.eos_token_id,
  146. )
  147. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  148. return {
  149. "model_id": model_id,
  150. "prompt": prompt,
  151. "generated_text": generated_text,
  152. }