| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- import json
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.db import async_session, ModelCache
- from app.core.logging import logger
- from sqlalchemy import select
- settings = get_settings()
- async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
- """从 HF 或 ModelScope 下载模型到本地缓存。"""
- try:
- if use_modelscope:
- from modelscope import snapshot_download as ms_download
- local_path = ms_download(model_id, cache_dir=str(settings.models_dir))
- else:
- from huggingface_hub import snapshot_download
- local_path = snapshot_download(
- repo_id=model_id,
- local_dir=str(settings.models_dir / model_id.replace("/", "_")),
- local_dir_use_symlinks=False,
- )
- # 读取 config.json 获取模型信息
- config_path = Path(local_path) / "config.json"
- model_type = "text"
- context_length = 2048
- peft_methods = "lora,qlora,ia3,adalora,prefix_tuning"
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- model_type = cfg.get("model_type", "text")
- context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
- # 写入数据库
- async with async_session() as session:
- record = ModelCache(
- id=model_id,
- name=model_id.split("/")[-1],
- model_type=model_type,
- path=local_path,
- is_downloaded=1,
- context_length=context_length,
- supported_peft_methods=peft_methods,
- )
- session.add(record)
- await session.commit()
- logger.info(f"Model downloaded: {model_id} -> {local_path}")
- return {"model_id": model_id, "status": "completed", "path": local_path}
- except Exception as e:
- logger.error(f"Model download failed: {e}")
- return {"model_id": model_id, "status": "failed", "error": str(e)}
- def list_cached_models() -> list[dict[str, Any]]:
- """列出本地已缓存的模型。"""
- models_dir = settings.models_dir
- if not models_dir.exists():
- return []
- result = []
- for d in models_dir.iterdir():
- if not d.is_dir():
- continue
- config_path = d / "config.json"
- info: dict[str, Any] = {
- "id": d.name,
- "name": d.name,
- "model_type": "text",
- "path": str(d),
- "is_downloaded": True,
- "context_length": None,
- "supported_peft_methods": [],
- }
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- info["model_type"] = cfg.get("model_type", "text")
- info["context_length"] = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
- info["supported_peft_methods"] = ["lora", "qlora", "ia3", "adalora", "prefix_tuning"]
- result.append(info)
- return result
- async def get_model_info(model_id: str) -> dict[str, Any] | None:
- """获取已缓存模型的元数据。"""
- # 先查数据库
- async with async_session() as session:
- result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "id": record.id,
- "name": record.name,
- "model_type": record.model_type,
- "path": record.path,
- "is_downloaded": bool(record.is_downloaded),
- "context_length": record.context_length,
- "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
- }
- # 回退:直接从文件系统读取
- model_dir = settings.models_dir / model_id.replace("/", "_")
- config_path = model_dir / "config.json"
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- return {
- "id": model_id,
- "name": model_id.split("/")[-1],
- "model_type": cfg.get("model_type", "text"),
- "path": str(model_dir),
- "is_downloaded": True,
- "context_length": cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048)),
- "supported_peft_methods": ["lora", "qlora", "ia3", "adalora", "prefix_tuning"],
- }
- return None
|