lxylxy123321 1 týždeň pred
rodič
commit
a539f22e73
1 zmenil súbory, kde vykonal 43 pridanie a 19 odobranie
  1. 43 19
      backend/app/services/model_test_service.py

+ 43 - 19
backend/app/services/model_test_service.py

@@ -19,29 +19,53 @@ def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperat
     import json
     from app.core.remote_executor import ssh_exec
 
-    safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
+    # 将 prompt 中的单引号转义,用于 Python 字符串格式化
+    safe_prompt = prompt.replace("'", "\\'")
+
+    python_script = """\
+import json, asyncio
+from app.services.model_service import resolve_model_path
+
+model_path = asyncio.run(resolve_model_path('%s'))
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
+
+t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+t.pad_token = t.pad_token or t.eos_token
+
+m = None
+loaders = [
+    (AutoModelForCausalLM, {'trust_remote_code': True}),
+    (AutoModel, {'trust_remote_code': True}),
+]
+for cls, kw in loaders:
+    try:
+        m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw)
+        break
+    except Exception:
+        pass
+
+if m is None:
+    print(json.dumps({'error': 'Unable to load model'}))
+    exit(1)
+
+m.eval()
+inp = t('%s', return_tensors='pt').to(m.device)
+out = m.generate(**inp, max_new_tokens=%d, temperature=%f, top_p=%f, do_sample=%s, pad_token_id=t.eos_token_id)
+gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)
+print(json.dumps({'generated_text': gen}))
+""" % (model_id, safe_prompt, max_new_tokens, temperature, top_p, str(temperature > 0).lower())
+
     container = settings.compute_node_docker_container
     python = settings.compute_node_python
 
+    # 使用 docker exec -i + heredoc 传递脚本到容器内 Python stdin,
+    # 避免长命令被截断或引号解析错误
     remote_cmd = (
-        f"docker exec {container} "
-        f"{python} -c \""
-        "import json, asyncio; "
-        "from app.services.model_service import resolve_model_path; "
-        "model_path = asyncio.run(resolve_model_path('" + model_id + "')); "
-        "import torch; "
-        "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel; "
-        "t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True); "
-        "t.pad_token = t.pad_token or t.eos_token; "
-        "m = None; "
-        "for cls, kw in [(AutoModelForCausalLM, {'trust_remote_code': True}), (AutoModel, {'trust_remote_code': True})]: "
-        "    try: m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw); break; "
-        "    except: pass; "
-        "m.eval(); "
-        "inp = t('" + safe_prompt + "', return_tensors='pt').to(m.device); "
-        "out = m.generate(**inp, max_new_tokens=" + str(max_new_tokens) + ", temperature=" + str(temperature) + ", top_p=" + str(top_p) + ", do_sample=" + str(temperature > 0).lower() + ", pad_token_id=t.eos_token_id); "
-        "gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True); "
-        "print(json.dumps({'generated_text': gen}))\" 2>&1"
+        f"docker exec -i {container} {python} << 'PYTHON_SCRIPT_EOF'\n"
+        f"{python_script}\n"
+        f"PYTHON_SCRIPT_EOF"
     )
 
     code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)