Quellcode durchsuchen

修复模型加载问题

lxylxy123321 vor 2 Wochen
Ursprung
Commit
f70de74f98

+ 7 - 1
backend/app/engines/multimodal_engine.py

@@ -21,7 +21,13 @@ class MultimodalEngine(BaseEngine):
         import torch
         from transformers import AutoProcessor, LlavaForConditionalGeneration
 
-        local_path = str(settings.models_dir / model_id.replace("/", "_"))
+        # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
+        from app.services.model_service import get_model_info as _get_model_info
+        info = await _get_model_info(model_id)
+        if info and info.get("path"):
+            local_path = info["path"]
+        else:
+            local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
         if not (Path(local_path) / "config.json").exists():
             from huggingface_hub import snapshot_download

+ 18 - 3
backend/app/engines/text_engine.py

@@ -22,7 +22,13 @@ class TextEngine(BaseEngine):
         import torch
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
-        local_path = str(settings.models_dir / model_id.replace("/", "_"))
+        # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
+        from app.services.model_service import get_model_info as _get_model_info
+        info = await _get_model_info(model_id)
+        if info and info.get("path"):
+            local_path = info["path"]
+        else:
+            local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
         # 如果本地没有,从 HF 下载
         if not (Path(local_path) / "config.json").exists():
@@ -213,8 +219,17 @@ class TextEngine(BaseEngine):
         import json
         from pathlib import Path
 
-        model_dir = settings.models_dir / model_id.replace("/", "_")
-        config_path = model_dir / "config.json"
+        # 优先从数据库获取实际路径
+        from app.services.model_service import get_model_info as _get_db_info
+        import asyncio
+        try:
+            db_info = asyncio.get_event_loop().run_until_complete(_get_db_info(model_id))
+        except RuntimeError:
+            db_info = None
+        if db_info and db_info.get("path"):
+            config_path = Path(db_info["path"]) / "config.json"
+        else:
+            config_path = settings.models_dir / model_id.replace("/", "_") / "config.json"
 
         if config_path.exists():
             with open(config_path) as f:

+ 7 - 1
backend/app/engines/vision_engine.py

@@ -21,7 +21,13 @@ class VisionEngine(BaseEngine):
         import torch
         from transformers import AutoImageProcessor, AutoModelForImageClassification
 
-        local_path = str(settings.models_dir / model_id.replace("/", "_"))
+        # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
+        from app.services.model_service import get_model_info as _get_model_info
+        info = await _get_model_info(model_id)
+        if info and info.get("path"):
+            local_path = info["path"]
+        else:
+            local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
         if not (Path(local_path) / "config.json").exists():
             from huggingface_hub import snapshot_download

+ 7 - 6
backend/app/services/model_test_service.py

@@ -1,12 +1,8 @@
-import json
 from pathlib import Path
 from typing import Any
 
-from app.config import get_settings
 from app.core.logging import logger
 
-settings = get_settings()
-
 
 async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
     """加载已缓存模型并生成测试响应。"""
@@ -14,8 +10,13 @@ async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temp
         import torch
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
-        # 查找模型路径
-        model_dir = settings.models_dir / model_id.replace("/", "_")
+        # 从数据库获取模型实际路径
+        from app.services.model_service import get_model_info
+        info = await get_model_info(model_id)
+        if not info or not info.get("path"):
+            return {"error": f"Model not found in cache: {model_id}"}
+
+        model_dir = Path(info["path"])
         if not (model_dir / "config.json").exists():
             return {"error": f"Model directory not found: {model_dir}"}