model_service.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import json
  2. from pathlib import Path
  3. from typing import Any
  4. from app.config import get_settings
  5. from app.core.db import async_session, ModelCache
  6. from app.core.logging import logger
  7. from sqlalchemy import select
  8. settings = get_settings()
  9. async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
  10. """从 HF 或 ModelScope 下载模型到本地缓存。"""
  11. try:
  12. if use_modelscope:
  13. from modelscope import snapshot_download as ms_download
  14. local_path = ms_download(model_id, cache_dir=str(settings.models_dir))
  15. else:
  16. from huggingface_hub import snapshot_download
  17. local_path = snapshot_download(
  18. repo_id=model_id,
  19. local_dir=str(settings.models_dir / model_id.replace("/", "_")),
  20. local_dir_use_symlinks=False,
  21. )
  22. # 读取 config.json 获取模型信息
  23. config_path = Path(local_path) / "config.json"
  24. model_type = "text"
  25. context_length = 2048
  26. peft_methods = "lora,qlora,ia3,adalora,prefix_tuning"
  27. if config_path.exists():
  28. with open(config_path) as f:
  29. cfg = json.load(f)
  30. model_type = cfg.get("model_type", "text")
  31. context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
  32. # 写入数据库(如果已存在则更新)
  33. async with async_session() as session:
  34. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  35. existing = result.scalar_one_or_none()
  36. if existing:
  37. existing.name = model_id.split("/")[-1]
  38. existing.model_type = model_type
  39. existing.path = local_path
  40. existing.is_downloaded = 1
  41. existing.context_length = context_length
  42. existing.supported_peft_methods = peft_methods
  43. else:
  44. record = ModelCache(
  45. id=model_id,
  46. name=model_id.split("/")[-1],
  47. model_type=model_type,
  48. path=local_path,
  49. is_downloaded=1,
  50. context_length=context_length,
  51. supported_peft_methods=peft_methods,
  52. )
  53. session.add(record)
  54. await session.commit()
  55. logger.info(f"Model downloaded: {model_id} -> {local_path}")
  56. return {"model_id": model_id, "status": "completed", "path": local_path}
  57. except Exception as e:
  58. error_msg = str(e)
  59. if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
  60. error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
  61. logger.error(f"Model download failed: {e}")
  62. return {"model_id": model_id, "status": "failed", "error": error_msg}
  63. async def list_cached_models() -> list[dict[str, Any]]:
  64. """从数据库列出已缓存的模型(不扫描目录,避免 HF 缓存子目录干扰)。"""
  65. async with async_session() as session:
  66. result = await session.execute(select(ModelCache).order_by(ModelCache.created_at.desc()))
  67. records = result.scalars().all()
  68. models = []
  69. for r in records:
  70. # 验证目录是否真的存在,如果不存在则标记为未下载
  71. dir_exists = r.path and Path(r.path).exists()
  72. if not dir_exists:
  73. # 尝试从 models_dir 下查找
  74. alt_path = settings.models_dir / r.id.replace("/", "_")
  75. dir_exists = alt_path.exists()
  76. if dir_exists:
  77. r.path = str(alt_path)
  78. models.append({
  79. "id": r.id,
  80. "name": r.name,
  81. "model_type": r.model_type,
  82. "path": r.path,
  83. "is_downloaded": dir_exists,
  84. "context_length": r.context_length,
  85. "supported_peft_methods": r.supported_peft_methods.split(",") if r.supported_peft_methods else [],
  86. })
  87. return models
  88. async def get_model_info(model_id: str) -> dict[str, Any] | None:
  89. """获取已缓存模型的元数据。"""
  90. async with async_session() as session:
  91. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  92. record = result.scalar_one_or_none()
  93. if record:
  94. return {
  95. "id": record.id,
  96. "name": record.name,
  97. "model_type": record.model_type,
  98. "path": record.path,
  99. "is_downloaded": bool(record.is_downloaded) and Path(record.path).exists() if record.path else False,
  100. "context_length": record.context_length,
  101. "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
  102. }
  103. return None
  104. async def delete_model(model_id: str) -> dict[str, Any]:
  105. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  106. async with async_session() as session:
  107. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  108. record = result.scalar_one_or_none()
  109. if not record:
  110. return {"status": "not_found", "message": f"Model not found: {model_id}"}
  111. # 删除本地文件目录
  112. model_dir = Path(record.path) if record.path else settings.models_dir / record.id.replace("/", "_")
  113. deleted_files = False
  114. if model_dir.exists() and model_dir.is_dir():
  115. import shutil
  116. shutil.rmtree(model_dir, ignore_errors=True)
  117. deleted_files = True
  118. # 删除数据库记录
  119. await session.delete(record)
  120. await session.commit()
  121. logger.info(f"Model deleted: {model_id} (files={deleted_files})")
  122. return {"status": "deleted", "model_id": model_id, "files_deleted": deleted_files}