| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- import os
- import json
- import uuid
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any
- from app.config import get_settings
- from app.core.background_tasks import background_task_manager
- from app.core.db import async_session, ModelCache, ModelDownloadTask
- from app.core.logging import logger
- from sqlalchemy import select
- settings = get_settings()
- async def resolve_model_path(model_id: str) -> str | None:
- """解析模型的实际路径,兼容 HuggingFace 和 ModelScope 的不同目录结构。"""
- # 策略 1: 从数据库读取实际路径
- info = await get_model_info(model_id)
- if info and info.get("path"):
- p = Path(info["path"])
- if (p / "config.json").exists():
- return str(p)
- # 策略 2: HuggingFace 风格(namespace_name 扁平化)
- hf_path = settings.models_dir / model_id.replace("/", "_")
- if (hf_path / "config.json").exists():
- return str(hf_path)
- # 策略 3: ModelScope 风格(namespace/name 嵌套,含软链接)
- ms_path = settings.models_dir / model_id
- if (ms_path / "config.json").exists():
- return str(ms_path)
- # 策略 4: 扫描 models_dir 下所有目录,匹配名称
- model_name = model_id.split("/")[-1]
- for p in settings.models_dir.rglob("config.json"):
- if p.parent.name == model_name or model_name in str(p.parent):
- return str(p.parent)
- return None
- async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
- """启动模型下载后台任务,立即返回 task_id。"""
- task_id = str(uuid.uuid4())
- # 检查是否有正在进行的同模型下载
- for tid, t in background_task_manager.tasks.items():
- if (
- t.get("task_type") == "model_download"
- and t.get("model_id") == model_id
- and t.get("status") in ("pending", "downloading", "running")
- ):
- return {"task_id": tid, "model_id": model_id, "status": t["status"], "duplicate": True}
- # 写 DB
- record = ModelDownloadTask(
- id=task_id,
- model_id=model_id,
- use_modelscope=1 if use_modelscope else 0,
- status="pending",
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- # 注册并启动
- background_task_manager.register_task(task_id, "model_download", {"model_id": model_id})
- background_task_manager.run(
- task_id, "model_download", _execute_model_download(task_id, model_id, use_modelscope)
- )
- logger.info(f"Model download task started: {model_id} (task_id={task_id})")
- return {"task_id": task_id, "model_id": model_id, "status": "pending"}
- async def _execute_model_download(task_id: str, model_id: str, use_modelscope: bool) -> dict:
- """后台执行模型下载。"""
- try:
- if use_modelscope:
- import subprocess
- download_dir = str(settings.models_dir / model_id.replace("/", "_"))
- proc = subprocess.run(
- [
- "modelscope", "download",
- "--model", model_id,
- "--local_dir", download_dir,
- ],
- capture_output=True, text=True, timeout=3600,
- )
- if proc.returncode != 0:
- raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
- local_path = download_dir
- else:
- from huggingface_hub import snapshot_download
- local_path = snapshot_download(
- repo_id=model_id,
- local_dir=str(settings.models_dir / model_id.replace("/", "_")),
- local_dir_use_symlinks=False,
- )
- # 读取 config.json
- config_path = Path(local_path) / "config.json"
- model_type = "text"
- context_length = 2048
- peft_methods = "lora,qlora,adalora"
- if config_path.exists():
- with open(config_path) as f:
- cfg = json.load(f)
- model_type = cfg.get("model_type", "text")
- context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
- # 更新 ModelCache
- async with async_session() as session:
- result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
- existing = result.scalar_one_or_none()
- if existing:
- existing.name = model_id.split("/")[-1]
- existing.model_type = model_type
- existing.path = local_path
- existing.is_downloaded = 1
- existing.context_length = context_length
- existing.supported_peft_methods = peft_methods
- else:
- record = ModelCache(
- id=model_id,
- name=model_id.split("/")[-1],
- model_type=model_type,
- path=local_path,
- is_downloaded=1,
- context_length=context_length,
- supported_peft_methods=peft_methods,
- )
- session.add(record)
- await session.commit()
- # 更新下载任务
- await _update_model_download_status(task_id, "completed", path=local_path)
- logger.info(f"Model downloaded: {model_id} -> {local_path}")
- return {"path": local_path}
- except Exception as e:
- import traceback
- tb = traceback.format_exc()
- logger.error(f"Model download failed: {type(e).__name__}: {e}")
- logger.error(f"Traceback:\n{tb}")
- error_msg = str(e)
- if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
- error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
- await _update_model_download_status(task_id, "failed", error=error_msg)
- return {"error": error_msg}
- async def _update_model_download_status(task_id: str, status: str, path: str = None, error: str = None):
- async with async_session() as session:
- result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- record.status = status
- if path:
- record.path = path
- if error:
- record.error = error
- if status in ("completed", "failed"):
- record.finished_at = datetime.utcnow()
- await session.commit()
- background_task_manager.update_task(
- task_id, status=status, path=path, error=error,
- finished_at=datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
- )
- async def get_model_download_status(task_id: str) -> dict[str, Any]:
- async with async_session() as session:
- result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "task_id": record.id,
- "model_id": record.model_id,
- "status": record.status,
- "use_modelscope": bool(record.use_modelscope),
- "path": record.path,
- "error": record.error,
- "progress": record.progress,
- "created_at": record.created_at.isoformat() if record.created_at else "",
- }
- # 也查内存
- mem = background_task_manager.get_task(task_id)
- if mem:
- return {
- "task_id": task_id,
- "model_id": mem.get("model_id", ""),
- "status": mem["status"],
- "error": mem.get("error"),
- "progress": mem.get("progress", 0),
- }
- return {"task_id": task_id, "status": "not_found"}
- async def list_model_downloads() -> list[dict[str, Any]]:
- async with async_session() as session:
- result = await session.execute(
- select(ModelDownloadTask).order_by(ModelDownloadTask.created_at.desc())
- )
- records = result.scalars().all()
- return [
- {
- "task_id": r.id,
- "model_id": r.model_id,
- "status": r.status,
- "use_modelscope": bool(r.use_modelscope),
- "path": r.path,
- "error": r.error,
- "created_at": r.created_at.isoformat() if r.created_at else "",
- }
- for r in records
- ]
- async def cancel_model_download(task_id: str) -> dict[str, Any]:
- background_task_manager.cancel_task(task_id)
- async with async_session() as session:
- result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
- record = result.scalar_one_or_none()
- if record and record.status in ("pending", "downloading"):
- record.status = "cancelled"
- record.error = "Cancelled by user"
- record.finished_at = datetime.utcnow()
- await session.commit()
- return {"task_id": task_id, "status": "cancelled"}
- async def recover_stale_downloads() -> None:
- """把因重启中断的下载任务标记为 failed。"""
- async with async_session() as session:
- result = await session.execute(
- select(ModelDownloadTask).where(
- ModelDownloadTask.status.in_(["pending", "downloading"])
- )
- )
- records = result.scalars().all()
- for record in records:
- record.status = "failed"
- record.error = "Server restarted, task interrupted"
- record.finished_at = datetime.utcnow()
- if records:
- await session.commit()
- logger.info(f"Recovered {len(records)} stale model download tasks")
- async def list_cached_models() -> list[dict[str, Any]]:
- """从数据库列出已缓存的模型(不扫描目录,避免 HF 缓存子目录干扰)。"""
- async with async_session() as session:
- result = await session.execute(select(ModelCache).order_by(ModelCache.created_at.desc()))
- records = result.scalars().all()
- models = []
- for r in records:
- # 验证目录是否真的存在,如果不存在则标记为未下载
- dir_exists = r.path and Path(r.path).exists()
- if not dir_exists:
- # 尝试从 models_dir 下查找
- alt_path = settings.models_dir / r.id.replace("/", "_")
- dir_exists = alt_path.exists()
- if dir_exists:
- r.path = str(alt_path)
- models.append({
- "id": r.id,
- "name": r.name,
- "model_type": r.model_type,
- "path": r.path,
- "is_downloaded": dir_exists,
- "context_length": r.context_length,
- "supported_peft_methods": r.supported_peft_methods.split(",") if r.supported_peft_methods else [],
- })
- return models
- async def get_model_info(model_id: str) -> dict[str, Any] | None:
- """获取已缓存模型的元数据。"""
- async with async_session() as session:
- result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "id": record.id,
- "name": record.name,
- "model_type": record.model_type,
- "path": record.path,
- "is_downloaded": bool(record.is_downloaded) and Path(record.path).exists() if record.path else False,
- "context_length": record.context_length,
- "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
- }
- return None
- async def delete_model(model_id: str) -> dict[str, Any]:
- """删除已缓存的模型(数据库记录 + 本地文件)。"""
- async with async_session() as session:
- result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"status": "not_found", "message": f"Model not found: {model_id}"}
- # 删除本地文件目录(对软链接,删除其指向的真实目录)
- model_dir = Path(record.path) if record.path else settings.models_dir / record.id.replace("/", "_")
- deleted_files = False
- if model_dir.is_symlink():
- # ModelScope 下载的模型可能是软链接,删除真实目录
- real_dir = model_dir.resolve()
- import shutil
- if real_dir.exists() and real_dir.is_dir():
- shutil.rmtree(real_dir, ignore_errors=True)
- # 如果还有父级软链接(如 dphn/ 下的其他链接),一并清理
- parent_link = model_dir.parent
- if parent_link.is_symlink():
- shutil.rmtree(parent_link, ignore_errors=True)
- deleted_files = True
- elif model_dir.exists() and model_dir.is_dir():
- import shutil
- shutil.rmtree(model_dir, ignore_errors=True)
- deleted_files = True
- # 删除数据库记录
- await session.delete(record)
- await session.commit()
- logger.info(f"Model deleted: {model_id} (files={deleted_files})")
- return {"status": "deleted", "model_id": model_id, "files_deleted": deleted_files}
|