Ver código fonte

修复模型加载找不到模型路径的问题

lxylxy123321 2 semanas atrás
pai
commit
9c944fa2ab

+ 4 - 4
backend/app/engines/multimodal_engine.py

@@ -22,10 +22,10 @@ class MultimodalEngine(BaseEngine):
         from transformers import AutoProcessor, LlavaForConditionalGeneration
 
         # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
-        from app.services.model_service import get_model_info as _get_model_info
-        info = await _get_model_info(model_id)
-        if info and info.get("path"):
-            local_path = info["path"]
+        from app.services.model_service import resolve_model_path
+        model_path = await resolve_model_path(model_id)
+        if model_path:
+            local_path = model_path
         else:
             local_path = str(settings.models_dir / model_id.replace("/", "_"))
 

+ 21 - 15
backend/app/engines/text_engine.py

@@ -23,10 +23,10 @@ class TextEngine(BaseEngine):
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
         # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
-        from app.services.model_service import get_model_info as _get_model_info
-        info = await _get_model_info(model_id)
-        if info and info.get("path"):
-            local_path = info["path"]
+        from app.services.model_service import resolve_model_path
+        model_path = await resolve_model_path(model_id)
+        if model_path:
+            local_path = model_path
         else:
             local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
@@ -219,17 +219,23 @@ class TextEngine(BaseEngine):
         import json
         from pathlib import Path
 
-        # 优先从数据库获取实际路径
-        from app.services.model_service import get_model_info as _get_db_info
-        import asyncio
-        try:
-            db_info = asyncio.get_event_loop().run_until_complete(_get_db_info(model_id))
-        except RuntimeError:
-            db_info = None
-        if db_info and db_info.get("path"):
-            config_path = Path(db_info["path"]) / "config.json"
-        else:
-            config_path = settings.models_dir / model_id.replace("/", "_") / "config.json"
+        # 同步查找模型路径(兼容 HF 和 ModelScope)
+        candidates = [
+            settings.models_dir / model_id.replace("/", "_"),
+            settings.models_dir / model_id,
+        ]
+        config_path = None
+        for p in candidates:
+            if (p / "config.json").exists():
+                config_path = p / "config.json"
+                break
+        if not config_path:
+            # 最后尝试扫描
+            model_name = model_id.split("/")[-1]
+            for cp in settings.models_dir.rglob("config.json"):
+                if model_name in str(cp.parent):
+                    config_path = cp
+                    break
 
         if config_path.exists():
             with open(config_path) as f:

+ 4 - 4
backend/app/engines/vision_engine.py

@@ -22,10 +22,10 @@ class VisionEngine(BaseEngine):
         from transformers import AutoImageProcessor, AutoModelForImageClassification
 
         # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
-        from app.services.model_service import get_model_info as _get_model_info
-        info = await _get_model_info(model_id)
-        if info and info.get("path"):
-            local_path = info["path"]
+        from app.services.model_service import resolve_model_path
+        model_path = await resolve_model_path(model_id)
+        if model_path:
+            local_path = model_path
         else:
             local_path = str(settings.models_dir / model_id.replace("/", "_"))
 

+ 28 - 0
backend/app/services/model_service.py

@@ -10,6 +10,34 @@ from sqlalchemy import select
 settings = get_settings()
 
 
+async def resolve_model_path(model_id: str) -> str | None:
+    """解析模型的实际路径,兼容 HuggingFace 和 ModelScope 的不同目录结构。"""
+    # 策略 1: 从数据库读取实际路径
+    info = await get_model_info(model_id)
+    if info and info.get("path"):
+        p = Path(info["path"])
+        if (p / "config.json").exists():
+            return str(p)
+
+    # 策略 2: HuggingFace 风格(namespace_name 扁平化)
+    hf_path = settings.models_dir / model_id.replace("/", "_")
+    if (hf_path / "config.json").exists():
+        return str(hf_path)
+
+    # 策略 3: ModelScope 风格(namespace/name 嵌套,含软链接)
+    ms_path = settings.models_dir / model_id
+    if (ms_path / "config.json").exists():
+        return str(ms_path)
+
+    # 策略 4: 扫描 models_dir 下所有目录,匹配名称
+    model_name = model_id.split("/")[-1]
+    for p in settings.models_dir.rglob("config.json"):
+        if p.parent.name == model_name or model_name in str(p.parent):
+            return str(p.parent)
+
+    return None
+
+
 async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
     """从 HF 或 ModelScope 下载模型到本地缓存。"""
     try:

+ 4 - 5
backend/app/services/model_test_service.py

@@ -10,13 +10,12 @@ async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temp
         import torch
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
-        # 从数据库获取模型实际路径
-        from app.services.model_service import get_model_info
-        info = await get_model_info(model_id)
-        if not info or not info.get("path"):
+        from app.services.model_service import resolve_model_path
+        model_path = await resolve_model_path(model_id)
+        if not model_path:
             return {"error": f"Model not found in cache: {model_id}"}
 
-        model_dir = Path(info["path"])
+        model_dir = Path(model_path)
         if not (model_dir / "config.json").exists():
             return {"error": f"Model directory not found: {model_dir}"}