catalog.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from datetime import date
  2. import logging
  3. import os
  4. from typing import Dict, List, Optional
  5. from urllib.parse import urlparse
  6. from pathlib import Path
  7. from fastapi import APIRouter
  8. import requests
  9. import yaml
  10. from gpustack.schemas.model_sets import (
  11. Catalog,
  12. ModelSet,
  13. DraftModel,
  14. ModelSetPublic,
  15. ModelSpec,
  16. )
  17. from gpustack.utils import file
  18. from gpustack.utils.compat_importlib import pkg_resources
  19. logger = logging.getLogger(__name__)
  20. router = APIRouter()
  21. model_catalog: Catalog = None
  22. model_set_specs: Dict[int, List[ModelSpec]] = {}
  23. model_set_specs_by_key: Dict[str, ModelSpec] = {}
  24. def get_model_sets() -> List[ModelSet]:
  25. return model_catalog.model_sets if model_catalog else []
  26. def get_catalog_draft_models() -> List[DraftModel]:
  27. return model_catalog.draft_models if model_catalog else []
  28. def get_model_set_specs() -> Dict[int, List[ModelSpec]]:
  29. return model_set_specs
  30. def convert_to_public(model_sets: List[ModelSet]) -> List[ModelSetPublic]:
  31. return [
  32. ModelSetPublic(**model_set.model_dump(exclude={"templates"}))
  33. for model_set in model_sets
  34. ]
  35. def init_model_catalog(model_catalog_file: Optional[str] = None):
  36. model_sets: List[ModelSet] = []
  37. try:
  38. if model_catalog_file is None:
  39. model_catalog_file = get_builtin_model_catalog_file()
  40. raw_data = None
  41. parsed_url = urlparse(model_catalog_file)
  42. if parsed_url.scheme in ("http", "https"):
  43. response = requests.get(model_catalog_file)
  44. response.raise_for_status()
  45. raw_data = yaml.safe_load(response.text)
  46. else:
  47. with open(model_catalog_file, "r") as f:
  48. raw_data = yaml.safe_load(f)
  49. global model_catalog
  50. model_catalog = Catalog(**raw_data)
  51. logger.debug(
  52. f"Loaded {len(model_catalog.model_sets)} model sets from model catalog: {model_catalog_file}"
  53. )
  54. model_sets = model_catalog.model_sets
  55. # Use index as the id for each model set
  56. for idx, model_set in enumerate(model_sets):
  57. model_set.id = idx + 1
  58. model_sets = sort_model_sets(model_sets)
  59. model_catalog.model_sets = convert_to_public(model_sets)
  60. init_model_set_specs(model_sets)
  61. except Exception as e:
  62. raise Exception(f"Failed to load model catalog: {e}")
  63. def sort_model_sets(model_sets: List[ModelSet]) -> List[ModelSet]:
  64. """
  65. Sort model sets by order asc, then by release_date desc
  66. """
  67. return sorted(
  68. model_sets,
  69. key=lambda x: (
  70. x.order if x.order is not None else float('inf'),
  71. -(x.release_date.toordinal() if x.release_date else date.min.toordinal()),
  72. ),
  73. )
  74. def init_model_set_specs(model_sets: List[ModelSet]):
  75. global model_set_specs, model_set_specs_by_key
  76. model_set_specs = {}
  77. for model_set in model_sets:
  78. model_set_specs[model_set.id] = model_set.specs
  79. # Initialize specs by key for quick lookup.
  80. # Later specs override earlier ones to prioritize standard specs.
  81. for spec in reversed(model_set.specs):
  82. if not model_set_specs_by_key.get(spec.model_source_key):
  83. model_set_specs_by_key[spec.model_source_key] = spec
  84. def prepare_chat_templates(data_dir: str):
  85. source_dir = pkg_resources.files("gpustack").joinpath("assets/chat_templates")
  86. target_dir = Path(data_dir).joinpath("chat_templates")
  87. if not os.path.exists(source_dir):
  88. return
  89. file.copy_with_owner(source_dir, target_dir)
  90. def get_builtin_model_catalog_file() -> str:
  91. huggingface_url = "https://huggingface.co"
  92. modelscope_url = "https://modelscope.cn"
  93. model_catalog_file_name = "model-catalog.yaml"
  94. if not can_access(huggingface_url) and can_access(modelscope_url):
  95. model_catalog_file_name = "model-catalog-modelscope.yaml"
  96. logger.info(f"Cannot access {huggingface_url}, using ModelScope model catalog.")
  97. return str(pkg_resources.files("gpustack.assets").joinpath(model_catalog_file_name))
  98. def can_access(url: str) -> bool:
  99. """
  100. Check if the URL is accessible
  101. """
  102. try:
  103. response = requests.get(url, timeout=3)
  104. return response.status_code >= 200 and response.status_code < 300
  105. except requests.RequestException:
  106. return False