| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412 |
- """Unit tests for P2 (Org / Group / Membership / cluster_access) route logic.
- The codebase pattern for tests in this repo is mock-based; we follow that
- here. Coverage focuses on the authorization branches and the corner cases
- that aren't trivially expressible in declarative SQL constraints.
- """
- from datetime import datetime, timezone
- from unittest.mock import AsyncMock, MagicMock
- import pytest
- from gpustack.api.exceptions import (
- AlreadyExistsException,
- ConflictException,
- ForbiddenException,
- InvalidException,
- NotFoundException,
- )
- from gpustack.routes import (
- cluster_access as cluster_access_route,
- organization_members,
- organizations as organizations_route,
- user_groups as user_groups_route,
- )
- from gpustack.schemas.principals import (
- OrgRole,
- Principal,
- PrincipalMembership,
- PrincipalType,
- )
- def _ctx(
- *,
- user_id: int = 1,
- user_principal_id: int = 100,
- is_admin: bool = False,
- current_principal_id: int | None = None,
- org_role: OrgRole | None = None,
- ):
- ctx = MagicMock()
- ctx.user = MagicMock()
- ctx.user.id = user_id
- ctx.user.principal_id = user_principal_id
- ctx.is_platform_admin = is_admin
- ctx.current_principal_id = current_principal_id
- ctx.org_role = org_role
- return ctx
- def _principal(
- id: int = 10,
- kind: PrincipalType = PrincipalType.ORG,
- parent_principal_id: int | None = None,
- name: str = "Acme",
- slug: str | None = "acme",
- ):
- p = MagicMock(spec=Principal)
- p.id = id
- p.kind = kind
- p.parent_principal_id = parent_principal_id
- p.name = name
- p.slug = slug
- p.description = None
- p.deleted_at = None
- p.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
- p.updated_at = p.created_at
- return p
- def _user_row(id: int = 2, principal_id: int = 200):
- u = MagicMock()
- u.id = id
- u.principal_id = principal_id
- u.username = f"user-{id}"
- u.full_name = None
- u.is_system = False
- u.deleted_at = None
- return u
- def _session_returning(*results):
- """Make a mock async session whose successive .exec() return the queued results."""
- session = MagicMock()
- queue = []
- for value in results:
- result = MagicMock()
- if isinstance(value, list):
- scalars = MagicMock()
- scalars.all = MagicMock(return_value=value)
- result.scalars = MagicMock(return_value=scalars)
- result.first = MagicMock(return_value=value[0] if value else None)
- result.all = MagicMock(return_value=value)
- else:
- result.scalar_one_or_none = MagicMock(return_value=value)
- result.first = MagicMock(return_value=value)
- queue.append(result)
- session.exec = AsyncMock(side_effect=queue)
- session.commit = AsyncMock()
- session.rollback = AsyncMock()
- session.refresh = AsyncMock()
- session.delete = AsyncMock()
- session.add = MagicMock()
- return session
- # ---- _can_manage (organization_members) ------------------------------------
- def test_can_manage_platform_admin_always():
- ctx = _ctx(is_admin=True)
- assert organization_members._can_manage(ctx, 1) is True
- assert organization_members._can_manage(ctx, 99) is True
- def test_can_manage_admin_in_org_can_manage():
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
- assert organization_members._can_manage(ctx, 10) is True
- def test_can_manage_admin_cannot_manage_other_org():
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
- assert organization_members._can_manage(ctx, 99) is False
- def test_can_manage_member_cannot_manage():
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
- assert organization_members._can_manage(ctx, 10) is False
- # ---- _can_manage_groups ----------------------------------------------------
- def test_can_manage_groups_admin_passthrough():
- ctx = _ctx(is_admin=True, current_principal_id=None)
- assert user_groups_route._can_manage_groups(ctx, org_id=42) is True
- def test_can_manage_groups_member_blocked():
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
- assert user_groups_route._can_manage_groups(ctx, org_id=10) is False
- def test_can_manage_groups_admin_role_in_org_passes():
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
- assert user_groups_route._can_manage_groups(ctx, org_id=10) is True
- def test_can_manage_groups_wrong_org_blocked():
- ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN)
- assert user_groups_route._can_manage_groups(ctx, org_id=10) is False
- # ---- organizations route ---------------------------------------------------
- @pytest.mark.asyncio
- async def test_create_organization_rejects_duplicate_slug(monkeypatch):
- session = MagicMock()
- monkeypatch.setattr(
- organizations_route.Principal,
- "one_by_fields",
- AsyncMock(return_value=_principal(name="Existing", slug="acme")),
- )
- with pytest.raises(AlreadyExistsException):
- await organizations_route.create_organization(
- session=session,
- org_in=organizations_route.OrganizationCreate(name="Acme", slug="acme"),
- )
- @pytest.mark.asyncio
- async def test_delete_platform_org_blocked(monkeypatch):
- platform = _principal(id=1, name="Platform", slug="default")
- monkeypatch.setattr(
- organizations_route.Principal,
- "one_by_id",
- AsyncMock(return_value=platform),
- )
- with pytest.raises(ConflictException):
- await organizations_route.delete_organization(session=MagicMock(), id=1)
- @pytest.mark.asyncio
- async def test_delete_org_blocked_when_resources_exist(monkeypatch):
- org = _principal(id=2, name="Acme", slug="acme")
- monkeypatch.setattr(
- organizations_route.Principal,
- "one_by_id",
- AsyncMock(return_value=org),
- )
- monkeypatch.setattr(
- organizations_route,
- "_has_resources",
- AsyncMock(return_value=["models", "api_keys"]),
- )
- with pytest.raises(ConflictException) as excinfo:
- await organizations_route.delete_organization(session=MagicMock(), id=2)
- assert "models" in excinfo.value.message
- # ---- organization_members route -------------------------------------------
- @pytest.mark.asyncio
- async def test_remove_only_admin_blocked(monkeypatch):
- org = _principal(id=10, name="Acme", slug="acme")
- user = _user_row(id=2, principal_id=200)
- membership = MagicMock(spec=PrincipalMembership)
- membership.parent_principal_id = 10
- membership.member_principal_id = 200
- membership.role = OrgRole.ADMIN
- membership.deleted_at = None
- monkeypatch.setattr(
- organization_members.Principal,
- "one_by_id",
- AsyncMock(return_value=org),
- )
- monkeypatch.setattr(
- organization_members,
- "_resolve_user",
- AsyncMock(return_value=user),
- )
- monkeypatch.setattr(
- organization_members,
- "_find_membership",
- AsyncMock(return_value=membership),
- )
- monkeypatch.setattr(
- organization_members,
- "_has_other_admin",
- AsyncMock(return_value=False),
- )
- ctx = _ctx(is_admin=True)
- with pytest.raises(ConflictException):
- await organization_members.remove_org_member(
- session=MagicMock(), ctx=ctx, org_id=10, user_id=2
- )
- @pytest.mark.asyncio
- async def test_demote_only_admin_blocked(monkeypatch):
- org = _principal(id=10, name="Acme", slug="acme")
- user = _user_row(id=2, principal_id=200)
- membership = MagicMock(spec=PrincipalMembership)
- membership.parent_principal_id = 10
- membership.member_principal_id = 200
- membership.role = OrgRole.ADMIN
- membership.deleted_at = None
- monkeypatch.setattr(
- organization_members.Principal,
- "one_by_id",
- AsyncMock(return_value=org),
- )
- monkeypatch.setattr(
- organization_members,
- "_resolve_user",
- AsyncMock(return_value=user),
- )
- monkeypatch.setattr(
- organization_members,
- "_find_membership",
- AsyncMock(return_value=membership),
- )
- monkeypatch.setattr(
- organization_members,
- "_has_other_admin",
- AsyncMock(return_value=False),
- )
- ctx = _ctx(is_admin=True)
- with pytest.raises(ConflictException):
- await organization_members.update_org_member(
- session=MagicMock(),
- ctx=ctx,
- org_id=10,
- user_id=2,
- body=organization_members.MembershipUpdate(role=OrgRole.USER),
- )
- # ---- cluster_access route --------------------------------------------------
- @pytest.mark.asyncio
- async def test_grant_cluster_access_validates_principal(monkeypatch):
- cluster = MagicMock()
- cluster.id = 1
- cluster.deleted_at = None
- monkeypatch.setattr(
- cluster_access_route.Cluster,
- "one_by_id",
- AsyncMock(return_value=cluster),
- )
- monkeypatch.setattr(
- cluster_access_route.Principal,
- "one_by_id",
- AsyncMock(return_value=None), # principal does not exist
- )
- ctx = _ctx(is_admin=True)
- with pytest.raises(InvalidException):
- await cluster_access_route.grant_cluster_access(
- session=MagicMock(),
- ctx=ctx,
- cluster_id=1,
- body=cluster_access_route.ClusterAccessGrant(
- principal_type=PrincipalType.ORG, principal_id=999
- ),
- )
- @pytest.mark.asyncio
- async def test_grant_cluster_access_rejects_duplicate(monkeypatch):
- cluster = MagicMock()
- cluster.id = 1
- cluster.deleted_at = None
- org = _principal(id=2, kind=PrincipalType.ORG, name="Acme", slug="acme")
- monkeypatch.setattr(
- cluster_access_route.Cluster,
- "one_by_id",
- AsyncMock(return_value=cluster),
- )
- monkeypatch.setattr(
- cluster_access_route.Principal,
- "one_by_id",
- AsyncMock(return_value=org),
- )
- session = _session_returning([MagicMock()]) # exec returns existing row
- ctx = _ctx(is_admin=True)
- with pytest.raises(AlreadyExistsException):
- await cluster_access_route.grant_cluster_access(
- session=session,
- ctx=ctx,
- cluster_id=1,
- body=cluster_access_route.ClusterAccessGrant(
- principal_type=PrincipalType.ORG, principal_id=2
- ),
- )
- @pytest.mark.asyncio
- async def test_revoke_cluster_access_404_when_missing(monkeypatch):
- cluster = MagicMock()
- cluster.id = 1
- cluster.deleted_at = None
- monkeypatch.setattr(
- cluster_access_route.Cluster,
- "one_by_id",
- AsyncMock(return_value=cluster),
- )
- session = _session_returning(None)
- ctx = _ctx(is_admin=True)
- with pytest.raises(NotFoundException):
- await cluster_access_route.revoke_cluster_access(
- session=session,
- ctx=ctx,
- cluster_id=1,
- principal_id=42,
- )
- # ---- user_groups route -----------------------------------------------------
- @pytest.mark.asyncio
- async def test_create_group_blocked_for_member(monkeypatch):
- org = _principal(id=10, kind=PrincipalType.ORG, name="Acme", slug="acme")
- monkeypatch.setattr(
- user_groups_route.Principal,
- "one_by_id",
- AsyncMock(return_value=org),
- )
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
- with pytest.raises(ForbiddenException):
- await user_groups_route.create_group(
- session=MagicMock(),
- ctx=ctx,
- org_id=10,
- body=user_groups_route.UserGroupCreate(name="team-a"),
- )
- @pytest.mark.asyncio
- async def test_add_group_member_requires_org_membership(monkeypatch):
- group = _principal(
- id=5,
- kind=PrincipalType.GROUP,
- parent_principal_id=10,
- name="team-a",
- slug=None,
- )
- monkeypatch.setattr(
- user_groups_route.Principal,
- "one_by_id",
- AsyncMock(return_value=group),
- )
- user = _user_row(id=99, principal_id=999)
- monkeypatch.setattr(
- user_groups_route,
- "_resolve_user",
- AsyncMock(return_value=user),
- )
- session = _session_returning(None) # no org membership
- ctx = _ctx(is_admin=True)
- with pytest.raises(InvalidException):
- await user_groups_route.add_group_member(
- session=session,
- ctx=ctx,
- org_id=10,
- group_id=5,
- body=user_groups_route.GroupMembershipCreate(user_id=99),
- )
|