model_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import os
  2. import json
  3. import uuid
  4. from datetime import datetime, timezone
  5. from pathlib import Path
  6. from typing import Any
  7. from app.config import get_settings
  8. from app.core.background_tasks import background_task_manager
  9. from app.core.db import async_session, ModelCache, ModelDownloadTask
  10. from app.core.logging import logger
  11. from sqlalchemy import select
  12. settings = get_settings()
  13. async def resolve_model_path(model_id: str) -> str | None:
  14. """解析模型的实际路径,兼容 HuggingFace 和 ModelScope 的不同目录结构。"""
  15. # 策略 1: 从数据库读取实际路径
  16. info = await get_model_info(model_id)
  17. if info and info.get("path"):
  18. p = Path(info["path"])
  19. if (p / "config.json").exists():
  20. return str(p)
  21. # 策略 2: HuggingFace 风格(namespace_name 扁平化)
  22. hf_path = settings.models_dir / model_id.replace("/", "_")
  23. if (hf_path / "config.json").exists():
  24. return str(hf_path)
  25. # 策略 3: ModelScope 风格(namespace/name 嵌套,含软链接)
  26. ms_path = settings.models_dir / model_id
  27. if (ms_path / "config.json").exists():
  28. return str(ms_path)
  29. # 策略 4: 扫描 models_dir 下所有目录,匹配名称
  30. model_name = model_id.split("/")[-1]
  31. for p in settings.models_dir.rglob("config.json"):
  32. if p.parent.name == model_name or model_name in str(p.parent):
  33. return str(p.parent)
  34. return None
  35. async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
  36. """启动模型下载后台任务,立即返回 task_id。"""
  37. task_id = str(uuid.uuid4())
  38. # 检查是否有正在进行的同模型下载
  39. for tid, t in background_task_manager.tasks.items():
  40. if (
  41. t.get("task_type") == "model_download"
  42. and t.get("model_id") == model_id
  43. and t.get("status") in ("pending", "downloading", "running")
  44. ):
  45. return {"task_id": tid, "model_id": model_id, "status": t["status"], "duplicate": True}
  46. # 写 DB
  47. record = ModelDownloadTask(
  48. id=task_id,
  49. model_id=model_id,
  50. use_modelscope=1 if use_modelscope else 0,
  51. status="pending",
  52. )
  53. async with async_session() as session:
  54. session.add(record)
  55. await session.commit()
  56. # 注册并启动
  57. background_task_manager.register_task(task_id, "model_download", {"model_id": model_id})
  58. await background_task_manager.run(
  59. task_id, "model_download", _execute_model_download(task_id, model_id, use_modelscope)
  60. )
  61. logger.info(f"Model download task started: {model_id} (task_id={task_id})")
  62. return {"task_id": task_id, "model_id": model_id, "status": "pending"}
  63. async def _execute_model_download(task_id: str, model_id: str, use_modelscope: bool) -> dict:
  64. """后台执行模型下载。"""
  65. try:
  66. if use_modelscope:
  67. import subprocess
  68. download_dir = str(settings.models_dir / model_id.replace("/", "_"))
  69. proc = subprocess.run(
  70. [
  71. "modelscope", "download",
  72. "--model", model_id,
  73. "--local_dir", download_dir,
  74. ],
  75. capture_output=True, text=True, timeout=3600,
  76. )
  77. if proc.returncode != 0:
  78. raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
  79. local_path = download_dir
  80. else:
  81. from huggingface_hub import snapshot_download
  82. local_path = snapshot_download(
  83. repo_id=model_id,
  84. local_dir=str(settings.models_dir / model_id.replace("/", "_")),
  85. local_dir_use_symlinks=False,
  86. )
  87. # 读取 config.json
  88. config_path = Path(local_path) / "config.json"
  89. model_type = "text"
  90. context_length = 2048
  91. peft_methods = "lora,qlora,adalora"
  92. if config_path.exists():
  93. with open(config_path) as f:
  94. cfg = json.load(f)
  95. model_type = cfg.get("model_type", "text")
  96. context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
  97. # 更新 ModelCache
  98. async with async_session() as session:
  99. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  100. existing = result.scalar_one_or_none()
  101. if existing:
  102. existing.name = model_id.split("/")[-1]
  103. existing.model_type = model_type
  104. existing.path = local_path
  105. existing.is_downloaded = 1
  106. existing.context_length = context_length
  107. existing.supported_peft_methods = peft_methods
  108. else:
  109. record = ModelCache(
  110. id=model_id,
  111. name=model_id.split("/")[-1],
  112. model_type=model_type,
  113. path=local_path,
  114. is_downloaded=1,
  115. context_length=context_length,
  116. supported_peft_methods=peft_methods,
  117. )
  118. session.add(record)
  119. await session.commit()
  120. # 更新下载任务
  121. await _update_model_download_status(task_id, "completed", path=local_path)
  122. logger.info(f"Model downloaded: {model_id} -> {local_path}")
  123. return {"path": local_path}
  124. except Exception as e:
  125. import traceback
  126. tb = traceback.format_exc()
  127. logger.error(f"Model download failed: {type(e).__name__}: {e}")
  128. logger.error(f"Traceback:\n{tb}")
  129. error_msg = str(e)
  130. if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
  131. error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
  132. await _update_model_download_status(task_id, "failed", error=error_msg)
  133. return {"error": error_msg}
  134. async def _update_model_download_status(task_id: str, status: str, path: str = None, error: str = None):
  135. async with async_session() as session:
  136. result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
  137. record = result.scalar_one_or_none()
  138. if record:
  139. record.status = status
  140. if path:
  141. record.path = path
  142. if error:
  143. record.error = error
  144. if status in ("completed", "failed"):
  145. record.finished_at = datetime.utcnow()
  146. await session.commit()
  147. background_task_manager.update_task(
  148. task_id, status=status, path=path, error=error,
  149. finished_at=datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
  150. )
  151. async def get_model_download_status(task_id: str) -> dict[str, Any]:
  152. async with async_session() as session:
  153. result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
  154. record = result.scalar_one_or_none()
  155. if record:
  156. return {
  157. "task_id": record.id,
  158. "model_id": record.model_id,
  159. "status": record.status,
  160. "use_modelscope": bool(record.use_modelscope),
  161. "path": record.path,
  162. "error": record.error,
  163. "progress": record.progress,
  164. "created_at": record.created_at.isoformat() if record.created_at else "",
  165. }
  166. # 也查内存
  167. mem = background_task_manager.get_task(task_id)
  168. if mem:
  169. return {
  170. "task_id": task_id,
  171. "model_id": mem.get("model_id", ""),
  172. "status": mem["status"],
  173. "error": mem.get("error"),
  174. "progress": mem.get("progress", 0),
  175. }
  176. return {"task_id": task_id, "status": "not_found"}
  177. async def list_model_downloads() -> list[dict[str, Any]]:
  178. async with async_session() as session:
  179. result = await session.execute(
  180. select(ModelDownloadTask).order_by(ModelDownloadTask.created_at.desc())
  181. )
  182. records = result.scalars().all()
  183. return [
  184. {
  185. "task_id": r.id,
  186. "model_id": r.model_id,
  187. "status": r.status,
  188. "use_modelscope": bool(r.use_modelscope),
  189. "path": r.path,
  190. "error": r.error,
  191. "created_at": r.created_at.isoformat() if r.created_at else "",
  192. }
  193. for r in records
  194. ]
  195. async def cancel_model_download(task_id: str) -> dict[str, Any]:
  196. background_task_manager.cancel_task(task_id)
  197. async with async_session() as session:
  198. result = await session.execute(select(ModelDownloadTask).where(ModelDownloadTask.id == task_id))
  199. record = result.scalar_one_or_none()
  200. if record and record.status in ("pending", "downloading"):
  201. record.status = "cancelled"
  202. record.error = "Cancelled by user"
  203. record.finished_at = datetime.utcnow()
  204. await session.commit()
  205. return {"task_id": task_id, "status": "cancelled"}
  206. async def recover_stale_downloads() -> None:
  207. """把因重启中断的下载任务标记为 failed。"""
  208. async with async_session() as session:
  209. result = await session.execute(
  210. select(ModelDownloadTask).where(
  211. ModelDownloadTask.status.in_(["pending", "downloading"])
  212. )
  213. )
  214. records = result.scalars().all()
  215. for record in records:
  216. record.status = "failed"
  217. record.error = "Server restarted, task interrupted"
  218. record.finished_at = datetime.utcnow()
  219. if records:
  220. await session.commit()
  221. logger.info(f"Recovered {len(records)} stale model download tasks")
  222. async def list_cached_models() -> list[dict[str, Any]]:
  223. """从数据库列出已缓存的模型(不扫描目录,避免 HF 缓存子目录干扰)。"""
  224. async with async_session() as session:
  225. result = await session.execute(select(ModelCache).order_by(ModelCache.created_at.desc()))
  226. records = result.scalars().all()
  227. models = []
  228. for r in records:
  229. # 验证目录是否真的存在,如果不存在则标记为未下载
  230. dir_exists = r.path and Path(r.path).exists()
  231. if not dir_exists:
  232. # 尝试从 models_dir 下查找
  233. alt_path = settings.models_dir / r.id.replace("/", "_")
  234. dir_exists = alt_path.exists()
  235. if dir_exists:
  236. r.path = str(alt_path)
  237. models.append({
  238. "id": r.id,
  239. "name": r.name,
  240. "model_type": r.model_type,
  241. "path": r.path,
  242. "is_downloaded": dir_exists,
  243. "context_length": r.context_length,
  244. "supported_peft_methods": r.supported_peft_methods.split(",") if r.supported_peft_methods else [],
  245. })
  246. return models
  247. async def get_model_info(model_id: str) -> dict[str, Any] | None:
  248. """获取已缓存模型的元数据。"""
  249. async with async_session() as session:
  250. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  251. record = result.scalar_one_or_none()
  252. if record:
  253. return {
  254. "id": record.id,
  255. "name": record.name,
  256. "model_type": record.model_type,
  257. "path": record.path,
  258. "is_downloaded": bool(record.is_downloaded) and Path(record.path).exists() if record.path else False,
  259. "context_length": record.context_length,
  260. "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
  261. }
  262. return None
  263. async def delete_model(model_id: str) -> dict[str, Any]:
  264. """删除已缓存的模型(数据库记录 + 本地文件)。"""
  265. async with async_session() as session:
  266. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  267. record = result.scalar_one_or_none()
  268. if not record:
  269. return {"status": "not_found", "message": f"Model not found: {model_id}"}
  270. # 删除本地文件目录(对软链接,删除其指向的真实目录)
  271. model_dir = Path(record.path) if record.path else settings.models_dir / record.id.replace("/", "_")
  272. deleted_files = False
  273. if model_dir.is_symlink():
  274. # ModelScope 下载的模型可能是软链接,删除真实目录
  275. real_dir = model_dir.resolve()
  276. import shutil
  277. if real_dir.exists() and real_dir.is_dir():
  278. shutil.rmtree(real_dir, ignore_errors=True)
  279. # 如果还有父级软链接(如 dphn/ 下的其他链接),一并清理
  280. parent_link = model_dir.parent
  281. if parent_link.is_symlink():
  282. shutil.rmtree(parent_link, ignore_errors=True)
  283. deleted_files = True
  284. elif model_dir.exists() and model_dir.is_dir():
  285. import shutil
  286. shutil.rmtree(model_dir, ignore_errors=True)
  287. deleted_files = True
  288. # 删除数据库记录
  289. await session.delete(record)
  290. await session.commit()
  291. logger.info(f"Model deleted: {model_id} (files={deleted_files})")
  292. return {"status": "deleted", "model_id": model_id, "files_deleted": deleted_files}