model_service.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import os
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.db import async_session, ModelCache
  7. from app.core.logging import logger
  8. from sqlalchemy import select
  9. settings = get_settings()
  10. async def resolve_model_path(model_id: str) -> str | None:
  11. """解析模型的实际路径,兼容 HuggingFace 和 ModelScope 的不同目录结构。"""
  12. # 策略 1: 从数据库读取实际路径
  13. info = await get_model_info(model_id)
  14. if info and info.get("path"):
  15. p = Path(info["path"])
  16. if (p / "config.json").exists():
  17. return str(p)
  18. # 策略 2: HuggingFace 风格(namespace_name 扁平化)
  19. hf_path = settings.models_dir / model_id.replace("/", "_")
  20. if (hf_path / "config.json").exists():
  21. return str(hf_path)
  22. # 策略 3: ModelScope 风格(namespace/name 嵌套,含软链接)
  23. ms_path = settings.models_dir / model_id
  24. if (ms_path / "config.json").exists():
  25. return str(ms_path)
  26. # 策略 4: 扫描 models_dir 下所有目录,匹配名称
  27. model_name = model_id.split("/")[-1]
  28. for p in settings.models_dir.rglob("config.json"):
  29. if p.parent.name == model_name or model_name in str(p.parent):
  30. return str(p.parent)
  31. return None
  32. async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
  33. """从 HF 或 ModelScope 下载模型到本地缓存。"""
  34. logger.info(f"Starting model download: {model_id} (source={'ModelScope' if use_modelscope else 'HuggingFace'})")
  35. try:
  36. if use_modelscope:
  37. import subprocess
  38. download_dir = str(settings.models_dir / model_id.replace("/", "_"))
  39. # 用独立进程调用 CLI,完全隔离 FastAPI 事件循环,避免 __aenter__ 错误
  40. proc = subprocess.run(
  41. [
  42. "modelscope", "download",
  43. "--model", model_id,
  44. "--local_dir", download_dir,
  45. ],
  46. capture_output=True, text=True, timeout=3600,
  47. )
  48. if proc.returncode != 0:
  49. raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
  50. local_path = download_dir
  51. else:
  52. from huggingface_hub import snapshot_download
  53. local_path_dir = str(settings.models_dir / model_id.replace("/", "_"))
  54. logger.info(f"Downloading from HuggingFace: {model_id} -> {local_path_dir}")
  55. local_path = snapshot_download(
  56. repo_id=model_id,
  57. local_dir=local_path_dir,
  58. local_dir_use_symlinks=False,
  59. )
  60. # 读取 config.json 获取模型信息
  61. config_path = Path(local_path) / "config.json"
  62. model_type = "text"
  63. context_length = 2048
  64. peft_methods = "lora,qlora,ia3,adalora,prefix_tuning"
  65. if config_path.exists():
  66. with open(config_path) as f:
  67. cfg = json.load(f)
  68. model_type = cfg.get("model_type", "text")
  69. context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
  70. # 写入数据库(如果已存在则更新)
  71. async with async_session() as session:
  72. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  73. existing = result.scalar_one_or_none()
  74. if existing:
  75. existing.name = model_id.split("/")[-1]
  76. existing.model_type = model_type
  77. existing.path = local_path
  78. existing.is_downloaded = 1
  79. existing.context_length = context_length
  80. existing.supported_peft_methods = peft_methods
  81. else:
  82. record = ModelCache(
  83. id=model_id,
  84. name=model_id.split("/")[-1],
  85. model_type=model_type,
  86. path=local_path,
  87. is_downloaded=1,
  88. context_length=context_length,
  89. supported_peft_methods=peft_methods,
  90. )
  91. session.add(record)
  92. await session.commit()
  93. logger.info(f"Model downloaded: {model_id} -> {local_path}")
  94. return {"model_id": model_id, "status": "completed", "path": local_path}
  95. except Exception as e:
  96. import traceback
  97. tb = traceback.format_exc()
  98. logger.error(f"Model download failed: {type(e).__name__}: {e}")
  99. logger.error(f"Traceback:\n{tb}")
  100. error_msg = str(e)
  101. if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
  102. error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
  103. return {"model_id": model_id, "status": "failed", "error": error_msg}
  104. async def list_cached_models() -> list[dict[str, Any]]:
  105. """从数据库列出已缓存的模型(不扫描目录,避免 HF 缓存子目录干扰)。"""
  106. async with async_session() as session:
  107. result = await session.execute(select(ModelCache).order_by(ModelCache.created_at.desc()))
  108. records = result.scalars().all()
  109. models = []
  110. for r in records:
  111. # 验证目录是否真的存在,如果不存在则标记为未下载
  112. dir_exists = r.path and Path(r.path).exists()
  113. if not dir_exists:
  114. # 尝试从 models_dir 下查找
  115. alt_path = settings.models_dir / r.id.replace("/", "_")
  116. dir_exists = alt_path.exists()
  117. if dir_exists:
  118. r.path = str(alt_path)
  119. models.append({
  120. "id": r.id,
  121. "name": r.name,
  122. "model_type": r.model_type,
  123. "path": r.path,
  124. "is_downloaded": dir_exists,
  125. "context_length": r.context_length,
  126. "supported_peft_methods": r.supported_peft_methods.split(",") if r.supported_peft_methods else [],
  127. })
  128. return models
  129. async def get_model_info(model_id: str) -> dict[str, Any] | None:
  130. """获取已缓存模型的元数据。"""
  131. async with async_session() as session:
  132. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  133. record = result.scalar_one_or_none()
  134. if record:
  135. return {
  136. "id": record.id,
  137. "name": record.name,
  138. "model_type": record.model_type,
  139. "path": record.path,
  140. "is_downloaded": bool(record.is_downloaded) and Path(record.path).exists() if record.path else False,
  141. "context_length": record.context_length,
  142. "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
  143. }
  144. return None
  145. async def delete_model(model_id: str) -> dict[str, Any]:
  146. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  147. async with async_session() as session:
  148. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  149. record = result.scalar_one_or_none()
  150. if not record:
  151. return {"status": "not_found", "message": f"Model not found: {model_id}"}
  152. # 删除本地文件目录(对软链接,删除其指向的真实目录)
  153. model_dir = Path(record.path) if record.path else settings.models_dir / record.id.replace("/", "_")
  154. deleted_files = False
  155. if model_dir.is_symlink():
  156. # ModelScope 下载的模型可能是软链接,删除真实目录
  157. real_dir = model_dir.resolve()
  158. import shutil
  159. if real_dir.exists() and real_dir.is_dir():
  160. shutil.rmtree(real_dir, ignore_errors=True)
  161. # 如果还有父级软链接(如 dphn/ 下的其他链接),一并清理
  162. parent_link = model_dir.parent
  163. if parent_link.is_symlink():
  164. shutil.rmtree(parent_link, ignore_errors=True)
  165. deleted_files = True
  166. elif model_dir.exists() and model_dir.is_dir():
  167. import shutil
  168. shutil.rmtree(model_dir, ignore_errors=True)
  169. deleted_files = True
  170. # 删除数据库记录
  171. await session.delete(record)
  172. await session.commit()
  173. logger.info(f"Model deleted: {model_id} (files={deleted_files})")
  174. return {"status": "deleted", "model_id": model_id, "files_deleted": deleted_files}