me_orgs.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """Self-service tenant endpoints — what orgs am I in, what clusters can I use."""
  2. from typing import List
  3. from fastapi import APIRouter
  4. from pydantic import BaseModel
  5. from sqlmodel import select
  6. from gpustack.api.exceptions import ForbiddenException, NotFoundException
  7. from gpustack.schemas.cluster_access import ClusterAccess
  8. from gpustack.schemas.clusters import Cluster, ClusterPublic
  9. from gpustack.schemas.organizations import OrganizationPublic
  10. from gpustack.schemas.principals import (
  11. OrgRole,
  12. Principal,
  13. PrincipalMembership,
  14. PrincipalType,
  15. )
  16. from gpustack.server.deps import CurrentUserDep, SessionDep, TenantContextDep
  17. router = APIRouter()
  18. class MyOrganization(BaseModel):
  19. organization: OrganizationPublic
  20. role: OrgRole
  21. model_config = {"from_attributes": True}
  22. @router.get("/organizations", response_model=List[MyOrganization])
  23. async def list_my_orgs(session: SessionDep, user: CurrentUserDep):
  24. """Return the org switcher list — user's Personal scope first,
  25. then any ORG-principals they're a member of.
  26. "Personal" is no longer a stored Org row. After the principals
  27. refactor it's the user's own USER-principal (pre-refactor flag
  28. ``is_personal=True`` is now ``kind == USER`` rendered by
  29. ``OrganizationPublic.from_principal``). Synthesizing it here keeps
  30. the OrgSwitcher render path unchanged on the UI side.
  31. """
  32. items: List[MyOrganization] = []
  33. user_principal = await Principal.one_by_id(session, user.principal_id)
  34. if user_principal is not None and user_principal.deleted_at is None:
  35. items.append(
  36. MyOrganization(
  37. organization=OrganizationPublic.from_principal(user_principal),
  38. role=OrgRole.ADMIN,
  39. )
  40. )
  41. stmt = (
  42. select(PrincipalMembership, Principal)
  43. .join(
  44. Principal,
  45. Principal.id == PrincipalMembership.parent_principal_id,
  46. )
  47. .where(
  48. PrincipalMembership.member_principal_id == user.principal_id,
  49. PrincipalMembership.deleted_at.is_(None),
  50. Principal.deleted_at.is_(None),
  51. Principal.kind == PrincipalType.ORG,
  52. )
  53. )
  54. rows = (await session.exec(stmt)).all()
  55. items.extend(
  56. MyOrganization(
  57. organization=OrganizationPublic.from_principal(org),
  58. role=membership.role or OrgRole.USER,
  59. )
  60. for membership, org in rows
  61. )
  62. return items
  63. @router.get("/organizations/{org_id}/clusters", response_model=List[ClusterPublic])
  64. async def list_my_clusters_in_org(
  65. session: SessionDep, ctx: TenantContextDep, org_id: int
  66. ):
  67. """List clusters accessible to the caller in a specific Org context."""
  68. org = await Principal.one_by_id(session, org_id)
  69. if not org or org.deleted_at is not None or org.kind != PrincipalType.ORG:
  70. raise NotFoundException(message="Organization not found")
  71. if not ctx.is_platform_admin and ctx.current_principal_id != org_id:
  72. raise ForbiddenException(
  73. message="Cannot inspect clusters of an organization you are not in"
  74. )
  75. # Group-principals that the user is a member of, scoped to this Org.
  76. user_principal_id = ctx.user.principal_id
  77. group_stmt = (
  78. select(PrincipalMembership.parent_principal_id)
  79. .join(
  80. Principal,
  81. Principal.id == PrincipalMembership.parent_principal_id,
  82. )
  83. .where(
  84. PrincipalMembership.member_principal_id == user_principal_id,
  85. PrincipalMembership.deleted_at.is_(None),
  86. Principal.kind == PrincipalType.GROUP,
  87. Principal.parent_principal_id == org_id,
  88. Principal.deleted_at.is_(None),
  89. )
  90. )
  91. group_principal_ids = list((await session.exec(group_stmt)).all())
  92. principal_ids = [user_principal_id, org_id, *group_principal_ids]
  93. cluster_id_stmt = select(ClusterAccess.cluster_id).where(
  94. ClusterAccess.principal_id.in_(principal_ids)
  95. )
  96. cluster_ids = set((await session.exec(cluster_id_stmt)).all())
  97. if not cluster_ids:
  98. return []
  99. cluster_stmt = select(Cluster).where(
  100. Cluster.id.in_(cluster_ids), Cluster.deleted_at.is_(None)
  101. )
  102. return list((await session.exec(cluster_stmt)).all())