lxylxy123321 1 неделя назад
Родитель
Сommit
c6f4c693e1
1 измененных файлов с 16 добавлено и 6 удалено
  1. 16 6
      backend/app/services/model_test_service.py

+ 16 - 6
backend/app/services/model_test_service.py

@@ -8,7 +8,8 @@ async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temp
     """加载已缓存模型并生成测试响应。"""
     try:
         import torch
-        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
+        from transformers import AutoConfig
 
         from app.services.model_service import resolve_model_path
         model_path = await resolve_model_path(model_id)
@@ -23,11 +24,20 @@ async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temp
         if tokenizer.pad_token is None:
             tokenizer.pad_token = tokenizer.eos_token
 
-        model = AutoModelForCausalLM.from_pretrained(
-            model_dir,
-            torch_dtype=torch.float16,
-            device_map="auto",
-        )
+        # 优先尝试因果语言模型加载,失败则回退到通用 AutoModel
+        try:
+            model = AutoModelForCausalLM.from_pretrained(
+                model_dir,
+                torch_dtype=torch.float16,
+                device_map="auto",
+            )
+        except (KeyError, ValueError, TypeError):
+            model = AutoModel.from_pretrained(
+                model_dir,
+                torch_dtype=torch.float16,
+                device_map="auto",
+                trust_remote_code=True,
+            )
         model.eval()
 
         inputs = tokenizer(prompt, return_tensors="pt").to(model.device)