| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679 |
- import math
- import random
- import secrets
- from typing import Any, Callable, Optional, Union
- from urllib.parse import urlencode
- import aiohttp
- from fastapi import APIRouter, Depends, Request, Response, Query
- from fastapi.responses import RedirectResponse, StreamingResponse
- from enum import Enum
- from sqlalchemy.orm import selectinload
- from gpustack.api.exceptions import (
- AlreadyExistsException,
- InternalServerErrorException,
- NotFoundException,
- InvalidException,
- ConflictException,
- ServiceUnavailableException,
- )
- from gpustack.api.responses import StreamingResponseWithStatusCode
- from gpustack.schemas.common import PaginatedList, Pagination
- from gpustack.schemas.config import parse_base_model_to_env_vars
- from gpustack.api.tenant import (
- TenantContext,
- bypass_tenant_filter,
- cluster_visibility_conditions,
- assert_cluster_visible,
- assert_cluster_writable,
- validate_owner_principal,
- )
- from gpustack.schemas.workers import Worker, WorkerStateEnum
- from gpustack.server.db import async_session
- from gpustack.server.deps import SessionDep, TenantContextDep
- from gpustack.server.services import create_user_with_principal
- from gpustack.server.worker_request import stream_to_worker
- from gpustack.schemas.clusters import (
- ClusterListParams,
- ClusterUpdate,
- ClusterCreate,
- ClusterPublic,
- ClustersPublic,
- Cluster,
- ClusterStateEnum,
- ClusterProvider,
- SensitiveRegistrationConfig,
- ClusterRegistrationTokenPublic,
- WorkerPoolCreate,
- WorkerPoolPublic,
- WorkerPool,
- CloudOptions,
- )
- from gpustack.schemas.organizations import PLATFORM_PRINCIPAL_ID
- from gpustack.schemas.users import User, UserRole, system_name_prefix
- from gpustack.schemas.api_keys import ApiKey
- from gpustack.security import get_secret_hash, API_KEY_PREFIX
- from gpustack.k8s.manifest_template import TemplateConfig
- from gpustack.config.config import get_global_config, get_cluster_image_name
- from gpustack.utils.grafana import resolve_grafana_base_url
- from gpustack_runtime.detector import ManufacturerEnum
- CLUSTER_LOAD_OPTIONS = [
- selectinload(Cluster.cluster_workers),
- selectinload(Cluster.cluster_models),
- ]
- router = APIRouter()
- def get_server_url(request: Request, cluster_override: Optional[str]) -> str:
- """Construct the server URL based on request headers or fallback to default."""
- if cluster_override:
- return cluster_override.strip("/")
- url = get_global_config().server_external_url
- if not url:
- url = f"{request.url.scheme}://{request.url.hostname}"
- if request.url.port:
- url += f":{request.url.port}"
- return url
- def _is_cluster_visible(cluster: Cluster, ctx: TenantContext) -> bool:
- """Python-side mirror of cluster_visibility_conditions for in-memory lists."""
- if bypass_tenant_filter(ctx):
- return True
- if (
- ctx.current_principal_id is not None
- and cluster.owner_principal_id == ctx.current_principal_id
- ):
- return True
- if cluster.id in ctx.accessible_cluster_ids:
- return True
- return False
- @router.get("", response_model=ClustersPublic, response_model_exclude_none=True)
- async def get_clusters(
- session: SessionDep,
- ctx: TenantContextDep,
- params: ClusterListParams = Depends(),
- name: str = None,
- search: str = None,
- ):
- fuzzy_fields = {}
- if search:
- fuzzy_fields = {"name": search}
- fields = {'deleted_at': None}
- if name:
- fields = {"name": name}
- if params.watch:
- return StreamingResponse(
- Cluster.streaming(
- fields=fields,
- fuzzy_fields=fuzzy_fields,
- options=CLUSTER_LOAD_OPTIONS,
- filter_func=lambda c: _is_cluster_visible(c, ctx),
- ),
- media_type="text/event-stream",
- )
- async with async_session() as session:
- # Push visibility filtering into the query — own-Org cluster OR
- # cluster_access grant — instead of fetching the whole table and
- # filtering in Python.
- items = await Cluster.all_by_fields(
- session=session,
- fields=fields,
- fuzzy_fields=fuzzy_fields,
- options=CLUSTER_LOAD_OPTIONS,
- extra_conditions=cluster_visibility_conditions(ctx, Cluster),
- )
- if not items:
- return PaginatedList[ClusterPublic](
- items=[],
- pagination=Pagination(
- page=params.page,
- perPage=params.perPage,
- total=0,
- totalPage=0,
- ),
- )
- if params.page < 1 or params.perPage < 1:
- # Return all items.
- pagination = Pagination(
- page=1,
- perPage=len(items),
- total=len(items),
- totalPage=1,
- )
- return PaginatedList[ClusterPublic](items=items, pagination=pagination)
- # sort in memory
- order_by = params.order_by
- if order_by:
- for field, direction in reversed(order_by):
- items.sort(
- key=_make_sort_key(field),
- reverse=direction == "desc",
- )
- # Paginate results.
- start = (params.page - 1) * params.perPage
- end = start + params.perPage
- paginated_items = items[start:end]
- count = len(items)
- total_page = math.ceil(count / params.perPage)
- pagination = Pagination(
- page=params.page,
- perPage=params.perPage,
- total=count,
- totalPage=total_page,
- )
- return PaginatedList[ClusterPublic](
- items=paginated_items, pagination=pagination
- )
- def _make_sort_key(field: str) -> Callable[[Any], tuple]:
- """
- Returns a key function for sorting objects by a given field.
- Handles:
- - None values (placed at the end regardless of sort direction),
- - Enum instances (uses .value for comparison),
- - Other types (str, int, float, datetime, etc.) as long as they are comparable.
- """
- def key_func(obj: Any) -> tuple:
- val = getattr(obj, field, None)
- if val is None:
- # (1, None) ensures None is sorted after non-None values
- return (1, None)
- if isinstance(val, Enum):
- # Use the underlying value of the Enum for comparison
- sort_val = val.value
- else:
- sort_val = val
- # (0, sort_val) so non-None values come first
- return (0, sort_val)
- return key_func
- @router.get("/{id}", response_model=ClusterPublic, response_model_exclude_none=True)
- async def get_cluster(session: SessionDep, ctx: TenantContextDep, id: int):
- cluster = await Cluster.one_by_id(
- session,
- id,
- options=CLUSTER_LOAD_OPTIONS,
- )
- assert_cluster_visible(ctx, cluster, not_found_message=f"cluster {id} not found")
- return cluster
- def create_update_check(
- provider: ClusterProvider, input: Union[ClusterCreate, ClusterUpdate]
- ):
- cfg = get_global_config()
- is_cloud_provider = provider not in [
- ClusterProvider.Kubernetes,
- ClusterProvider.Docker,
- ]
- if (
- is_cloud_provider
- and isinstance(input, ClusterCreate)
- and input.credential_id is None
- ):
- raise InvalidException(
- message=f"credential_id is required for provider {provider}"
- )
- server_url = input.server_url or cfg.server_external_url
- if is_cloud_provider and server_url is None:
- raise InvalidException(
- message=f"server_url is required for provider {provider}"
- )
- if provider == ClusterProvider.Kubernetes:
- # check for volume mounts
- if input.k8s_volume_mounts is None or len(input.k8s_volume_mounts) < 1:
- # at least one volume mount is required, and the default one is for gpustack data dir.
- raise InvalidException(
- message="At least one k8s_volume_mount is required, and the default one is for gpustack data dir."
- )
- if (
- input.k8s_volume_mounts[0].volume_source is None
- or input.k8s_volume_mounts[0].volume_source.host_path is None
- ):
- raise InvalidException(
- message="The first k8s_volume_mount must be for gpustack data dir with hostPath volume source."
- )
- def enforce_data_dir_mounts(input: Union[ClusterCreate, ClusterUpdate]):
- """
- Assuming the first item of k8s_volume_mounts is for gpustack data dir, enforce that it is always present and has the correct settings.
- """
- # the first volume must exist as it's validated in create_update_check, and it must be for gpustack data dir, so we enforce it here.
- data_dir_mount = input.k8s_volume_mounts[0]
- data_dir_mount.name = "gpustack-data-dir"
- data_dir_mount.mount_path = "/var/lib/gpustack"
- data_dir_mount.read_only = False
- data_dir_mount.volume_source.host_path.type = "DirectoryOrCreate"
- data_dir_mount.volume_source.config_map = None
- data_dir_mount.volume_source.persistent_volume_claim = None
- @router.post("", response_model=ClusterPublic, response_model_exclude_none=True)
- async def create_cluster(
- session: SessionDep, ctx: TenantContextDep, input: ClusterCreate
- ):
- # Every cluster has an owner Org. Fill in a sensible default when the
- # caller omitted it: their current Org context, or the platform Org
- # for admin in "All" mode (admin's home is Default).
- if input.owner_principal_id is None:
- input.owner_principal_id = ctx.current_principal_id or PLATFORM_PRINCIPAL_ID
- validate_owner_principal(input.owner_principal_id, ctx, resource_label="cluster")
- # Cluster names are unique within their owning Org, not globally —
- # two Orgs can each have a "c1".
- existing = await Cluster.one_by_fields(
- session,
- {
- 'deleted_at': None,
- "name": input.name,
- "owner_principal_id": input.owner_principal_id,
- },
- )
- if existing:
- raise AlreadyExistsException(message=f"cluster {input.name} already exists")
- create_update_check(input.provider, input)
- if input.provider == ClusterProvider.Kubernetes:
- enforce_data_dir_mounts(input)
- # Auto-promote the first cluster in an Org to that Org's default so
- # users don't have to flip a separate switch after onboarding.
- has_existing_in_org = await Cluster.first_by_field(
- session=session, field="owner_principal_id", value=input.owner_principal_id
- )
- auto_default = has_existing_in_org is None
- access_key = secrets.token_hex(8)
- secret_key = secrets.token_hex(16)
- target_state = ClusterStateEnum.READY
- state_message = None
- if input.provider not in [ClusterProvider.Kubernetes, ClusterProvider.Docker]:
- target_state = ClusterStateEnum.PENDING
- state_message = "No workers have been provisioned for this cluster yet."
- pools = input.worker_pools or []
- to_create_cluster = Cluster.model_validate(
- {
- **input.model_dump(exclude={"worker_pools"}),
- "state": target_state,
- "state_message": state_message,
- "hashed_suffix": secrets.token_hex(6),
- "registration_token": f"{API_KEY_PREFIX}_{access_key}_{secret_key}",
- "is_default": auto_default,
- }
- )
- to_create_user = User(
- username=f'{system_name_prefix}-{to_create_cluster.hashed_suffix}',
- is_system=True,
- role=UserRole.Cluster,
- hashed_password="",
- )
- to_create_apikey = ApiKey(
- name=f'{system_name_prefix}-{to_create_cluster.hashed_suffix}',
- access_key=access_key,
- hashed_secret_key=get_secret_hash(secret_key),
- )
- try:
- # create cluster
- cluster = await Cluster.create(session, to_create_cluster, auto_commit=False)
- # create pools
- for pool in pools:
- to_create_pool = WorkerPool.model_validate(
- {
- **pool.model_dump(),
- "cluster_id": cluster.id,
- # Pool inherits its cluster's owner so list filters
- # can scope without joining.
- "owner_principal_id": cluster.owner_principal_id,
- "cloud_options": (
- pool.cloud_options if pool.cloud_options else CloudOptions()
- ),
- }
- )
- to_create_pool.cluster = cluster
- await WorkerPool.create(
- session=session, source=to_create_pool, auto_commit=False
- )
- to_create_user.cluster = cluster
- user = await create_user_with_principal(session, to_create_user)
- to_create_apikey.user_id = user.id
- await ApiKey.create(session=session, source=to_create_apikey, auto_commit=False)
- await session.commit()
- await session.refresh(cluster)
- return cluster
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to create cluster: {e}")
- @router.put("/{id}", response_model=ClusterPublic, response_model_exclude_none=True)
- async def update_cluster(
- session: SessionDep, ctx: TenantContextDep, id: int, input: ClusterUpdate
- ):
- cluster = await Cluster.one_by_id(session, id)
- if not cluster:
- raise NotFoundException(message=f"cluster {id} not found")
- assert_cluster_writable(ctx, cluster)
- create_update_check(cluster.provider, input)
- if cluster.provider == ClusterProvider.Kubernetes:
- enforce_data_dir_mounts(input)
- try:
- await cluster.update(session=session, source=input)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to update cluster: {e}")
- return await Cluster.one_by_id(
- session,
- id,
- options=CLUSTER_LOAD_OPTIONS,
- )
- @router.delete("/{id}")
- async def delete_cluster(session: SessionDep, ctx: TenantContextDep, id: int):
- existing = await Cluster.one_by_id(
- session,
- id,
- options=[
- selectinload(Cluster.cluster_workers),
- selectinload(Cluster.cluster_models),
- selectinload(Cluster.cluster_model_instances),
- ],
- )
- if not existing or existing.deleted_at is not None:
- raise NotFoundException(message=f"cluster {id} not found")
- assert_cluster_writable(ctx, existing)
- # check for workers, if any are present, prevent deletion
- if len(existing.cluster_workers) > 0:
- raise ConflictException(
- message=f"cluster {existing.name}(id: {id}) has workers, cannot be deleted"
- )
- # check for models, if any are present, prevent deletion
- if len(existing.cluster_models) > 0:
- raise ConflictException(
- message=f"cluster {existing.name}(id: {id}) has models, cannot be deleted"
- )
- # check for model instances, if any are present, prevent deletion
- if len(existing.cluster_model_instances) > 0:
- raise ConflictException(
- message=f"cluster {existing.name}(id: {id}) has model instances, cannot be deleted"
- )
- try:
- await existing.delete(session=session)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to delete cluster: {e}")
- @router.post("/{id}/set-default")
- async def set_default_cluster(session: SessionDep, ctx: TenantContextDep, id: int):
- # "Default cluster" is a per-Org concept now: each Org has at most
- # one default, and that's what its members' deploy form falls back
- # to. Writing it follows the standard cluster-write rule (admin
- # always; Org admin only on their own Org's clusters).
- cluster = await Cluster.one_by_id(session, id)
- if not cluster:
- raise NotFoundException(message=f"cluster {id} not found")
- assert_cluster_writable(ctx, cluster)
- try:
- # Unset any existing default in this cluster's Org. The partial
- # unique index guarantees there's at most one to begin with.
- existing_defaults = await Cluster.all_by_fields(
- session,
- {
- 'is_default': True,
- 'deleted_at': None,
- 'owner_principal_id': cluster.owner_principal_id,
- },
- )
- for dc in existing_defaults:
- if dc.id != cluster.id:
- await dc.update(
- session=session,
- source={"is_default": False},
- auto_commit=False,
- )
- # Set this cluster as the Org's default.
- await cluster.update(
- session=session,
- source={"is_default": True},
- auto_commit=False,
- )
- await session.commit()
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to set default cluster: {e}"
- )
- @router.post("/{id}/worker-pools", response_model=WorkerPoolPublic)
- async def create_worker_pool(
- session: SessionDep, ctx: TenantContextDep, id: int, input: WorkerPoolCreate
- ):
- cluster = await Cluster.one_by_id(session, id)
- if not cluster or cluster.deleted_at is not None:
- raise NotFoundException(message=f"cluster {id} not found")
- assert_cluster_writable(ctx, cluster)
- if cluster.provider in [ClusterProvider.Docker, ClusterProvider.Kubernetes]:
- raise InvalidException(
- message=f"Cannot create worker pool for cluster {cluster.name}(id: {id}) with provider {cluster.provider}"
- )
- try:
- cloud_options = input.cloud_options or CloudOptions()
- worker_pool = WorkerPool.model_validate(
- {
- **input.model_dump(),
- "cluster_id": id,
- "owner_principal_id": cluster.owner_principal_id,
- "cloud_options": cloud_options,
- }
- )
- worker_pool.cluster = cluster
- return await WorkerPool.create(session, worker_pool)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to create worker pool: {e}")
- def get_registration_from_cluster(
- request: Request, cluster: Cluster
- ) -> ClusterRegistrationTokenPublic:
- config = cluster.worker_config.model_dump() if cluster.worker_config else {}
- sensitive_registration = SensitiveRegistrationConfig(
- token=cluster.registration_token, **config
- )
- return ClusterRegistrationTokenPublic(
- token=cluster.registration_token,
- server_url=get_server_url(request, cluster.server_url),
- image=get_cluster_image_name(
- cluster.worker_config
- ), # Default image, can be customized
- env=parse_base_model_to_env_vars(sensitive_registration),
- args=[],
- )
- @router.get("/{id}/registration-token", response_model=ClusterRegistrationTokenPublic)
- async def get_registration_token(
- request: Request, session: SessionDep, ctx: TenantContextDep, id: int
- ):
- cluster = await Cluster.one_by_id(session, id)
- if not cluster or cluster.deleted_at is not None:
- raise NotFoundException(message=f"cluster {id} not found")
- # Registration token is a write-class secret (anyone holding it can
- # register a worker into this cluster) — gate with the writable check.
- assert_cluster_writable(ctx, cluster)
- return get_registration_from_cluster(request, cluster)
- @router.get("/{id}/manifests")
- async def get_cluster_manifests(
- request: Request,
- session: SessionDep,
- id: int,
- runtime: Optional[ManufacturerEnum] = Query(
- None, description="Optional runtime to include in the manifest"
- ),
- ):
- cluster = await Cluster.one_by_id(session, id)
- if not cluster or cluster.deleted_at is not None:
- raise NotFoundException(message=f"cluster {id} not found")
- if cluster.provider != ClusterProvider.Kubernetes:
- raise InvalidException(
- message=f"Cannot get manifests for cluster {cluster.name}(id: {id}) with provider {cluster.provider}"
- )
- config = TemplateConfig(
- registration=get_registration_from_cluster(request, cluster),
- cluster_suffix=cluster.hashed_suffix,
- namespace=getattr(cluster.worker_config, "namespace", None),
- runtime_enum=runtime,
- k8s_volume_mounts=cluster.k8s_volume_mounts,
- )
- yaml_content = config.render()
- return Response(
- content=yaml_content,
- media_type="application/x-yaml",
- headers={"Content-Disposition": "attachment; filename=manifest.yaml"},
- )
- @router.get("/{id}/dashboard")
- async def get_cluster_dashboard(
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- request: Request,
- ):
- cluster = await Cluster.one_by_id(session, id)
- assert_cluster_visible(ctx, cluster, not_found_message="cluster not found")
- cfg = get_global_config()
- if not cfg.get_grafana_url() or not cfg.grafana_worker_dashboard_uid:
- raise InternalServerErrorException(
- message="Grafana dashboard settings are not configured"
- )
- query_params = {"var-cluster_name": cluster.name}
- grafana_base = resolve_grafana_base_url(cfg, request)
- slug = "gpustack-worker"
- dashboard_url = f"{grafana_base}/d/{cfg.grafana_worker_dashboard_uid}/{slug}"
- if query_params:
- dashboard_url = f"{dashboard_url}?{urlencode(query_params)}"
- return RedirectResponse(url=dashboard_url, status_code=302)
- # Hop-by-hop headers and other things we should not forward to the worker; the
- # worker layer will inject its own Authorization, and the worker→k8s leg will
- # inject the in-pod ServiceAccount token.
- _CLUSTER_PROXY_REQUEST_HEADER_SKIP = {
- "host",
- "content-length",
- "transfer-encoding",
- "connection",
- "keep-alive",
- "proxy-authenticate",
- "proxy-authorization",
- "te",
- "trailer",
- "upgrade",
- "authorization",
- "cookie",
- "x-api-key",
- "x-forwarded-host",
- "x-forwarded-port",
- "x-forwarded-proto",
- }
- @router.api_route(
- "/{id}/proxy/{path:path}",
- methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"],
- include_in_schema=False,
- )
- async def cluster_apiserver_proxy(
- request: Request,
- session: SessionDep,
- id: int,
- path: str,
- ):
- """
- Proxy a request to the Kubernetes API server of a Kubernetes-provider
- cluster, by forwarding it through one of the cluster's worker pods. The
- worker uses its in-pod ServiceAccount credentials to call the API server.
- """
- cluster = await Cluster.one_by_id(session, id)
- if not cluster or cluster.deleted_at is not None:
- raise NotFoundException(message=f"cluster {id} not found")
- if cluster.provider != ClusterProvider.Kubernetes:
- raise InvalidException(
- message=(
- f"cluster {cluster.name}(id: {id}) provider is "
- f"{cluster.provider.value}; API server proxy is only supported "
- "for Kubernetes-provider clusters."
- )
- )
- workers = await Worker.all_by_fields(
- session,
- fields={"cluster_id": id, "state": WorkerStateEnum.READY},
- )
- if not workers:
- raise ServiceUnavailableException(
- message=f"No ready workers in cluster {cluster.name}(id: {id})"
- )
- worker = random.choice(workers)
- headers = {
- k: v
- for k, v in request.headers.items()
- if k.lower() not in _CLUSTER_PROXY_REQUEST_HEADER_SKIP
- }
- # Stream the request body through to avoid loading large payloads (e.g. apply
- # of big manifests) into memory.
- body = (
- request.stream() if request.method not in ("GET", "HEAD", "OPTIONS") else None
- )
- # request.query_params preserves order but a flat dict is sufficient for
- # the Kubernetes API surface we forward (no duplicate keys in practice).
- params = dict(request.query_params) or None
- # No total timeout — Kubernetes watch and log-follow streams may be open
- # indefinitely. Connect timeout still bounds the upstream connect step.
- timeout = aiohttp.ClientTimeout(total=None, sock_connect=10)
- return StreamingResponseWithStatusCode(
- stream_to_worker(
- worker=worker,
- method=request.method,
- path=f"cluster-proxy/{path}",
- proxy_client=request.app.state.http_client,
- no_proxy_client=request.app.state.http_client_no_proxy,
- params=params,
- data=body,
- headers=headers,
- timeout=timeout,
- raw=True,
- ),
- )
|