Parcourir la source

修复远程查询数据库问题

lxylxy123321 il y a 1 semaine
Parent
commit
99286fed1b

+ 7 - 9
backend/app/engines/multimodal_engine.py

@@ -21,17 +21,15 @@ class MultimodalEngine(BaseEngine):
         import torch
         from transformers import AutoProcessor, LlavaForConditionalGeneration
 
-        # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
-        from app.services.model_service import resolve_model_path
-        model_path = await resolve_model_path(model_id)
-        if model_path:
-            local_path = model_path
-        else:
-            local_path = str(settings.models_dir / model_id.replace("/", "_"))
+        local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
         if not (Path(local_path) / "config.json").exists():
-            from huggingface_hub import snapshot_download
-            snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
+            ms_path = settings.models_dir / model_id
+            if (ms_path / "config.json").exists():
+                local_path = str(ms_path)
+            else:
+                from huggingface_hub import snapshot_download
+                snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
 
         self._processor = AutoProcessor.from_pretrained(local_path, trust_remote_code=True)
         self._model = LlavaForConditionalGeneration.from_pretrained(

+ 13 - 18
backend/app/engines/text_engine.py

@@ -40,27 +40,22 @@ class TextEngine(BaseEngine):
         import torch
         from transformers import AutoModelForCausalLM, AutoTokenizer
 
-        # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
-        # 远程节点可能没有 sqlalchemy,直接回退到本地路径扫描
-        try:
-            from app.services.model_service import resolve_model_path
-            model_path = await resolve_model_path(model_id)
-        except ImportError:
-            model_path = None
-        if model_path:
-            local_path = model_path
-        else:
-            local_path = str(settings.models_dir / model_id.replace("/", "_"))
+        # 远程节点不查数据库,直接扫描本地模型目录
+        local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
         # 如果本地没有,从 HF 下载
         if not (Path(local_path) / "config.json").exists():
-            from huggingface_hub import snapshot_download
-
-            snapshot_download(
-                repo_id=model_id,
-                local_dir=local_path,
-                local_dir_use_symlinks=False,
-            )
+            # 尝试 ModelScope 风格路径
+            ms_path = settings.models_dir / model_id
+            if (ms_path / "config.json").exists():
+                local_path = str(ms_path)
+            else:
+                from huggingface_hub import snapshot_download
+                snapshot_download(
+                    repo_id=model_id,
+                    local_dir=local_path,
+                    local_dir_use_symlinks=False,
+                )
 
         quantization = kwargs.get("quantization", None)
         

+ 7 - 9
backend/app/engines/vision_engine.py

@@ -21,17 +21,15 @@ class VisionEngine(BaseEngine):
         import torch
         from transformers import AutoImageProcessor, AutoModelForImageClassification
 
-        # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
-        from app.services.model_service import resolve_model_path
-        model_path = await resolve_model_path(model_id)
-        if model_path:
-            local_path = model_path
-        else:
-            local_path = str(settings.models_dir / model_id.replace("/", "_"))
+        local_path = str(settings.models_dir / model_id.replace("/", "_"))
 
         if not (Path(local_path) / "config.json").exists():
-            from huggingface_hub import snapshot_download
-            snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
+            ms_path = settings.models_dir / model_id
+            if (ms_path / "config.json").exists():
+                local_path = str(ms_path)
+            else:
+                from huggingface_hub import snapshot_download
+                snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
 
         self._processor = AutoImageProcessor.from_pretrained(local_path, trust_remote_code=True)
         self._model = AutoModelForImageClassification.from_pretrained(