model_sets.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import math
  2. from typing import Dict, List, Optional
  3. from fastapi import APIRouter, Depends, Query
  4. from packaging.version import Version
  5. from packaging.specifiers import SpecifierSet
  6. from gpustack_runtime.detector import ManufacturerEnum
  7. from gpustack.api.exceptions import NotFoundException
  8. from gpustack.schemas.common import PaginatedList, Pagination
  9. from gpustack.schemas.gpu_devices import GPUDevice
  10. from gpustack.server.catalog import (
  11. ModelSet,
  12. ModelSetPublic,
  13. ModelSpec,
  14. get_model_sets,
  15. get_model_set_specs,
  16. )
  17. from gpustack.server.deps import ListParamsDep, SessionDep
  18. from gpustack.worker.backends.base import get_ascend_cann_variant
  19. router = APIRouter()
  20. @router.get("", response_model=PaginatedList[ModelSetPublic])
  21. async def get_model_sets(
  22. params: ListParamsDep,
  23. search: str = None,
  24. categories: Optional[List[str]] = Query(None, description="Filter by categories."),
  25. model_sets: List[ModelSet] = Depends(get_model_sets),
  26. ):
  27. if search:
  28. model_sets = [
  29. model for model in model_sets if search.lower() in model.name.lower()
  30. ]
  31. if categories:
  32. model_sets = [
  33. model
  34. for model in model_sets
  35. if model.categories is not None
  36. and any(category in model.categories for category in categories)
  37. ]
  38. count = len(model_sets)
  39. if params.page < 1 or params.perPage < 1:
  40. # Return all items.
  41. pagination = Pagination(
  42. page=1,
  43. perPage=count,
  44. total=count,
  45. totalPage=1,
  46. )
  47. return PaginatedList[ModelSetPublic](items=model_sets, pagination=pagination)
  48. # Paginate results.
  49. total_page = math.ceil(count / params.perPage)
  50. start_index = (params.page - 1) * params.perPage
  51. end_index = start_index + params.perPage
  52. paginated_items = model_sets[start_index:end_index]
  53. pagination = Pagination(
  54. page=params.page,
  55. perPage=params.perPage,
  56. total=count,
  57. totalPage=total_page,
  58. )
  59. return PaginatedList[ModelSetPublic](items=paginated_items, pagination=pagination)
  60. @router.get("/{id}/specs", response_model=PaginatedList[ModelSpec])
  61. async def get_model_specs(
  62. session: SessionDep,
  63. id: int,
  64. params: ListParamsDep,
  65. cluster_id: Optional[int] = Query(
  66. None, description="Filter specs compatible with the given cluster ID."
  67. ),
  68. model_set_specs: Dict[int, List[ModelSpec]] = Depends(get_model_set_specs),
  69. ):
  70. specs = model_set_specs.get(id, [])
  71. if not specs:
  72. raise NotFoundException(message="Model set not found")
  73. fields = {}
  74. if cluster_id:
  75. fields["cluster_id"] = cluster_id
  76. gpus = await GPUDevice.all_by_fields(session, fields)
  77. specs = filter_specs_by_gpu(gpus or [], specs)
  78. count = len(specs)
  79. total_page = math.ceil(count / params.perPage)
  80. pagination = Pagination(
  81. page=params.page,
  82. perPage=params.perPage,
  83. total=count,
  84. totalPage=total_page,
  85. )
  86. return PaginatedList[ModelSpec](items=specs, pagination=pagination)
  87. def filter_specs_by_gpu(
  88. gpus: List[GPUDevice], specs: List[ModelSpec]
  89. ) -> List[ModelSpec]:
  90. """Filter model specs based on the GPUs available."""
  91. # Matched specs mapping by mode (standard, throughput, latency, etc.).
  92. filtered: Dict[str, ModelSpec] = {}
  93. gpu_vendors = {gpu.vendor.lower() for gpu in gpus}
  94. # Vendor variants. Now only Ascend CANN variants are supported.
  95. vendor_variants = {
  96. get_ascend_cann_variant(gpu.arch_family).lower()
  97. for gpu in gpus
  98. if gpu.arch_family is not None and gpu.vendor == ManufacturerEnum.ASCEND
  99. }
  100. for spec in specs:
  101. # If already selected for this mode, skip
  102. if spec.mode in filtered:
  103. continue
  104. gf = spec.gpu_filters
  105. if gf is None:
  106. filtered[spec.mode] = spec
  107. continue
  108. # --- GPU Vendor check ---
  109. if gf.vendor:
  110. wanted = {v.lower() for v in gf.vendor}
  111. if wanted.isdisjoint(gpu_vendors):
  112. continue
  113. # --- Compute Capability check ---
  114. if gf.compute_capability:
  115. if not any(
  116. match_compute_capability(gf.compute_capability, gpu.compute_capability)
  117. for gpu in gpus
  118. ):
  119. continue
  120. # --- Vendor Variant check ---
  121. if gf.vendor_variant:
  122. wanted = {v.lower() for v in gf.vendor_variant}
  123. if wanted.isdisjoint(vendor_variants):
  124. continue
  125. filtered[spec.mode] = spec
  126. result = list(filtered.values())
  127. # Sort by mode priority in case catalog items are messy. These are our conventional priorities.
  128. ORDER = {"throughput": 0, "latency": 1, "standard": 2}
  129. result.sort(key=lambda spec: (ORDER.get(spec.mode, 999), spec.mode))
  130. return result
  131. def match_compute_capability(filter_str: Optional[str], gpu_cc: Optional[str]) -> bool:
  132. """Check if the GPU compute capability matches the given filter string.
  133. Args:
  134. filter_str (Optional[str]): The pip-style version specifier string.
  135. gpu_cc (Optional[str]): The GPU compute capability version string.
  136. Returns:
  137. bool: True if the GPU compute capability matches the filter, False otherwise.
  138. """
  139. if not filter_str:
  140. return True
  141. if not gpu_cc:
  142. return False
  143. try:
  144. spec_set = SpecifierSet(filter_str)
  145. cc_version = Version(gpu_cc)
  146. return cc_version in spec_set
  147. except Exception:
  148. return False