|
@@ -23,10 +23,10 @@ class TextEngine(BaseEngine):
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
# 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
|
|
# 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
|
|
|
- from app.services.model_service import get_model_info as _get_model_info
|
|
|
|
|
- info = await _get_model_info(model_id)
|
|
|
|
|
- if info and info.get("path"):
|
|
|
|
|
- local_path = info["path"]
|
|
|
|
|
|
|
+ from app.services.model_service import resolve_model_path
|
|
|
|
|
+ model_path = await resolve_model_path(model_id)
|
|
|
|
|
+ if model_path:
|
|
|
|
|
+ local_path = model_path
|
|
|
else:
|
|
else:
|
|
|
local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
|
|
|
|
|
@@ -219,17 +219,23 @@ class TextEngine(BaseEngine):
|
|
|
import json
|
|
import json
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
- # 优先从数据库获取实际路径
|
|
|
|
|
- from app.services.model_service import get_model_info as _get_db_info
|
|
|
|
|
- import asyncio
|
|
|
|
|
- try:
|
|
|
|
|
- db_info = asyncio.get_event_loop().run_until_complete(_get_db_info(model_id))
|
|
|
|
|
- except RuntimeError:
|
|
|
|
|
- db_info = None
|
|
|
|
|
- if db_info and db_info.get("path"):
|
|
|
|
|
- config_path = Path(db_info["path"]) / "config.json"
|
|
|
|
|
- else:
|
|
|
|
|
- config_path = settings.models_dir / model_id.replace("/", "_") / "config.json"
|
|
|
|
|
|
|
+ # 同步查找模型路径(兼容 HF 和 ModelScope)
|
|
|
|
|
+ candidates = [
|
|
|
|
|
+ settings.models_dir / model_id.replace("/", "_"),
|
|
|
|
|
+ settings.models_dir / model_id,
|
|
|
|
|
+ ]
|
|
|
|
|
+ config_path = None
|
|
|
|
|
+ for p in candidates:
|
|
|
|
|
+ if (p / "config.json").exists():
|
|
|
|
|
+ config_path = p / "config.json"
|
|
|
|
|
+ break
|
|
|
|
|
+ if not config_path:
|
|
|
|
|
+ # 最后尝试扫描
|
|
|
|
|
+ model_name = model_id.split("/")[-1]
|
|
|
|
|
+ for cp in settings.models_dir.rglob("config.json"):
|
|
|
|
|
+ if model_name in str(cp.parent):
|
|
|
|
|
+ config_path = cp
|
|
|
|
|
+ break
|
|
|
|
|
|
|
|
if config_path.exists():
|
|
if config_path.exists():
|
|
|
with open(config_path) as f:
|
|
with open(config_path) as f:
|