test_p2_routes.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. """Unit tests for P2 (Org / Group / Membership / cluster_access) route logic.
  2. The codebase pattern for tests in this repo is mock-based; we follow that
  3. here. Coverage focuses on the authorization branches and the corner cases
  4. that aren't trivially expressible in declarative SQL constraints.
  5. """
  6. from datetime import datetime, timezone
  7. from unittest.mock import AsyncMock, MagicMock
  8. import pytest
  9. from gpustack.api.exceptions import (
  10. AlreadyExistsException,
  11. ConflictException,
  12. ForbiddenException,
  13. InvalidException,
  14. NotFoundException,
  15. )
  16. from gpustack.routes import (
  17. cluster_access as cluster_access_route,
  18. organization_members,
  19. organizations as organizations_route,
  20. user_groups as user_groups_route,
  21. )
  22. from gpustack.schemas.principals import (
  23. OrgRole,
  24. Principal,
  25. PrincipalMembership,
  26. PrincipalType,
  27. )
  28. def _ctx(
  29. *,
  30. user_id: int = 1,
  31. user_principal_id: int = 100,
  32. is_admin: bool = False,
  33. current_principal_id: int | None = None,
  34. org_role: OrgRole | None = None,
  35. ):
  36. ctx = MagicMock()
  37. ctx.user = MagicMock()
  38. ctx.user.id = user_id
  39. ctx.user.principal_id = user_principal_id
  40. ctx.is_platform_admin = is_admin
  41. ctx.current_principal_id = current_principal_id
  42. ctx.org_role = org_role
  43. return ctx
  44. def _principal(
  45. id: int = 10,
  46. kind: PrincipalType = PrincipalType.ORG,
  47. parent_principal_id: int | None = None,
  48. name: str = "Acme",
  49. slug: str | None = "acme",
  50. ):
  51. p = MagicMock(spec=Principal)
  52. p.id = id
  53. p.kind = kind
  54. p.parent_principal_id = parent_principal_id
  55. p.name = name
  56. p.slug = slug
  57. p.description = None
  58. p.deleted_at = None
  59. p.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
  60. p.updated_at = p.created_at
  61. return p
  62. def _user_row(id: int = 2, principal_id: int = 200):
  63. u = MagicMock()
  64. u.id = id
  65. u.principal_id = principal_id
  66. u.username = f"user-{id}"
  67. u.full_name = None
  68. u.is_system = False
  69. u.deleted_at = None
  70. return u
  71. def _session_returning(*results):
  72. """Make a mock async session whose successive .exec() return the queued results."""
  73. session = MagicMock()
  74. queue = []
  75. for value in results:
  76. result = MagicMock()
  77. if isinstance(value, list):
  78. scalars = MagicMock()
  79. scalars.all = MagicMock(return_value=value)
  80. result.scalars = MagicMock(return_value=scalars)
  81. result.first = MagicMock(return_value=value[0] if value else None)
  82. result.all = MagicMock(return_value=value)
  83. else:
  84. result.scalar_one_or_none = MagicMock(return_value=value)
  85. result.first = MagicMock(return_value=value)
  86. queue.append(result)
  87. session.exec = AsyncMock(side_effect=queue)
  88. session.commit = AsyncMock()
  89. session.rollback = AsyncMock()
  90. session.refresh = AsyncMock()
  91. session.delete = AsyncMock()
  92. session.add = MagicMock()
  93. return session
  94. # ---- _can_manage (organization_members) ------------------------------------
  95. def test_can_manage_platform_admin_always():
  96. ctx = _ctx(is_admin=True)
  97. assert organization_members._can_manage(ctx, 1) is True
  98. assert organization_members._can_manage(ctx, 99) is True
  99. def test_can_manage_admin_in_org_can_manage():
  100. ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
  101. assert organization_members._can_manage(ctx, 10) is True
  102. def test_can_manage_admin_cannot_manage_other_org():
  103. ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
  104. assert organization_members._can_manage(ctx, 99) is False
  105. def test_can_manage_member_cannot_manage():
  106. ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
  107. assert organization_members._can_manage(ctx, 10) is False
  108. # ---- _can_manage_groups ----------------------------------------------------
  109. def test_can_manage_groups_admin_passthrough():
  110. ctx = _ctx(is_admin=True, current_principal_id=None)
  111. assert user_groups_route._can_manage_groups(ctx, org_id=42) is True
  112. def test_can_manage_groups_member_blocked():
  113. ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
  114. assert user_groups_route._can_manage_groups(ctx, org_id=10) is False
  115. def test_can_manage_groups_admin_role_in_org_passes():
  116. ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
  117. assert user_groups_route._can_manage_groups(ctx, org_id=10) is True
  118. def test_can_manage_groups_wrong_org_blocked():
  119. ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN)
  120. assert user_groups_route._can_manage_groups(ctx, org_id=10) is False
  121. # ---- organizations route ---------------------------------------------------
  122. @pytest.mark.asyncio
  123. async def test_create_organization_rejects_duplicate_slug(monkeypatch):
  124. session = MagicMock()
  125. monkeypatch.setattr(
  126. organizations_route.Principal,
  127. "one_by_fields",
  128. AsyncMock(return_value=_principal(name="Existing", slug="acme")),
  129. )
  130. with pytest.raises(AlreadyExistsException):
  131. await organizations_route.create_organization(
  132. session=session,
  133. org_in=organizations_route.OrganizationCreate(name="Acme", slug="acme"),
  134. )
  135. @pytest.mark.asyncio
  136. async def test_delete_platform_org_blocked(monkeypatch):
  137. platform = _principal(id=1, name="Platform", slug="default")
  138. monkeypatch.setattr(
  139. organizations_route.Principal,
  140. "one_by_id",
  141. AsyncMock(return_value=platform),
  142. )
  143. with pytest.raises(ConflictException):
  144. await organizations_route.delete_organization(session=MagicMock(), id=1)
  145. @pytest.mark.asyncio
  146. async def test_delete_org_blocked_when_resources_exist(monkeypatch):
  147. org = _principal(id=2, name="Acme", slug="acme")
  148. monkeypatch.setattr(
  149. organizations_route.Principal,
  150. "one_by_id",
  151. AsyncMock(return_value=org),
  152. )
  153. monkeypatch.setattr(
  154. organizations_route,
  155. "_has_resources",
  156. AsyncMock(return_value=["models", "api_keys"]),
  157. )
  158. with pytest.raises(ConflictException) as excinfo:
  159. await organizations_route.delete_organization(session=MagicMock(), id=2)
  160. assert "models" in excinfo.value.message
  161. # ---- organization_members route -------------------------------------------
  162. @pytest.mark.asyncio
  163. async def test_remove_only_admin_blocked(monkeypatch):
  164. org = _principal(id=10, name="Acme", slug="acme")
  165. user = _user_row(id=2, principal_id=200)
  166. membership = MagicMock(spec=PrincipalMembership)
  167. membership.parent_principal_id = 10
  168. membership.member_principal_id = 200
  169. membership.role = OrgRole.ADMIN
  170. membership.deleted_at = None
  171. monkeypatch.setattr(
  172. organization_members.Principal,
  173. "one_by_id",
  174. AsyncMock(return_value=org),
  175. )
  176. monkeypatch.setattr(
  177. organization_members,
  178. "_resolve_user",
  179. AsyncMock(return_value=user),
  180. )
  181. monkeypatch.setattr(
  182. organization_members,
  183. "_find_membership",
  184. AsyncMock(return_value=membership),
  185. )
  186. monkeypatch.setattr(
  187. organization_members,
  188. "_has_other_admin",
  189. AsyncMock(return_value=False),
  190. )
  191. ctx = _ctx(is_admin=True)
  192. with pytest.raises(ConflictException):
  193. await organization_members.remove_org_member(
  194. session=MagicMock(), ctx=ctx, org_id=10, user_id=2
  195. )
  196. @pytest.mark.asyncio
  197. async def test_demote_only_admin_blocked(monkeypatch):
  198. org = _principal(id=10, name="Acme", slug="acme")
  199. user = _user_row(id=2, principal_id=200)
  200. membership = MagicMock(spec=PrincipalMembership)
  201. membership.parent_principal_id = 10
  202. membership.member_principal_id = 200
  203. membership.role = OrgRole.ADMIN
  204. membership.deleted_at = None
  205. monkeypatch.setattr(
  206. organization_members.Principal,
  207. "one_by_id",
  208. AsyncMock(return_value=org),
  209. )
  210. monkeypatch.setattr(
  211. organization_members,
  212. "_resolve_user",
  213. AsyncMock(return_value=user),
  214. )
  215. monkeypatch.setattr(
  216. organization_members,
  217. "_find_membership",
  218. AsyncMock(return_value=membership),
  219. )
  220. monkeypatch.setattr(
  221. organization_members,
  222. "_has_other_admin",
  223. AsyncMock(return_value=False),
  224. )
  225. ctx = _ctx(is_admin=True)
  226. with pytest.raises(ConflictException):
  227. await organization_members.update_org_member(
  228. session=MagicMock(),
  229. ctx=ctx,
  230. org_id=10,
  231. user_id=2,
  232. body=organization_members.MembershipUpdate(role=OrgRole.USER),
  233. )
  234. # ---- cluster_access route --------------------------------------------------
  235. @pytest.mark.asyncio
  236. async def test_grant_cluster_access_validates_principal(monkeypatch):
  237. cluster = MagicMock()
  238. cluster.id = 1
  239. cluster.deleted_at = None
  240. monkeypatch.setattr(
  241. cluster_access_route.Cluster,
  242. "one_by_id",
  243. AsyncMock(return_value=cluster),
  244. )
  245. monkeypatch.setattr(
  246. cluster_access_route.Principal,
  247. "one_by_id",
  248. AsyncMock(return_value=None), # principal does not exist
  249. )
  250. ctx = _ctx(is_admin=True)
  251. with pytest.raises(InvalidException):
  252. await cluster_access_route.grant_cluster_access(
  253. session=MagicMock(),
  254. ctx=ctx,
  255. cluster_id=1,
  256. body=cluster_access_route.ClusterAccessGrant(
  257. principal_type=PrincipalType.ORG, principal_id=999
  258. ),
  259. )
  260. @pytest.mark.asyncio
  261. async def test_grant_cluster_access_rejects_duplicate(monkeypatch):
  262. cluster = MagicMock()
  263. cluster.id = 1
  264. cluster.deleted_at = None
  265. org = _principal(id=2, kind=PrincipalType.ORG, name="Acme", slug="acme")
  266. monkeypatch.setattr(
  267. cluster_access_route.Cluster,
  268. "one_by_id",
  269. AsyncMock(return_value=cluster),
  270. )
  271. monkeypatch.setattr(
  272. cluster_access_route.Principal,
  273. "one_by_id",
  274. AsyncMock(return_value=org),
  275. )
  276. session = _session_returning([MagicMock()]) # exec returns existing row
  277. ctx = _ctx(is_admin=True)
  278. with pytest.raises(AlreadyExistsException):
  279. await cluster_access_route.grant_cluster_access(
  280. session=session,
  281. ctx=ctx,
  282. cluster_id=1,
  283. body=cluster_access_route.ClusterAccessGrant(
  284. principal_type=PrincipalType.ORG, principal_id=2
  285. ),
  286. )
  287. @pytest.mark.asyncio
  288. async def test_revoke_cluster_access_404_when_missing(monkeypatch):
  289. cluster = MagicMock()
  290. cluster.id = 1
  291. cluster.deleted_at = None
  292. monkeypatch.setattr(
  293. cluster_access_route.Cluster,
  294. "one_by_id",
  295. AsyncMock(return_value=cluster),
  296. )
  297. session = _session_returning(None)
  298. ctx = _ctx(is_admin=True)
  299. with pytest.raises(NotFoundException):
  300. await cluster_access_route.revoke_cluster_access(
  301. session=session,
  302. ctx=ctx,
  303. cluster_id=1,
  304. principal_id=42,
  305. )
  306. # ---- user_groups route -----------------------------------------------------
  307. @pytest.mark.asyncio
  308. async def test_create_group_blocked_for_member(monkeypatch):
  309. org = _principal(id=10, kind=PrincipalType.ORG, name="Acme", slug="acme")
  310. monkeypatch.setattr(
  311. user_groups_route.Principal,
  312. "one_by_id",
  313. AsyncMock(return_value=org),
  314. )
  315. ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
  316. with pytest.raises(ForbiddenException):
  317. await user_groups_route.create_group(
  318. session=MagicMock(),
  319. ctx=ctx,
  320. org_id=10,
  321. body=user_groups_route.UserGroupCreate(name="team-a"),
  322. )
  323. @pytest.mark.asyncio
  324. async def test_add_group_member_requires_org_membership(monkeypatch):
  325. group = _principal(
  326. id=5,
  327. kind=PrincipalType.GROUP,
  328. parent_principal_id=10,
  329. name="team-a",
  330. slug=None,
  331. )
  332. monkeypatch.setattr(
  333. user_groups_route.Principal,
  334. "one_by_id",
  335. AsyncMock(return_value=group),
  336. )
  337. user = _user_row(id=99, principal_id=999)
  338. monkeypatch.setattr(
  339. user_groups_route,
  340. "_resolve_user",
  341. AsyncMock(return_value=user),
  342. )
  343. session = _session_returning(None) # no org membership
  344. ctx = _ctx(is_admin=True)
  345. with pytest.raises(InvalidException):
  346. await user_groups_route.add_group_member(
  347. session=session,
  348. ctx=ctx,
  349. org_id=10,
  350. group_id=5,
  351. body=user_groups_route.GroupMembershipCreate(user_id=99),
  352. )