|
|
@@ -21,17 +21,15 @@ class VisionEngine(BaseEngine):
|
|
|
import torch
|
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
|
|
|
|
|
- # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
|
|
|
- from app.services.model_service import resolve_model_path
|
|
|
- model_path = await resolve_model_path(model_id)
|
|
|
- if model_path:
|
|
|
- local_path = model_path
|
|
|
- else:
|
|
|
- local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
|
+ local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
|
|
|
|
if not (Path(local_path) / "config.json").exists():
|
|
|
- from huggingface_hub import snapshot_download
|
|
|
- snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
|
|
|
+ ms_path = settings.models_dir / model_id
|
|
|
+ if (ms_path / "config.json").exists():
|
|
|
+ local_path = str(ms_path)
|
|
|
+ else:
|
|
|
+ from huggingface_hub import snapshot_download
|
|
|
+ snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
|
|
|
|
|
|
self._processor = AutoImageProcessor.from_pretrained(local_path, trust_remote_code=True)
|
|
|
self._model = AutoModelForImageClassification.from_pretrained(
|