| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import math
- from typing import Dict, List, Optional
- from fastapi import APIRouter, Depends, Query
- from packaging.version import Version
- from packaging.specifiers import SpecifierSet
- from gpustack_runtime.detector import ManufacturerEnum
- from gpustack.api.exceptions import NotFoundException
- from gpustack.schemas.common import PaginatedList, Pagination
- from gpustack.schemas.gpu_devices import GPUDevice
- from gpustack.server.catalog import (
- ModelSet,
- ModelSetPublic,
- ModelSpec,
- get_model_sets,
- get_model_set_specs,
- )
- from gpustack.server.deps import ListParamsDep, SessionDep
- from gpustack.worker.backends.base import get_ascend_cann_variant
- router = APIRouter()
- @router.get("", response_model=PaginatedList[ModelSetPublic])
- async def get_model_sets(
- params: ListParamsDep,
- search: str = None,
- categories: Optional[List[str]] = Query(None, description="Filter by categories."),
- model_sets: List[ModelSet] = Depends(get_model_sets),
- ):
- if search:
- model_sets = [
- model for model in model_sets if search.lower() in model.name.lower()
- ]
- if categories:
- model_sets = [
- model
- for model in model_sets
- if model.categories is not None
- and any(category in model.categories for category in categories)
- ]
- count = len(model_sets)
- if params.page < 1 or params.perPage < 1:
- # Return all items.
- pagination = Pagination(
- page=1,
- perPage=count,
- total=count,
- totalPage=1,
- )
- return PaginatedList[ModelSetPublic](items=model_sets, pagination=pagination)
- # Paginate results.
- total_page = math.ceil(count / params.perPage)
- start_index = (params.page - 1) * params.perPage
- end_index = start_index + params.perPage
- paginated_items = model_sets[start_index:end_index]
- pagination = Pagination(
- page=params.page,
- perPage=params.perPage,
- total=count,
- totalPage=total_page,
- )
- return PaginatedList[ModelSetPublic](items=paginated_items, pagination=pagination)
- @router.get("/{id}/specs", response_model=PaginatedList[ModelSpec])
- async def get_model_specs(
- session: SessionDep,
- id: int,
- params: ListParamsDep,
- cluster_id: Optional[int] = Query(
- None, description="Filter specs compatible with the given cluster ID."
- ),
- model_set_specs: Dict[int, List[ModelSpec]] = Depends(get_model_set_specs),
- ):
- specs = model_set_specs.get(id, [])
- if not specs:
- raise NotFoundException(message="Model set not found")
- fields = {}
- if cluster_id:
- fields["cluster_id"] = cluster_id
- gpus = await GPUDevice.all_by_fields(session, fields)
- specs = filter_specs_by_gpu(gpus or [], specs)
- count = len(specs)
- total_page = math.ceil(count / params.perPage)
- pagination = Pagination(
- page=params.page,
- perPage=params.perPage,
- total=count,
- totalPage=total_page,
- )
- return PaginatedList[ModelSpec](items=specs, pagination=pagination)
- def filter_specs_by_gpu(
- gpus: List[GPUDevice], specs: List[ModelSpec]
- ) -> List[ModelSpec]:
- """Filter model specs based on the GPUs available."""
- # Matched specs mapping by mode (standard, throughput, latency, etc.).
- filtered: Dict[str, ModelSpec] = {}
- gpu_vendors = {gpu.vendor.lower() for gpu in gpus}
- # Vendor variants. Now only Ascend CANN variants are supported.
- vendor_variants = {
- get_ascend_cann_variant(gpu.arch_family).lower()
- for gpu in gpus
- if gpu.arch_family is not None and gpu.vendor == ManufacturerEnum.ASCEND
- }
- for spec in specs:
- # If already selected for this mode, skip
- if spec.mode in filtered:
- continue
- gf = spec.gpu_filters
- if gf is None:
- filtered[spec.mode] = spec
- continue
- # --- GPU Vendor check ---
- if gf.vendor:
- wanted = {v.lower() for v in gf.vendor}
- if wanted.isdisjoint(gpu_vendors):
- continue
- # --- Compute Capability check ---
- if gf.compute_capability:
- if not any(
- match_compute_capability(gf.compute_capability, gpu.compute_capability)
- for gpu in gpus
- ):
- continue
- # --- Vendor Variant check ---
- if gf.vendor_variant:
- wanted = {v.lower() for v in gf.vendor_variant}
- if wanted.isdisjoint(vendor_variants):
- continue
- filtered[spec.mode] = spec
- result = list(filtered.values())
- # Sort by mode priority in case catalog items are messy. These are our conventional priorities.
- ORDER = {"throughput": 0, "latency": 1, "standard": 2}
- result.sort(key=lambda spec: (ORDER.get(spec.mode, 999), spec.mode))
- return result
- def match_compute_capability(filter_str: Optional[str], gpu_cc: Optional[str]) -> bool:
- """Check if the GPU compute capability matches the given filter string.
- Args:
- filter_str (Optional[str]): The pip-style version specifier string.
- gpu_cc (Optional[str]): The GPU compute capability version string.
- Returns:
- bool: True if the GPU compute capability matches the filter, False otherwise.
- """
- if not filter_str:
- return True
- if not gpu_cc:
- return False
- try:
- spec_set = SpecifierSet(filter_str)
- cc_version = Version(gpu_cc)
- return cc_version in spec_set
- except Exception:
- return False
|