model_test_service.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from pathlib import Path
  2. from typing import Any
  3. from app.config import get_settings
  4. from app.core.logging import logger
  5. settings = get_settings()
  6. async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
  7. """加载已缓存模型并生成测试响应。"""
  8. if settings.use_remote_compute:
  9. return await _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
  10. return await _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
  11. async def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  12. """在算力节点容器内执行模型测试(通过 SSH + docker exec)。
  13. 方案:通过 SSH 在远端容器内直接执行 Python 单行命令,
  14. 所有参数通过环境变量传入,避免任何引号/转义问题。
  15. """
  16. import base64
  17. import json
  18. from app.core.remote_executor import ssh_exec
  19. container = settings.compute_node_docker_container
  20. python = settings.compute_node_python
  21. workdir = settings.compute_node_workdir
  22. # 将 prompt 进行 base64 编码,避免引号/特殊字符问题
  23. prompt_b64 = base64.b64encode(prompt.encode("utf-8")).decode()
  24. do_sample = str(temperature > 0).lower()
  25. # 独立脚本:零 app/db 依赖,参数全部通过环境变量传入
  26. script = rf"""\
  27. import json, os, base64
  28. from pathlib import Path
  29. import torch
  30. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
  31. def find_model_path(model_id):
  32. # 远端实际存储路径(与 model_service.resolve_model_path 一致)
  33. for base in [
  34. '/root/Fine-tuning/backend/data/models',
  35. '/root/.cache/huggingface/hub',
  36. '/root/.cache/modelscope/hub',
  37. '/root/models',
  38. ]:
  39. bp = Path(base)
  40. if not bp.is_dir():
  41. continue
  42. # 尝试 namespace_name 扁平化匹配(HF 风格)
  43. flat_name = model_id.replace("/", "_")
  44. if (bp / flat_name / "config.json").exists():
  45. return str(bp / flat_name)
  46. # 尝试 namespace/name 嵌套匹配(ModelScope 风格)
  47. if (bp / model_id / "config.json").exists():
  48. return str(bp / model_id)
  49. # 扫描所有目录
  50. try:
  51. for child in bp.rglob("config.json"):
  52. if child.parent.is_dir():
  53. return str(child.parent)
  54. except Exception:
  55. pass
  56. return None
  57. model_id = os.environ.get('MODEL_ID', '')
  58. prompt = base64.b64decode(os.environ.get('PROMPT_B64', '')).decode('utf-8')
  59. max_new_tokens = int(os.environ.get('MAX_TOKENS', '128'))
  60. temperature = float(os.environ.get('TEMPERATURE', '0.8'))
  61. top_p = float(os.environ.get('TOP_P', '0.95'))
  62. do_sample = os.environ.get('DO_SAMPLE', 'true').lower() == 'true'
  63. model_path = find_model_path(model_id)
  64. if model_path is None:
  65. print(json.dumps({{'error': f'Model not found in cache: {{model_id}}'}}))
  66. exit(1)
  67. t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  68. t.pad_token = t.pad_token or t.eos_token
  69. m = None
  70. load_errors = []
  71. for cls, kw in [(AutoModelForCausalLM, {{'trust_remote_code': True}}), (AutoModel, {{'trust_remote_code': True}})]:
  72. try:
  73. m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw)
  74. break
  75. except Exception as e:
  76. load_errors.append(f'{{cls.__name__}} float16: {{str(e)[:200]}}')
  77. # float16 失败时尝试 float32
  78. try:
  79. m = cls.from_pretrained(model_path, torch_dtype=torch.float32, device_map='auto', **kw)
  80. break
  81. except Exception as e:
  82. load_errors.append(f'{{cls.__name__}} float32: {{str(e)[:200]}}')
  83. if m is None:
  84. print(json.dumps({{'error': 'Unable to load model', 'details': load_errors}}))
  85. exit(1)
  86. m.eval()
  87. inp = t(prompt, return_tensors='pt').to(m.device)
  88. out = m.generate(**inp, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=t.eos_token_id)
  89. gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)
  90. print(json.dumps({{'generated_text': gen}}))
  91. """
  92. script_b64 = base64.b64encode(script.encode()).decode()
  93. # 通过环境变量传递参数,脚本通过 stdin 管道传入容器内的 Python
  94. remote_cmd = (
  95. f"echo {script_b64} | base64 -d | "
  96. f"docker exec -i -w {workdir} "
  97. f"-e MODEL_ID={model_id} "
  98. f"-e PROMPT_B64={prompt_b64} "
  99. f"-e MAX_TOKENS={max_new_tokens} "
  100. f"-e TEMPERATURE={temperature} "
  101. f"-e TOP_P={top_p} "
  102. f"-e DO_SAMPLE={do_sample} "
  103. f"{container} {python}"
  104. )
  105. code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
  106. logger.info(f"Remote test result: code={code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}")
  107. if stdout:
  108. logger.info(f"stdout (first 500): {stdout[:500]}")
  109. if stderr:
  110. logger.info(f"stderr (first 500): {stderr[:500]}")
  111. if code != 0:
  112. logger.error(f"Remote model test failed: {stderr}")
  113. return {"error": stderr.strip() or "Remote test failed"}
  114. for line in reversed(stdout.strip().split("\n")):
  115. line = line.strip()
  116. if line.startswith("{"):
  117. try:
  118. result = json.loads(line)
  119. result["model_id"] = model_id
  120. result["prompt"] = prompt
  121. return result
  122. except json.JSONDecodeError:
  123. continue
  124. return {"error": f"Invalid response: {stdout[:500]}"}
  125. async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
  126. """本地执行模型测试(仅用于开发环境)。"""
  127. import torch
  128. from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
  129. from app.services.model_service import resolve_model_path
  130. model_path = await resolve_model_path(model_id)
  131. if not model_path:
  132. return {"error": f"Model not found in cache: {model_id}"}
  133. model_dir = Path(model_path)
  134. if not (model_dir / "config.json").exists():
  135. return {"error": f"Model directory not found: {model_dir}"}
  136. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  137. if tokenizer.pad_token is None:
  138. tokenizer.pad_token = tokenizer.eos_token
  139. model = None
  140. for loader_cls, kwargs in [
  141. (AutoModelForCausalLM, {"trust_remote_code": True}),
  142. (AutoModel, {"trust_remote_code": True}),
  143. ]:
  144. try:
  145. model = loader_cls.from_pretrained(
  146. model_dir,
  147. torch_dtype=torch.float16,
  148. device_map="auto",
  149. **kwargs,
  150. )
  151. break
  152. except Exception:
  153. continue
  154. if model is None:
  155. return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
  156. model.eval()
  157. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  158. with torch.no_grad():
  159. outputs = model.generate(
  160. **inputs,
  161. max_new_tokens=max_new_tokens,
  162. temperature=temperature,
  163. top_p=top_p,
  164. do_sample=temperature > 0,
  165. pad_token_id=tokenizer.eos_token_id,
  166. )
  167. generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
  168. return {
  169. "model_id": model_id,
  170. "prompt": prompt,
  171. "generated_text": generated_text,
  172. }