test_tenant.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. """Unit tests for TenantContext resolution and role guards."""
  2. from unittest.mock import AsyncMock, MagicMock
  3. import pytest
  4. from gpustack.api.exceptions import ForbiddenException, InvalidException
  5. from gpustack.api.tenant import (
  6. _resolve_requested_principal_id,
  7. get_tenant_context,
  8. require_org_role,
  9. require_platform_admin,
  10. )
  11. from gpustack.schemas.principals import (
  12. OrgRole,
  13. PrincipalMembership,
  14. PrincipalType,
  15. )
  16. def _request(api_key=None):
  17. request = MagicMock()
  18. request.state = MagicMock(spec=[])
  19. if api_key is not None:
  20. request.state.api_key = api_key
  21. return request
  22. def _user(
  23. id: int = 7,
  24. is_admin: bool = False,
  25. is_system: bool = False,
  26. principal_id=None,
  27. ):
  28. user = MagicMock()
  29. user.id = id
  30. user.is_admin = is_admin
  31. user.is_system = is_system
  32. user.principal_id = principal_id
  33. user.is_active = True
  34. return user
  35. def _api_key(owner_principal_id: int):
  36. key = MagicMock()
  37. key.owner_principal_id = owner_principal_id
  38. return key
  39. def _principal(
  40. id: int = 99,
  41. kind: PrincipalType = PrincipalType.ORG,
  42. deleted_at=None,
  43. ):
  44. p = MagicMock()
  45. p.id = id
  46. p.kind = kind
  47. p.deleted_at = deleted_at
  48. return p
  49. def _session_returning(*scalar_lists):
  50. """Build a mock session whose successive `exec(...)` calls yield the given
  51. result sets. Each item in `scalar_lists` is either:
  52. - a list -> wrapped in `.scalars().all()` and `.all()`
  53. - any other value -> returned by `.scalar_one_or_none()` and `.first()`
  54. """
  55. session = MagicMock()
  56. results = []
  57. for value in scalar_lists:
  58. result = MagicMock()
  59. if isinstance(value, list):
  60. scalars = MagicMock()
  61. scalars.all = MagicMock(return_value=value)
  62. result.scalars = MagicMock(return_value=scalars)
  63. result.all = MagicMock(return_value=value)
  64. else:
  65. result.scalar_one_or_none = MagicMock(return_value=value)
  66. result.first = MagicMock(return_value=value)
  67. results.append(result)
  68. session.exec = AsyncMock(side_effect=results)
  69. return session
  70. # ---- _resolve_requested_principal_id ---------------------------------------
  71. def test_resolve_principal_id_prefers_api_key():
  72. user = _user(principal_id=1)
  73. request = _request(api_key=_api_key(owner_principal_id=42))
  74. assert _resolve_requested_principal_id(request, user, "999") == 42
  75. def test_resolve_principal_id_uses_header_when_no_api_key():
  76. user = _user(principal_id=1)
  77. request = _request()
  78. assert _resolve_requested_principal_id(request, user, "999") == 999
  79. def test_resolve_principal_id_falls_back_to_user_principal():
  80. user = _user(principal_id=1)
  81. request = _request()
  82. assert _resolve_requested_principal_id(request, user, None) == 1
  83. def test_resolve_principal_id_invalid_header_raises():
  84. user = _user(principal_id=1)
  85. request = _request()
  86. with pytest.raises(InvalidException):
  87. _resolve_requested_principal_id(request, user, "not-an-int")
  88. # ---- get_tenant_context -----------------------------------------------------
  89. @pytest.mark.asyncio
  90. async def test_platform_admin_without_header_has_no_org_filter():
  91. user = _user(id=1, is_admin=True, principal_id=None)
  92. request = _request()
  93. session = _session_returning() # no DB calls expected
  94. ctx = await get_tenant_context(
  95. request=request,
  96. session=session,
  97. user=user,
  98. x_organization_id=None,
  99. )
  100. assert ctx.is_platform_admin is True
  101. assert ctx.current_principal_id is None
  102. assert ctx.org_role is None
  103. assert ctx.accessible_cluster_ids == set()
  104. @pytest.mark.asyncio
  105. async def test_member_uses_team_org_via_header():
  106. """Non-admin sends X-Organization-Id pointing at an Org they belong to."""
  107. user = _user(id=10, is_admin=False, principal_id=100)
  108. request = _request()
  109. membership = PrincipalMembership(
  110. member_principal_id=100,
  111. parent_principal_id=5,
  112. role=OrgRole.USER,
  113. )
  114. session = _session_returning(
  115. membership, # _resolve_membership
  116. [11, 12], # _user_group_principal_ids
  117. [101, 102], # _accessible_clusters
  118. _principal(id=5, kind=PrincipalType.ORG), # org existence check
  119. )
  120. ctx = await get_tenant_context(
  121. request=request,
  122. session=session,
  123. user=user,
  124. x_organization_id="5",
  125. )
  126. assert ctx.is_platform_admin is False
  127. assert ctx.current_principal_id == 5
  128. assert ctx.org_role == OrgRole.USER
  129. assert ctx.accessible_cluster_ids == {101, 102}
  130. assert ctx.current_is_personal_scope is False
  131. @pytest.mark.asyncio
  132. async def test_personal_scope_short_circuits():
  133. """When current_principal_id == user.principal_id we treat it as
  134. personal scope — no org membership lookup, no group expansion."""
  135. user = _user(id=10, is_admin=False, principal_id=100)
  136. request = _request()
  137. session = _session_returning(
  138. [], # _accessible_clusters
  139. )
  140. ctx = await get_tenant_context(
  141. request=request,
  142. session=session,
  143. user=user,
  144. x_organization_id=None,
  145. )
  146. assert ctx.current_principal_id == 100
  147. assert ctx.current_is_personal_scope is True
  148. assert ctx.org_role is None
  149. @pytest.mark.asyncio
  150. async def test_non_member_request_to_other_org_is_rejected():
  151. user = _user(id=11, is_admin=False, principal_id=100)
  152. request = _request()
  153. session = _session_returning(None) # no membership row
  154. with pytest.raises(ForbiddenException):
  155. await get_tenant_context(
  156. request=request,
  157. session=session,
  158. user=user,
  159. x_organization_id="2",
  160. )
  161. @pytest.mark.asyncio
  162. async def test_platform_admin_can_act_in_org_without_membership():
  163. user = _user(id=1, is_admin=True, principal_id=None)
  164. request = _request()
  165. session = _session_returning(
  166. None, # no membership; admin should still pass
  167. [],
  168. [],
  169. _principal(id=7, kind=PrincipalType.ORG),
  170. )
  171. ctx = await get_tenant_context(
  172. request=request,
  173. session=session,
  174. user=user,
  175. x_organization_id="7",
  176. )
  177. assert ctx.is_platform_admin is True
  178. assert ctx.current_principal_id == 7
  179. assert ctx.org_role is None
  180. @pytest.mark.asyncio
  181. async def test_api_key_overrides_header():
  182. user = _user(id=10, is_admin=False, principal_id=100)
  183. request = _request(api_key=_api_key(owner_principal_id=42))
  184. membership = PrincipalMembership(
  185. member_principal_id=100,
  186. parent_principal_id=42,
  187. role=OrgRole.USER,
  188. )
  189. session = _session_returning(
  190. membership,
  191. [],
  192. [],
  193. _principal(id=42, kind=PrincipalType.ORG),
  194. )
  195. ctx = await get_tenant_context(
  196. request=request,
  197. session=session,
  198. user=user,
  199. x_organization_id="999", # ignored when api_key is set
  200. )
  201. assert ctx.current_principal_id == 42
  202. # ---- require_platform_admin / require_org_role ------------------------------
  203. @pytest.mark.asyncio
  204. async def test_require_platform_admin_blocks_regular_user():
  205. ctx = MagicMock()
  206. ctx.is_platform_admin = False
  207. with pytest.raises(ForbiddenException):
  208. await require_platform_admin(ctx)
  209. @pytest.mark.asyncio
  210. async def test_require_platform_admin_allows_admin():
  211. ctx = MagicMock()
  212. ctx.is_platform_admin = True
  213. assert await require_platform_admin(ctx) is ctx
  214. @pytest.mark.asyncio
  215. async def test_require_org_role_admin_passthrough():
  216. dep = require_org_role(OrgRole.ADMIN)
  217. ctx = MagicMock()
  218. ctx.is_platform_admin = True
  219. assert await dep(ctx) is ctx
  220. @pytest.mark.asyncio
  221. async def test_require_org_role_blocks_when_no_org_context():
  222. dep = require_org_role(OrgRole.ADMIN)
  223. ctx = MagicMock()
  224. ctx.is_platform_admin = False
  225. ctx.current_principal_id = None
  226. with pytest.raises(ForbiddenException):
  227. await dep(ctx)
  228. @pytest.mark.asyncio
  229. async def test_require_org_role_blocks_insufficient_role():
  230. dep = require_org_role(OrgRole.ADMIN)
  231. def _assert_role(*allowed):
  232. if OrgRole.USER not in allowed:
  233. raise ForbiddenException(message="nope")
  234. ctx = MagicMock()
  235. ctx.is_platform_admin = False
  236. ctx.current_principal_id = 1
  237. ctx.org_role = OrgRole.USER
  238. ctx.assert_org_role = _assert_role
  239. with pytest.raises(ForbiddenException):
  240. await dep(ctx)
  241. @pytest.mark.asyncio
  242. async def test_require_org_role_passes_for_matching_role():
  243. dep = require_org_role(OrgRole.ADMIN, OrgRole.ADMIN)
  244. def _assert_role(*allowed):
  245. if OrgRole.ADMIN not in allowed:
  246. raise AssertionError("did not pass owner role through")
  247. ctx = MagicMock()
  248. ctx.is_platform_admin = False
  249. ctx.current_principal_id = 1
  250. ctx.org_role = OrgRole.ADMIN
  251. ctx.assert_org_role = _assert_role
  252. assert await dep(ctx) is ctx