| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- """Cluster access authorization — platform admin only.
- Lets the platform admin grant or revoke a cluster's accessibility to a
- specific principal (USER / ORG / GROUP). The grant row stores a single
- ``principal_id`` FK; kind comes from the joined principals row at read
- time.
- """
- from datetime import datetime, timezone
- from typing import List, Optional
- from fastapi import APIRouter
- from pydantic import BaseModel
- from sqlmodel import select
- from gpustack.api.exceptions import (
- AlreadyExistsException,
- InvalidException,
- NotFoundException,
- )
- from gpustack.schemas.cluster_access import ClusterAccess, ClusterAccessPublic
- from gpustack.schemas.clusters import Cluster
- from gpustack.schemas.principals import Principal, PrincipalType
- from gpustack.schemas.users import User
- from gpustack.server.deps import SessionDep, TenantContextDep
- router = APIRouter()
- class ClusterAccessGrant(BaseModel):
- # Discriminator kept on the input shape for client-side validation
- # — the server cross-checks against the joined ``principals`` row,
- # so a mismatched kind is rejected at the validator below.
- principal_type: PrincipalType
- principal_id: int
- async def _load_cluster(session, cluster_id: int) -> Cluster:
- cluster = await Cluster.one_by_id(session, cluster_id)
- if not cluster or cluster.deleted_at is not None:
- raise NotFoundException(message="Cluster not found")
- return cluster
- async def _validate_principal(
- session, principal_type: PrincipalType, principal_id: int
- ) -> Principal:
- """Ensure the principal exists, isn't soft-deleted, and matches
- the declared kind."""
- target = await Principal.one_by_id(session, principal_id)
- if not target or target.deleted_at is not None:
- raise InvalidException(message=f"Principal {principal_id} not found")
- if target.kind != principal_type:
- raise InvalidException(
- message=(
- f"Principal {principal_id} is a {target.kind.value}, "
- f"not a {principal_type.value}"
- )
- )
- if target.kind == PrincipalType.USER:
- # Disallow granting access to system users (workers etc.). They
- # already bypass cluster_access via ``is_system``.
- user = await User.one_by_field(session, "principal_id", principal_id)
- if user is None or user.is_system or user.deleted_at is not None:
- raise InvalidException(message=f"User principal {principal_id} not found")
- return target
- async def _resolve_principal_views(
- session, rows: List[ClusterAccess]
- ) -> List[ClusterAccessPublic]:
- """Bulk-resolve display labels and kind for each row in a single
- principals lookup.
- """
- principal_ids = {r.principal_id for r in rows}
- principal_by_id: dict[int, Principal] = {}
- if principal_ids:
- result = await session.exec(
- select(Principal).where(Principal.id.in_(principal_ids))
- )
- principal_by_id = {p.id: p for p in result.all()}
- out: List[ClusterAccessPublic] = []
- for r in rows:
- p: Optional[Principal] = principal_by_id.get(r.principal_id)
- kind = p.kind if p else PrincipalType.USER
- # GROUP principals expose their owning ORG via parent_principal_id
- # so the UI can render quota slots; USER and ORG return None.
- parent = p.parent_principal_id if p and p.kind == PrincipalType.GROUP else None
- # ORG principals' "parent" for display purposes is themselves.
- if p and p.kind == PrincipalType.ORG:
- parent = p.id
- out.append(
- ClusterAccessPublic(
- cluster_id=r.cluster_id,
- principal_id=r.principal_id,
- principal_type=kind,
- principal_name=p.name if p else None,
- principal_parent_id=parent,
- granted_by=r.granted_by,
- created_at=r.created_at,
- )
- )
- return out
- @router.get("/clusters/{cluster_id}/access", response_model=List[ClusterAccessPublic])
- async def list_cluster_access(
- session: SessionDep, ctx: TenantContextDep, cluster_id: int
- ):
- await _load_cluster(session, cluster_id)
- stmt = select(ClusterAccess).where(ClusterAccess.cluster_id == cluster_id)
- rows = list((await session.exec(stmt)).all())
- return await _resolve_principal_views(session, rows)
- @router.post("/clusters/{cluster_id}/access", response_model=ClusterAccessPublic)
- async def grant_cluster_access(
- session: SessionDep,
- ctx: TenantContextDep,
- cluster_id: int,
- body: ClusterAccessGrant,
- ):
- await _load_cluster(session, cluster_id)
- await _validate_principal(session, body.principal_type, body.principal_id)
- existing_stmt = select(ClusterAccess).where(
- ClusterAccess.cluster_id == cluster_id,
- ClusterAccess.principal_id == body.principal_id,
- )
- if (await session.exec(existing_stmt)).first() is not None:
- raise AlreadyExistsException(message="Access already granted")
- try:
- access = ClusterAccess(
- cluster_id=cluster_id,
- principal_id=body.principal_id,
- granted_by=ctx.user.id,
- created_at=datetime.now(timezone.utc).replace(tzinfo=None),
- )
- session.add(access)
- await session.commit()
- await session.refresh(access)
- except Exception as e:
- await session.rollback()
- raise InvalidException(message=f"Failed to grant cluster access: {e}")
- enriched = await _resolve_principal_views(session, [access])
- return enriched[0]
- @router.delete("/clusters/{cluster_id}/access/{principal_id}")
- async def revoke_cluster_access(
- session: SessionDep,
- ctx: TenantContextDep,
- cluster_id: int,
- principal_id: int,
- ):
- await _load_cluster(session, cluster_id)
- stmt = select(ClusterAccess).where(
- ClusterAccess.cluster_id == cluster_id,
- ClusterAccess.principal_id == principal_id,
- )
- access = (await session.exec(stmt)).first()
- if not access:
- raise NotFoundException(message="Access grant not found")
- try:
- await session.delete(access)
- await session.commit()
- except Exception as e:
- await session.rollback()
- raise InvalidException(message=f"Failed to revoke cluster access: {e}")
|