| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581 |
- import logging
- import math
- from typing import List, Optional, Union
- from fastapi import APIRouter, Depends, Query, Request
- from fastapi.responses import RedirectResponse, StreamingResponse
- from urllib.parse import urlencode
- from gpustack_runtime.detector import ManufacturerEnum
- from sqlalchemy.orm import selectinload
- from sqlmodel import select, and_, or_
- from sqlmodel.ext.asyncio.session import AsyncSession
- from enum import Enum
- from gpustack.api.exceptions import (
- AlreadyExistsException,
- InternalServerErrorException,
- BadRequestException,
- )
- from gpustack.schemas.common import Pagination
- from gpustack.schemas.inference_backend import is_custom_backend
- from gpustack.schemas.models import (
- ModelInstance,
- ModelInstancesPublic,
- BackendEnum,
- ModelListParams,
- )
- from gpustack.schemas.clusters import Cluster
- from gpustack.schemas.workers import GPUDeviceStatus, Worker
- from gpustack.api.tenant import (
- bypass_tenant_filter,
- assert_resource_visible,
- tenant_list_conditions,
- )
- from gpustack.server.db import async_session
- from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
- from gpustack.schemas.models import (
- Model,
- ModelCreate,
- ModelSpecBase,
- ModelUpdate,
- ModelPublic,
- ModelsPublic,
- )
- from gpustack.schemas.model_routes import (
- ModelRoute,
- ModelRouteTarget,
- TargetStateEnum,
- )
- from gpustack.server.services import (
- ModelService,
- WorkerService,
- revoke_model_access_cache,
- )
- from gpustack.utils.command import find_parameter
- from gpustack.utils.convert import safe_int
- from gpustack.utils.gpu import parse_gpu_id
- from gpustack.routes.model_common import (
- build_category_conditions,
- categories_filter,
- )
- from gpustack.config.config import get_global_config
- from gpustack.utils.grafana import resolve_grafana_base_url
- router = APIRouter()
- logger = logging.getLogger(__name__)
- class ModelStateFilterEnum(str, Enum):
- READY = "ready"
- NOT_READY = "not_ready"
- STOPPED = "stopped"
- @router.get("", response_model=ModelsPublic)
- async def get_models(
- ctx: TenantContextDep,
- params: ModelListParams = Depends(),
- state: Optional[ModelStateFilterEnum] = Query(
- default=None,
- description="Filter by model state.",
- ),
- search: str = None,
- categories: Optional[List[str]] = Query(None, description="Filter by categories."),
- cluster_id: int = None,
- backend: Optional[str] = Query(None, description="Filter by backend."),
- ):
- fuzzy_fields = {}
- if search:
- fuzzy_fields = {"name": search}
- fields = {}
- if cluster_id:
- fields["cluster_id"] = cluster_id
- if backend:
- fields["backend"] = backend
- # Streaming uses field-equality only; scope by current org so non-admin
- # users never see cross-org rows via the live stream. Admin without an
- # explicit org context keeps the unfiltered cross-org stream. System
- # users (workers / cluster accounts) bypass — they need the cross-org
- # view to handle instances scheduled to them on clusters outside their
- # default Org.
- if ctx.current_principal_id is not None and not bypass_tenant_filter(ctx):
- fields["owner_principal_id"] = ctx.current_principal_id
- if params.watch:
- return StreamingResponse(
- Model.streaming(
- fields=fields,
- fuzzy_fields=fuzzy_fields,
- filter_func=lambda data: categories_filter(data, categories),
- ),
- media_type="text/event-stream",
- )
- async with async_session() as session:
- extra_conditions = list(tenant_list_conditions(ctx, Model))
- if categories:
- conditions = build_category_conditions(session, Model, categories)
- extra_conditions.append(or_(*conditions))
- if state is None:
- pass
- elif state == ModelStateFilterEnum.READY:
- extra_conditions.append(Model.ready_replicas > 0)
- elif state == ModelStateFilterEnum.NOT_READY:
- extra_conditions.append(and_(Model.ready_replicas == 0, Model.replicas > 0))
- elif state == ModelStateFilterEnum.STOPPED:
- extra_conditions.append(Model.replicas == 0)
- order_by = params.order_by
- if order_by:
- # When sorting by "source", add additional sorting fields for deterministic ordering
- new_order_by = []
- for field, direction in order_by:
- new_order_by.append((field, direction))
- if field == "source":
- new_order_by.append(("huggingface_repo_id", direction))
- new_order_by.append(("huggingface_filename", direction))
- new_order_by.append(("model_scope_model_id", direction))
- new_order_by.append(("model_scope_file_path", direction))
- new_order_by.append(("local_path", direction))
- order_by = new_order_by
- return await Model.paginated_by_query(
- session=session,
- fuzzy_fields=fuzzy_fields,
- extra_conditions=extra_conditions,
- page=params.page,
- per_page=params.perPage,
- fields=fields,
- order_by=order_by,
- )
- @router.get("/{id}", response_model=ModelPublic)
- async def get_model(
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- ):
- return await _get_model(session=session, ctx=ctx, id=id)
- @router.get("/{id}/dashboard")
- async def get_model_dashboard(
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- request: Request,
- ):
- model = await _get_model(session=session, ctx=ctx, id=id)
- cfg = get_global_config()
- if not cfg.get_grafana_url() or not cfg.grafana_model_dashboard_uid:
- raise InternalServerErrorException(
- message="Grafana dashboard settings are not configured"
- )
- cluster = None
- if model.cluster_id is not None:
- cluster = await Cluster.one_by_id(session, model.cluster_id)
- query_params = {}
- if cluster is not None:
- query_params["var-cluster_name"] = cluster.name
- query_params["var-model_name"] = model.name
- grafana_base = resolve_grafana_base_url(cfg, request)
- slug = "gpustack-model"
- dashboard_url = f"{grafana_base}/d/{cfg.grafana_model_dashboard_uid}/{slug}"
- if query_params:
- dashboard_url = f"{dashboard_url}?{urlencode(query_params)}"
- return RedirectResponse(url=dashboard_url, status_code=302)
- async def _get_model(
- session: SessionDep,
- ctx,
- id: int,
- ):
- model = await Model.one_by_id(session, id)
- assert_resource_visible(ctx, model, not_found_message="Model not found")
- return model
- @router.get("/{id}/instances", response_model=ModelInstancesPublic)
- async def get_model_instances(ctx: TenantContextDep, id: int, params: ListParamsDep):
- if params.watch:
- fields = {"model_id": id}
- return StreamingResponse(
- ModelInstance.streaming(fields=fields),
- media_type="text/event-stream",
- )
- async with async_session() as session:
- model = await Model.one_by_id(
- session, id, options=[selectinload(Model.instances)]
- )
- assert_resource_visible(ctx, model, not_found_message="Model not found")
- instances = model.instances
- count = len(instances)
- total_page = math.ceil(count / params.perPage)
- pagination = Pagination(
- page=params.page,
- perPage=params.perPage,
- total=count,
- totalPage=total_page,
- )
- return ModelInstancesPublic(items=instances, pagination=pagination)
- async def validate_model_in(
- session: SessionDep,
- model_in: Union[ModelCreate, ModelUpdate, ModelSpecBase],
- *,
- cluster_id: Optional[int] = None,
- ):
- if model_in.gpu_selector is not None and model_in.replicas > 0:
- await validate_gpu_ids(session, model_in, cluster_id=cluster_id)
- if is_custom_backend(model_in.backend):
- logger.info("Skip model validation for custom backend")
- return
- if model_in.backend_parameters:
- param_gpu_layers = find_parameter(
- model_in.backend_parameters, ["ngl", "gpu-layers", "n-gpu-layers"]
- )
- if param_gpu_layers:
- int_param_gpu_layers = safe_int(param_gpu_layers, None)
- if (
- not param_gpu_layers.isdigit()
- or int_param_gpu_layers < 0
- or int_param_gpu_layers > 999
- ):
- raise BadRequestException(
- message="Invalid backend parameter --gpu-layers. Please provide an integer in the range 0-999 (inclusive)."
- )
- if (
- int_param_gpu_layers == 0
- and model_in.gpu_selector is not None
- and len(model_in.gpu_selector.gpu_ids) > 0
- ):
- raise BadRequestException(
- message="Cannot set --gpu-layers to 0 and manually select GPUs at the same time. Setting --gpu-layers to 0 means running on CPU only."
- )
- unsupported_params = [
- (
- ["port"],
- (
- "Setting the port using --port is not supported. Ports are "
- "automatically allocated by GPUStack."
- ),
- ),
- (
- ["api-key"],
- (
- "Setting the API key using --api-key is not supported. API keys "
- "are managed by GPUStack."
- ),
- ),
- (
- ["served-model-name"],
- (
- "Setting the served model name using --served-model-name is not "
- "supported. The model name is automatically set from your "
- "deployment configuration."
- ),
- ),
- ]
- for param_names, error_message in unsupported_params:
- if find_parameter(model_in.backend_parameters, param_names):
- raise BadRequestException(message=error_message)
- async def validate_gpu_ids( # noqa: C901
- session: SessionDep,
- model_in: Union[ModelCreate, ModelUpdate, ModelSpecBase],
- *,
- cluster_id: Optional[int] = None,
- ):
- effective_cluster_id = (
- cluster_id if cluster_id is not None else getattr(model_in, "cluster_id", None)
- )
- if (
- model_in.gpu_selector
- and model_in.gpu_selector.gpu_ids
- and model_in.gpu_selector.gpus_per_replica
- ):
- if len(model_in.gpu_selector.gpu_ids) < model_in.gpu_selector.gpus_per_replica:
- raise BadRequestException(
- message="The number of selected GPUs must be greater than or equal to gpus_per_replica."
- )
- model_backend = model_in.backend
- if model_backend == BackendEnum.VOX_BOX and (
- len(model_in.gpu_selector.gpu_ids) > 1
- or (
- model_in.gpu_selector.gpus_per_replica is not None
- and model_in.gpu_selector.gpus_per_replica > 1
- )
- ):
- raise BadRequestException(
- message="The vox-box backend is restricted to execution on a single NVIDIA GPU."
- )
- worker_name_set = set()
- for gpu_id in model_in.gpu_selector.gpu_ids:
- is_valid, matched = parse_gpu_id(gpu_id)
- if not is_valid:
- raise BadRequestException(message=f"Invalid GPU ID: {gpu_id}")
- worker_name = matched.get("worker_name")
- gpu_index = safe_int(matched.get("gpu_index"), -1)
- worker_name_set.add(worker_name)
- if effective_cluster_id is None:
- raise BadRequestException(
- message=f"A cluster context is required for manual GPU selection, but was not provided. Cannot validate worker '{worker_name}'."
- )
- worker = await WorkerService(session).get_by_cluster_id_name(
- effective_cluster_id, worker_name
- )
- if not worker:
- raise BadRequestException(message=f"Worker {worker_name} not found")
- gpu = (
- next(
- (gpu for gpu in worker.status.gpu_devices if gpu.index == gpu_index),
- None,
- )
- if worker.status and worker.status.gpu_devices
- else None
- )
- if gpu:
- validate_gpu(gpu, model_backend=model_backend)
- if model_backend == BackendEnum.VLLM and len(worker_name_set) > 1:
- await validate_distributed_vllm_limit_per_worker(session, model_in, worker)
- if (
- is_custom_backend(model_backend)
- and len(worker_name_set) > 1
- and model_in.replicas == 1
- ):
- raise BadRequestException(
- message="Distributed inference across multiple workers is not supported for custom backends."
- )
- def validate_gpu(gpu_device: GPUDeviceStatus, model_backend: str = ""):
- if (
- model_backend == BackendEnum.VOX_BOX
- and gpu_device.vendor != ManufacturerEnum.NVIDIA.value
- ):
- raise BadRequestException(
- "The vox-box backend is supported only on NVIDIA GPUs."
- )
- if (
- model_backend == BackendEnum.ASCEND_MINDIE
- and gpu_device.vendor != ManufacturerEnum.ASCEND.value
- ):
- raise BadRequestException(
- f"Ascend MindIE backend requires Ascend NPUs. Selected {gpu_device.vendor} GPU is not supported."
- )
- async def validate_distributed_vllm_limit_per_worker(
- session: AsyncSession, model: Union[ModelCreate, ModelUpdate], worker: Worker
- ):
- """
- Validate that there is no more than one distributed vLLM instance per worker.
- """
- instances = await ModelInstance.all_by_field(session, "worker_id", worker.id)
- for instance in instances:
- if (
- instance.distributed_servers
- and instance.distributed_servers.subordinate_workers
- and instance.model_name != model.name
- ):
- raise BadRequestException(
- message=f"Each worker can run only one distributed vLLM instance. Worker '{worker.name}' already has '{instance.name}'."
- )
- @router.post("", response_model=ModelPublic)
- async def create_model(
- session: SessionDep, ctx: TenantContextDep, model_in: ModelCreate
- ):
- # Model & ModelRoute names are unique within their Org. Two Orgs
- # can each have a "llama3" without colliding.
- org_scope = ctx.current_principal_id
- # Check for ANY existing record with the same name (including soft-deleted)
- # to avoid unique constraint violations
- statement = select(Model).where(Model.name == model_in.name)
- result = await session.exec(statement)
- any_existing = result.first()
- if any_existing:
- if any_existing.deleted_at is not None:
- # Soft-deleted record found
- if not model_in.overwrite_deleted:
- # Prompt user to confirm overwrite
- raise AlreadyExistsException(
- message=f"Model '{model_in.name}' was previously deleted. "
- "Do you want to overwrite it? Set 'overwrite_deleted=true' to confirm."
- )
- # User confirmed overwrite - permanently delete the soft-deleted record
- await session.delete(any_existing)
- await session.flush()
- else:
- # Active record found - check if it's in the same org scope
- if any_existing.owner_principal_id == org_scope:
- raise AlreadyExistsException(
- message=f"Model '{model_in.name}' already exists. "
- "Please choose a different name or check the existing model."
- )
- else:
- # Different org - still a conflict due to unique constraint on name
- raise AlreadyExistsException(
- message=f"Model name '{model_in.name}' is already in use by another organization. "
- "Please choose a different name."
- )
- # Double-check for the specific org scope (defensive programming)
- existing = await Model.one_by_fields(
- session,
- {"name": model_in.name, "owner_principal_id": org_scope},
- )
- if existing:
- raise AlreadyExistsException(
- message=f"Model '{model_in.name}' already exists. "
- "Please choose a different name or check the existing model."
- )
- should_create_route = (
- model_in.enable_model_route is not None and model_in.enable_model_route
- )
- if should_create_route:
- existing_route = await ModelRoute.one_by_fields(
- session,
- {"name": model_in.name, "owner_principal_id": org_scope},
- )
- if existing_route:
- raise AlreadyExistsException(
- message=f"Model route '{model_in.name}' already exists. "
- "Please choose a different name or check the existing model route."
- )
- await validate_model_in(session, model_in)
- model_in_dict = model_in.model_dump(exclude={"enable_model_route", "overwrite_deleted"})
- # Stamp tenant scope. ModelBase has owner_principal_id defaulted to
- # PLATFORM_PRINCIPAL_ID, so `model_dump()` always emits the key —
- # `setdefault` would silently leave it at 1 even when the caller is
- # acting under a different Org. Override directly:
- # - Caller has a current Org context → that Org wins
- # - Caller is admin in "All" mode → fall back to the chosen
- # cluster's owner Org so the model lives where it actually runs
- # (otherwise it'd land in Platform/Default and the cluster's Org
- # couldn't see / manage it)
- target_org_id = ctx.current_principal_id
- if target_org_id is None and model_in.cluster_id is not None:
- cluster = await Cluster.one_by_id(session, model_in.cluster_id)
- if cluster is not None:
- target_org_id = cluster.owner_principal_id
- if target_org_id is not None:
- model_in_dict["owner_principal_id"] = target_org_id
- try:
- model: Model = await Model.create(
- session, source=model_in_dict, auto_commit=(not should_create_route)
- )
- if should_create_route:
- model_route = ModelRoute(
- name=model.name,
- description=model.description,
- categories=model.categories,
- generic_proxy=model.generic_proxy,
- created_by_model=True,
- access_policy=model.access_policy,
- owner_principal_id=model.owner_principal_id,
- )
- model_route: ModelRoute = await ModelRoute.create(
- session, source=model_route, auto_commit=False
- )
- model_route_target = ModelRouteTarget(
- name=f"{model.name}-deployment",
- route_name=model_route.name,
- generic_proxy=model.generic_proxy,
- model_route=model_route,
- model=model,
- weight=100,
- state=TargetStateEnum.UNAVAILABLE,
- )
- await ModelRouteTarget.create(
- session,
- source=model_route_target,
- auto_commit=False,
- )
- await session.commit()
- await revoke_model_access_cache(session=session)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to create model: {e}")
- return model
- @router.put("/{id}", response_model=ModelPublic)
- async def update_model(
- session: SessionDep, ctx: TenantContextDep, id: int, model_in: ModelUpdate
- ):
- model = await Model.one_by_id(session, id)
- assert_resource_visible(ctx, model, not_found_message="Model not found")
- await validate_model_in(session, model_in)
- if model_in.backend != BackendEnum.CUSTOM.value and (
- model.run_command or model.image_name
- ):
- patch = model_in.model_dump(exclude_unset=True)
- patch["run_command"] = None
- patch["image_name"] = None
- model_in = patch
- try:
- await ModelService(session).update(model, model_in)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to update model: {e}")
- return model
- @router.delete("/{id}")
- async def delete_model(session: SessionDep, ctx: TenantContextDep, id: int):
- model = await Model.one_by_id(
- session,
- id,
- options=[
- selectinload(Model.instances),
- selectinload(Model.model_route_targets),
- ],
- )
- assert_resource_visible(ctx, model, not_found_message="Model not found")
- try:
- await ModelService(session).delete(model)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to delete model: {e}")
|