|
|
@@ -22,7 +22,13 @@ class TextEngine(BaseEngine):
|
|
|
import torch
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
- local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
|
+ # 优先从数据库获取实际路径(兼容 ModelScope 下载的目录结构)
|
|
|
+ from app.services.model_service import get_model_info as _get_model_info
|
|
|
+ info = await _get_model_info(model_id)
|
|
|
+ if info and info.get("path"):
|
|
|
+ local_path = info["path"]
|
|
|
+ else:
|
|
|
+ local_path = str(settings.models_dir / model_id.replace("/", "_"))
|
|
|
|
|
|
# 如果本地没有,从 HF 下载
|
|
|
if not (Path(local_path) / "config.json").exists():
|
|
|
@@ -213,8 +219,17 @@ class TextEngine(BaseEngine):
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
|
|
|
- model_dir = settings.models_dir / model_id.replace("/", "_")
|
|
|
- config_path = model_dir / "config.json"
|
|
|
+ # 优先从数据库获取实际路径
|
|
|
+ from app.services.model_service import get_model_info as _get_db_info
|
|
|
+ import asyncio
|
|
|
+ try:
|
|
|
+ db_info = asyncio.get_event_loop().run_until_complete(_get_db_info(model_id))
|
|
|
+ except RuntimeError:
|
|
|
+ db_info = None
|
|
|
+ if db_info and db_info.get("path"):
|
|
|
+ config_path = Path(db_info["path"]) / "config.json"
|
|
|
+ else:
|
|
|
+ config_path = settings.models_dir / model_id.replace("/", "_") / "config.json"
|
|
|
|
|
|
if config_path.exists():
|
|
|
with open(config_path) as f:
|