model_service.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import json
  2. from pathlib import Path
  3. from typing import Any
  4. from app.config import get_settings
  5. from app.core.db import async_session, ModelCache
  6. from app.core.logging import logger
  7. from sqlalchemy import select
  8. settings = get_settings()
  9. async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
  10. """从 HF 或 ModelScope 下载模型到本地缓存。"""
  11. try:
  12. if use_modelscope:
  13. from modelscope import snapshot_download as ms_download
  14. local_path = ms_download(model_id, cache_dir=str(settings.models_dir))
  15. else:
  16. from huggingface_hub import snapshot_download
  17. local_path = snapshot_download(
  18. repo_id=model_id,
  19. local_dir=str(settings.models_dir / model_id.replace("/", "_")),
  20. local_dir_use_symlinks=False,
  21. )
  22. # 读取 config.json 获取模型信息
  23. config_path = Path(local_path) / "config.json"
  24. model_type = "text"
  25. context_length = 2048
  26. peft_methods = "lora,qlora,ia3,adalora,prefix_tuning"
  27. if config_path.exists():
  28. with open(config_path) as f:
  29. cfg = json.load(f)
  30. model_type = cfg.get("model_type", "text")
  31. context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
  32. # 写入数据库
  33. async with async_session() as session:
  34. record = ModelCache(
  35. id=model_id,
  36. name=model_id.split("/")[-1],
  37. model_type=model_type,
  38. path=local_path,
  39. is_downloaded=1,
  40. context_length=context_length,
  41. supported_peft_methods=peft_methods,
  42. )
  43. session.add(record)
  44. await session.commit()
  45. logger.info(f"Model downloaded: {model_id} -> {local_path}")
  46. return {"model_id": model_id, "status": "completed", "path": local_path}
  47. except Exception as e:
  48. error_msg = str(e)
  49. if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
  50. error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
  51. logger.error(f"Model download failed: {e}")
  52. return {"model_id": model_id, "status": "failed", "error": error_msg}
  53. def list_cached_models() -> list[dict[str, Any]]:
  54. """列出本地已缓存的模型。"""
  55. models_dir = settings.models_dir
  56. if not models_dir.exists():
  57. return []
  58. result = []
  59. for d in models_dir.iterdir():
  60. if not d.is_dir():
  61. continue
  62. config_path = d / "config.json"
  63. info: dict[str, Any] = {
  64. "id": d.name,
  65. "name": d.name,
  66. "model_type": "text",
  67. "path": str(d),
  68. "is_downloaded": True,
  69. "context_length": None,
  70. "supported_peft_methods": [],
  71. }
  72. if config_path.exists():
  73. with open(config_path) as f:
  74. cfg = json.load(f)
  75. info["model_type"] = cfg.get("model_type", "text")
  76. info["context_length"] = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
  77. info["supported_peft_methods"] = ["lora", "qlora", "ia3", "adalora", "prefix_tuning"]
  78. result.append(info)
  79. return result
  80. async def get_model_info(model_id: str) -> dict[str, Any] | None:
  81. """获取已缓存模型的元数据。"""
  82. # 先查数据库
  83. async with async_session() as session:
  84. result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
  85. record = result.scalar_one_or_none()
  86. if record:
  87. return {
  88. "id": record.id,
  89. "name": record.name,
  90. "model_type": record.model_type,
  91. "path": record.path,
  92. "is_downloaded": bool(record.is_downloaded),
  93. "context_length": record.context_length,
  94. "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
  95. }
  96. # 回退:直接从文件系统读取
  97. model_dir = settings.models_dir / model_id.replace("/", "_")
  98. config_path = model_dir / "config.json"
  99. if config_path.exists():
  100. with open(config_path) as f:
  101. cfg = json.load(f)
  102. return {
  103. "id": model_id,
  104. "name": model_id.split("/")[-1],
  105. "model_type": cfg.get("model_type", "text"),
  106. "path": str(model_dir),
  107. "is_downloaded": True,
  108. "context_length": cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048)),
  109. "supported_peft_methods": ["lora", "qlora", "ia3", "adalora", "prefix_tuning"],
  110. }
  111. return None