Răsfoiți Sursa

修复模型测试功能

lxylxy123321 1 săptămână în urmă
părinte
comite
e5f21bd330
1 a modificat fișierele cu 108 adăugiri și 53 ștergeri
  1. 108 53
      backend/app/services/model_test_service.py

+ 108 - 53
backend/app/services/model_test_service.py

@@ -1,70 +1,125 @@
 from pathlib import Path
 from typing import Any
 
+from app.config import get_settings
 from app.core.logging import logger
 
+settings = get_settings()
+
 
 async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
     """加载已缓存模型并生成测试响应。"""
-    try:
-        import torch
-        from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
-        from transformers import AutoConfig
-
-        from app.services.model_service import resolve_model_path
-        model_path = await resolve_model_path(model_id)
-        if not model_path:
-            return {"error": f"Model not found in cache: {model_id}"}
-
-        model_dir = Path(model_path)
-        if not (model_dir / "config.json").exists():
-            return {"error": f"Model directory not found: {model_dir}"}
-
-        tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
-        if tokenizer.pad_token is None:
-            tokenizer.pad_token = tokenizer.eos_token
-
-        # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
-        model = None
-        for loader_cls, kwargs in [
-            (AutoModelForCausalLM, {"trust_remote_code": True}),
-            (AutoModel, {"trust_remote_code": True}),
-        ]:
+    if settings.use_remote_compute:
+        return _test_model_remote(model_id, prompt, max_new_tokens, temperature, top_p)
+    return _test_model_local(model_id, prompt, max_new_tokens, temperature, top_p)
+
+
+def _test_model_remote(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
+    """通过 SSH 在算力节点执行模型测试。"""
+    import json
+    from app.core.remote_executor import ssh_exec
+
+    safe_prompt = prompt.replace('"', '\\"').replace("'", "\\'").replace("\n", "\\n")
+    container = settings.compute_node_docker_container
+    python = settings.compute_node_python
+
+    remote_cmd = (
+        f"docker exec {container} "
+        f"{python} -c \""
+        "import json, asyncio; "
+        "from app.services.model_service import resolve_model_path; "
+        "model_path = asyncio.run(resolve_model_path('" + model_id + "')); "
+        "import torch; "
+        "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel; "
+        "t = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True); "
+        "t.pad_token = t.pad_token or t.eos_token; "
+        "m = None; "
+        "for cls, kw in [(AutoModelForCausalLM, {'trust_remote_code': True}), (AutoModel, {'trust_remote_code': True})]: "
+        "    try: m = cls.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', **kw); break; "
+        "    except: pass; "
+        "m.eval(); "
+        "inp = t('" + safe_prompt + "', return_tensors='pt').to(m.device); "
+        "out = m.generate(**inp, max_new_tokens=" + str(max_new_tokens) + ", temperature=" + str(temperature) + ", top_p=" + str(top_p) + ", do_sample=" + str(temperature > 0).lower() + ", pad_token_id=t.eos_token_id); "
+        "gen = t.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True); "
+        "print(json.dumps({'generated_text': gen}))\" 2>&1"
+    )
+
+    code, stdout, stderr = ssh_exec(remote_cmd, timeout=600)
+
+    if code != 0:
+        logger.error(f"Remote model test failed: {stderr}")
+        return {"error": stderr.strip() or "Remote test failed"}
+
+    # 提取最后一行 JSON
+    for line in reversed(stdout.strip().split("\n")):
+        line = line.strip()
+        if line.startswith("{"):
             try:
-                model = loader_cls.from_pretrained(
-                    model_dir,
-                    torch_dtype=torch.float16,
-                    device_map="auto",
-                    **kwargs,
-                )
-                break
-            except Exception:
+                result = json.loads(line)
+                result["model_id"] = model_id
+                result["prompt"] = prompt
+                return result
+            except json.JSONDecodeError:
                 continue
 
-        if model is None:
-            return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
-        model.eval()
+    return {"error": f"Invalid response: {stdout[:500]}"}
+
+
+async def _test_model_local(model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> dict[str, Any]:
+    """本地执行模型测试(仅用于开发环境)。"""
+    import torch
+    from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+    from app.services.model_service import resolve_model_path
+    model_path = await resolve_model_path(model_id)
+    if not model_path:
+        return {"error": f"Model not found in cache: {model_id}"}
 
-        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+    model_dir = Path(model_path)
+    if not (model_dir / "config.json").exists():
+        return {"error": f"Model directory not found: {model_dir}"}
 
-        with torch.no_grad():
-            outputs = model.generate(
-                **inputs,
-                max_new_tokens=max_new_tokens,
-                temperature=temperature,
-                top_p=top_p,
-                do_sample=temperature > 0,
-                pad_token_id=tokenizer.eos_token_id,
+    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
+    if tokenizer.pad_token is None:
+        tokenizer.pad_token = tokenizer.eos_token
+
+    # 通用加载策略:尝试多种加载方式,自动兼容各种新架构
+    model = None
+    for loader_cls, kwargs in [
+        (AutoModelForCausalLM, {"trust_remote_code": True}),
+        (AutoModel, {"trust_remote_code": True}),
+    ]:
+        try:
+            model = loader_cls.from_pretrained(
+                model_dir,
+                torch_dtype=torch.float16,
+                device_map="auto",
+                **kwargs,
             )
+            break
+        except Exception:
+            continue
+
+    if model is None:
+        return {"error": f"Unable to load model with any available loader. Model type may not be supported yet."}
+    model.eval()
+
+    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
-        generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
+    with torch.no_grad():
+        outputs = model.generate(
+            **inputs,
+            max_new_tokens=max_new_tokens,
+            temperature=temperature,
+            top_p=top_p,
+            do_sample=temperature > 0,
+            pad_token_id=tokenizer.eos_token_id,
+        )
 
-        return {
-            "model_id": model_id,
-            "prompt": prompt,
-            "generated_text": generated_text,
-        }
+    generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
-    except Exception as e:
-        logger.error(f"Model test failed: {e}")
-        return {"error": str(e)}
+    return {
+        "model_id": model_id,
+        "prompt": prompt,
+        "generated_text": generated_text,
+    }