import os import json import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from app.config import get_settings from app.core.background_tasks import background_task_manager from app.core.db import async_session, ModelCache, ModelDownloadTask from app.core.logging import logger from sqlalchemy import select settings = get_settings() async def resolve_model_path(model_id: str) -> str | None: """解析模型的实际路径,兼容 HuggingFace 和 ModelScope 的不同目录结构。""" # 策略 1: 从数据库读取实际路径 info = await get_model_info(model_id) if info and info.get("path"): p = Path(info["path"]) if (p / "config.json").exists(): return str(p) # 策略 2: HuggingFace 风格(namespace_name 扁平化) hf_path = settings.models_dir / model_id.replace("/", "_") if (hf_path / "config.json").exists(): return str(hf_path) # 策略 3: ModelScope 风格(namespace/name 嵌套,含软链接) ms_path = settings.models_dir / model_id if (ms_path / "config.json").exists(): return str(ms_path) # 策略 4: 扫描 models_dir 下所有目录,匹配名称 model_name = model_id.split("/")[-1] for p in settings.models_dir.rglob("config.json"): if p.parent.name == model_name or model_name in str(p.parent): return str(p.parent) return None async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]: """启动模型下载后台任务,立即返回 task_id。""" task_id = str(uuid.uuid4()) # 检查是否有正在进行的同模型下载 for tid, t in background_task_manager.tasks.items(): if ( t.get("task_type") == "model_download" and t.get("model_id") == model_id and t.get("status") in ("pending", "downloading", "running") ): return {"task_id": tid, "model_id": model_id, "status": t["status"], "duplicate": True} # 写 DB record = ModelDownloadTask( id=task_id, model_id=model_id, use_modelscope=1 if use_modelscope else 0, status="pending", ) async with async_session() as session: session.add(record) await session.commit() # 注册并启动 background_task_manager.register_task(task_id, "model_download", {"model_id": model_id}) await background_task_manager.run( task_id, "model_download", _execute_model_download(task_id, model_id, use_modelscope) ) logger.info(f"Model download task started: {model_id} (task_id={task_id})") return {"task_id": task_id, "model_id": model_id, "status": "pending"} async def _execute_model_download(task_id: str, model_id: str, use_modelscope: bool) -> dict: """后台执行模型下载。""" try: if use_modelscope: import subprocess download_dir = str(settings.models_dir / model_id.replace("/", "_")) proc = subprocess.run( [ "modelscope", "download", "--model", model_id, "--local_dir", download_dir, ], capture_output=True, text=True, timeout=3600, ) if proc.returncode != 0: raise RuntimeError(f"modelscope CLI failed: {proc.stderr}") local_path = download_dir else: from huggingface_hub import snapshot_download local_path = snapshot_download( repo_id=model_id, local_dir=str(settings.models_dir / model_id.replace("/", "_")), local_dir_use_symlinks=False, ) # 读取 config.json config_path = Path(local_path) / "config.json" model_type = "text" context_length = 2048 peft_methods = "lora,qlora,adalora" if config_path.exists(): with open(config_path) as f: cfg = json.load(f) model_type = cfg.get("model_type", "text") context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048)) # 更新 ModelCache async with async_session() as session: result = await session.execute(select(ModelCache).where(ModelCache.id == model_id)) existing = result.scalar_one_or_none() if existing: existing.name = model_id.split("/")[-1] existing.model_type = model_type existing.path = local_path existing.is_downloaded = 1 existing.context_length = context_length existing.supported_peft_methods = peft_methods else: record = ModelCache( id=model_id, name=model_id.split("/")[-1], model_type=model_type, path=local_path, is_downloaded=1, context_length=context_length, supported_peft_methods=peft_methods, ) session.add(record) await session.commit() # 更新下载任务 await _update_model_download_status(task_id, "completed", path=local_path) logger.info(f"Model downloaded: {model_id} -> {local_path}") return {"path": local_path} except Exception as e: import traceback tb = traceback.format_exc() logger.error(f"Model download failed: {type(e).__name__}: {e}") logger.error(f"Traceback:\n{tb}") error_msg = str(e) if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower(): error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。" await _update_model_download_status(task_id, "failed", error=error_msg) return {"error": error_msg} async def _update_model_download_status(task_id: str, status: str, path: str = None, error: str = None): async with async_session() as session: result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id)) record = result.scalar_one_or_none() if record: record.status = status if path: record.path = path if error: record.error = error if status in ("completed", "failed"): record.finished_at = datetime.utcnow() await session.commit() background_task_manager.update_task( task_id, status=status, path=path, error=error, finished_at=datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None, ) async def get_model_download_status(task_id: str) -> dict[str, Any]: async with async_session() as session: result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id)) record = result.scalar_one_or_none() if record: return { "task_id": record.id, "model_id": record.model_id, "status": record.status, "use_modelscope": bool(record.use_modelscope), "path": record.path, "error": record.error, "progress": record.progress, "created_at": record.created_at.isoformat() if record.created_at else "", } # 也查内存 mem = background_task_manager.get_task(task_id) if mem: return { "task_id": task_id, "model_id": mem.get("model_id", ""), "status": mem["status"], "error": mem.get("error"), "progress": mem.get("progress", 0), } return {"task_id": task_id, "status": "not_found"} async def list_model_downloads() -> list[dict[str, Any]]: async with async_session() as session: result = await session.execute( select(ModelDownloadTask).order_by(ModelDownloadTask.created_at.desc()) ) records = result.scalars().all() return [ { "task_id": r.id, "model_id": r.model_id, "status": r.status, "use_modelscope": bool(r.use_modelscope), "path": r.path, "error": r.error, "created_at": r.created_at.isoformat() if r.created_at else "", } for r in records ] async def cancel_model_download(task_id: str) -> dict[str, Any]: background_task_manager.cancel_task(task_id) async with async_session() as session: result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id)) record = result.scalar_one_or_none() if record and record.status in ("pending", "downloading"): record.status = "cancelled" record.error = "Cancelled by user" record.finished_at = datetime.utcnow() await session.commit() return {"task_id": task_id, "status": "cancelled"} async def recover_stale_downloads() -> None: """把因重启中断的下载任务标记为 failed。""" async with async_session() as session: result = await session.execute( select(ModelDownloadTask).where( ModelDownloadTask.status.in_(["pending", "downloading"]) ) ) records = result.scalars().all() for record in records: record.status = "failed" record.error = "Server restarted, task interrupted" record.finished_at = datetime.utcnow() if records: await session.commit() logger.info(f"Recovered {len(records)} stale model download tasks") async def list_cached_models() -> list[dict[str, Any]]: """从数据库列出已缓存的模型(不扫描目录,避免 HF 缓存子目录干扰)。""" async with async_session() as session: result = await session.execute(select(ModelCache).order_by(ModelCache.created_at.desc())) records = result.scalars().all() models = [] for r in records: # 验证目录是否真的存在,如果不存在则标记为未下载 dir_exists = r.path and Path(r.path).exists() if not dir_exists: # 尝试从 models_dir 下查找 alt_path = settings.models_dir / r.id.replace("/", "_") dir_exists = alt_path.exists() if dir_exists: r.path = str(alt_path) models.append({ "id": r.id, "name": r.name, "model_type": r.model_type, "path": r.path, "is_downloaded": dir_exists, "context_length": r.context_length, "supported_peft_methods": r.supported_peft_methods.split(",") if r.supported_peft_methods else [], }) return models async def get_model_info(model_id: str) -> dict[str, Any] | None: """获取已缓存模型的元数据。""" async with async_session() as session: result = await session.execute(select(ModelCache).where(ModelCache.id == model_id)) record = result.scalar_one_or_none() if record: return { "id": record.id, "name": record.name, "model_type": record.model_type, "path": record.path, "is_downloaded": bool(record.is_downloaded) and Path(record.path).exists() if record.path else False, "context_length": record.context_length, "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [], } return None async def delete_model(model_id: str) -> dict[str, Any]: """删除已缓存的模型(数据库记录 + 本地文件)。""" async with async_session() as session: result = await session.execute(select(ModelCache).where(ModelCache.id == model_id)) record = result.scalar_one_or_none() if not record: return {"status": "not_found", "message": f"Model not found: {model_id}"} # 删除本地文件目录(对软链接,删除其指向的真实目录) model_dir = Path(record.path) if record.path else settings.models_dir / record.id.replace("/", "_") deleted_files = False if model_dir.is_symlink(): # ModelScope 下载的模型可能是软链接,删除真实目录 real_dir = model_dir.resolve() import shutil if real_dir.exists() and real_dir.is_dir(): shutil.rmtree(real_dir, ignore_errors=True) # 如果还有父级软链接(如 dphn/ 下的其他链接),一并清理 parent_link = model_dir.parent if parent_link.is_symlink(): shutil.rmtree(parent_link, ignore_errors=True) deleted_files = True elif model_dir.exists() and model_dir.is_dir(): import shutil shutil.rmtree(model_dir, ignore_errors=True) deleted_files = True # 删除数据库记录 await session.delete(record) await session.commit() logger.info(f"Model deleted: {model_id} (files={deleted_files})") return {"status": "deleted", "model_id": model_id, "files_deleted": deleted_files}