cluster_access.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """Cluster access authorization — platform admin only.
  2. Lets the platform admin grant or revoke a cluster's accessibility to a
  3. specific principal (USER / ORG / GROUP). The grant row stores a single
  4. ``principal_id`` FK; kind comes from the joined principals row at read
  5. time.
  6. """
  7. from datetime import datetime, timezone
  8. from typing import List, Optional
  9. from fastapi import APIRouter
  10. from pydantic import BaseModel
  11. from sqlmodel import select
  12. from gpustack.api.exceptions import (
  13. AlreadyExistsException,
  14. InvalidException,
  15. NotFoundException,
  16. )
  17. from gpustack.schemas.cluster_access import ClusterAccess, ClusterAccessPublic
  18. from gpustack.schemas.clusters import Cluster
  19. from gpustack.schemas.principals import Principal, PrincipalType
  20. from gpustack.schemas.users import User
  21. from gpustack.server.deps import SessionDep, TenantContextDep
  22. router = APIRouter()
  23. class ClusterAccessGrant(BaseModel):
  24. # Discriminator kept on the input shape for client-side validation
  25. # — the server cross-checks against the joined ``principals`` row,
  26. # so a mismatched kind is rejected at the validator below.
  27. principal_type: PrincipalType
  28. principal_id: int
  29. async def _load_cluster(session, cluster_id: int) -> Cluster:
  30. cluster = await Cluster.one_by_id(session, cluster_id)
  31. if not cluster or cluster.deleted_at is not None:
  32. raise NotFoundException(message="Cluster not found")
  33. return cluster
  34. async def _validate_principal(
  35. session, principal_type: PrincipalType, principal_id: int
  36. ) -> Principal:
  37. """Ensure the principal exists, isn't soft-deleted, and matches
  38. the declared kind."""
  39. target = await Principal.one_by_id(session, principal_id)
  40. if not target or target.deleted_at is not None:
  41. raise InvalidException(message=f"Principal {principal_id} not found")
  42. if target.kind != principal_type:
  43. raise InvalidException(
  44. message=(
  45. f"Principal {principal_id} is a {target.kind.value}, "
  46. f"not a {principal_type.value}"
  47. )
  48. )
  49. if target.kind == PrincipalType.USER:
  50. # Disallow granting access to system users (workers etc.). They
  51. # already bypass cluster_access via ``is_system``.
  52. user = await User.one_by_field(session, "principal_id", principal_id)
  53. if user is None or user.is_system or user.deleted_at is not None:
  54. raise InvalidException(message=f"User principal {principal_id} not found")
  55. return target
  56. async def _resolve_principal_views(
  57. session, rows: List[ClusterAccess]
  58. ) -> List[ClusterAccessPublic]:
  59. """Bulk-resolve display labels and kind for each row in a single
  60. principals lookup.
  61. """
  62. principal_ids = {r.principal_id for r in rows}
  63. principal_by_id: dict[int, Principal] = {}
  64. if principal_ids:
  65. result = await session.exec(
  66. select(Principal).where(Principal.id.in_(principal_ids))
  67. )
  68. principal_by_id = {p.id: p for p in result.all()}
  69. out: List[ClusterAccessPublic] = []
  70. for r in rows:
  71. p: Optional[Principal] = principal_by_id.get(r.principal_id)
  72. kind = p.kind if p else PrincipalType.USER
  73. # GROUP principals expose their owning ORG via parent_principal_id
  74. # so the UI can render quota slots; USER and ORG return None.
  75. parent = p.parent_principal_id if p and p.kind == PrincipalType.GROUP else None
  76. # ORG principals' "parent" for display purposes is themselves.
  77. if p and p.kind == PrincipalType.ORG:
  78. parent = p.id
  79. out.append(
  80. ClusterAccessPublic(
  81. cluster_id=r.cluster_id,
  82. principal_id=r.principal_id,
  83. principal_type=kind,
  84. principal_name=p.name if p else None,
  85. principal_parent_id=parent,
  86. granted_by=r.granted_by,
  87. created_at=r.created_at,
  88. )
  89. )
  90. return out
  91. @router.get("/clusters/{cluster_id}/access", response_model=List[ClusterAccessPublic])
  92. async def list_cluster_access(
  93. session: SessionDep, ctx: TenantContextDep, cluster_id: int
  94. ):
  95. await _load_cluster(session, cluster_id)
  96. stmt = select(ClusterAccess).where(ClusterAccess.cluster_id == cluster_id)
  97. rows = list((await session.exec(stmt)).all())
  98. return await _resolve_principal_views(session, rows)
  99. @router.post("/clusters/{cluster_id}/access", response_model=ClusterAccessPublic)
  100. async def grant_cluster_access(
  101. session: SessionDep,
  102. ctx: TenantContextDep,
  103. cluster_id: int,
  104. body: ClusterAccessGrant,
  105. ):
  106. await _load_cluster(session, cluster_id)
  107. await _validate_principal(session, body.principal_type, body.principal_id)
  108. existing_stmt = select(ClusterAccess).where(
  109. ClusterAccess.cluster_id == cluster_id,
  110. ClusterAccess.principal_id == body.principal_id,
  111. )
  112. if (await session.exec(existing_stmt)).first() is not None:
  113. raise AlreadyExistsException(message="Access already granted")
  114. try:
  115. access = ClusterAccess(
  116. cluster_id=cluster_id,
  117. principal_id=body.principal_id,
  118. granted_by=ctx.user.id,
  119. created_at=datetime.now(timezone.utc).replace(tzinfo=None),
  120. )
  121. session.add(access)
  122. await session.commit()
  123. await session.refresh(access)
  124. except Exception as e:
  125. await session.rollback()
  126. raise InvalidException(message=f"Failed to grant cluster access: {e}")
  127. enriched = await _resolve_principal_views(session, [access])
  128. return enriched[0]
  129. @router.delete("/clusters/{cluster_id}/access/{principal_id}")
  130. async def revoke_cluster_access(
  131. session: SessionDep,
  132. ctx: TenantContextDep,
  133. cluster_id: int,
  134. principal_id: int,
  135. ):
  136. await _load_cluster(session, cluster_id)
  137. stmt = select(ClusterAccess).where(
  138. ClusterAccess.cluster_id == cluster_id,
  139. ClusterAccess.principal_id == principal_id,
  140. )
  141. access = (await session.exec(stmt)).first()
  142. if not access:
  143. raise NotFoundException(message="Access grant not found")
  144. try:
  145. await session.delete(access)
  146. await session.commit()
  147. except Exception as e:
  148. await session.rollback()
  149. raise InvalidException(message=f"Failed to revoke cluster access: {e}")