test_p4_routes.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. """Unit tests for P4 (ALLOWED_PRINCIPALS extension) route logic."""
  2. from unittest.mock import AsyncMock, MagicMock
  3. import pytest
  4. from gpustack.api.exceptions import (
  5. AlreadyExistsException,
  6. InvalidException,
  7. NotFoundException,
  8. )
  9. from gpustack.routes import model_route_principals as principals_route
  10. from gpustack.schemas.model_routes import ModelRoute
  11. from gpustack.schemas.principals import Principal, PrincipalType
  12. from gpustack.schemas.users import User
  13. def _route(id: int = 1):
  14. route = MagicMock(spec=ModelRoute)
  15. route.id = id
  16. route.deleted_at = None
  17. return route
  18. def _principal(
  19. id: int = 5,
  20. kind: PrincipalType = PrincipalType.ORG,
  21. ):
  22. p = MagicMock(spec=Principal)
  23. p.id = id
  24. p.kind = kind
  25. p.deleted_at = None
  26. p.parent_principal_id = None
  27. return p
  28. def _exec_returning(*results):
  29. queue = []
  30. for value in results:
  31. result = MagicMock()
  32. if isinstance(value, list):
  33. result.all = MagicMock(return_value=value)
  34. scalars = MagicMock()
  35. scalars.all = MagicMock(return_value=value)
  36. result.scalars = MagicMock(return_value=scalars)
  37. result.first = MagicMock(return_value=value[0] if value else None)
  38. result.scalar_one_or_none = MagicMock(
  39. return_value=value[0] if value else None
  40. )
  41. else:
  42. result.scalar_one_or_none = MagicMock(return_value=value)
  43. result.first = MagicMock(return_value=value)
  44. scalars = MagicMock()
  45. scalars.all = MagicMock(return_value=[])
  46. result.scalars = MagicMock(return_value=scalars)
  47. result.all = MagicMock(return_value=[])
  48. queue.append(result)
  49. return AsyncMock(side_effect=queue)
  50. def _session(*results):
  51. s = MagicMock()
  52. s.exec = _exec_returning(*results)
  53. s.commit = AsyncMock()
  54. s.rollback = AsyncMock()
  55. s.refresh = AsyncMock()
  56. s.delete = AsyncMock()
  57. s.add = MagicMock()
  58. return s
  59. # ---- list -------------------------------------------------------------------
  60. @pytest.mark.asyncio
  61. async def test_list_principals_returns_attached_links(monkeypatch):
  62. monkeypatch.setattr(
  63. principals_route.ModelRoute,
  64. "one_by_id",
  65. AsyncMock(return_value=_route()),
  66. )
  67. link1 = MagicMock()
  68. link1.id = 100
  69. link1.route_id = 1
  70. link1.principal_id = 5
  71. # Two exec() calls: list rows, then bulk lookup of principals.
  72. session = MagicMock()
  73. session.exec = _exec_returning([link1], [_principal(id=5, kind=PrincipalType.ORG)])
  74. result = await principals_route.list_route_principals(session=session, id=1)
  75. assert len(result) == 1
  76. assert result[0].route_id == 1
  77. assert result[0].principal_type == PrincipalType.ORG
  78. assert result[0].principal_id == 5
  79. @pytest.mark.asyncio
  80. async def test_list_principals_404_when_route_missing(monkeypatch):
  81. monkeypatch.setattr(
  82. principals_route.ModelRoute,
  83. "one_by_id",
  84. AsyncMock(return_value=None),
  85. )
  86. with pytest.raises(NotFoundException):
  87. await principals_route.list_route_principals(session=MagicMock(), id=999)
  88. # ---- add --------------------------------------------------------------------
  89. @pytest.mark.asyncio
  90. async def test_add_principal_validates_principal_exists(monkeypatch):
  91. monkeypatch.setattr(
  92. principals_route.ModelRoute,
  93. "one_by_id",
  94. AsyncMock(return_value=_route()),
  95. )
  96. monkeypatch.setattr(
  97. principals_route.Principal,
  98. "one_by_id",
  99. AsyncMock(return_value=None),
  100. )
  101. with pytest.raises(InvalidException):
  102. await principals_route.add_route_principal(
  103. session=MagicMock(),
  104. id=1,
  105. body=principals_route.PrincipalRef(
  106. principal_type=PrincipalType.ORG, principal_id=999
  107. ),
  108. )
  109. @pytest.mark.asyncio
  110. async def test_add_principal_rejects_kind_mismatch(monkeypatch):
  111. monkeypatch.setattr(
  112. principals_route.ModelRoute,
  113. "one_by_id",
  114. AsyncMock(return_value=_route()),
  115. )
  116. # Caller declared GROUP, but the principal row is actually an ORG.
  117. monkeypatch.setattr(
  118. principals_route.Principal,
  119. "one_by_id",
  120. AsyncMock(return_value=_principal(id=5, kind=PrincipalType.ORG)),
  121. )
  122. with pytest.raises(InvalidException):
  123. await principals_route.add_route_principal(
  124. session=MagicMock(),
  125. id=1,
  126. body=principals_route.PrincipalRef(
  127. principal_type=PrincipalType.GROUP, principal_id=5
  128. ),
  129. )
  130. @pytest.mark.asyncio
  131. async def test_add_principal_rejects_system_user(monkeypatch):
  132. monkeypatch.setattr(
  133. principals_route.ModelRoute,
  134. "one_by_id",
  135. AsyncMock(return_value=_route()),
  136. )
  137. monkeypatch.setattr(
  138. principals_route.Principal,
  139. "one_by_id",
  140. AsyncMock(return_value=_principal(id=2, kind=PrincipalType.USER)),
  141. )
  142. sys_user = MagicMock(spec=User)
  143. sys_user.is_system = True
  144. sys_user.deleted_at = None
  145. monkeypatch.setattr(
  146. principals_route.User,
  147. "one_by_field",
  148. AsyncMock(return_value=sys_user),
  149. )
  150. with pytest.raises(InvalidException):
  151. await principals_route.add_route_principal(
  152. session=MagicMock(),
  153. id=1,
  154. body=principals_route.PrincipalRef(
  155. principal_type=PrincipalType.USER, principal_id=2
  156. ),
  157. )
  158. @pytest.mark.asyncio
  159. async def test_add_principal_rejects_duplicate(monkeypatch):
  160. monkeypatch.setattr(
  161. principals_route.ModelRoute,
  162. "one_by_id",
  163. AsyncMock(return_value=_route()),
  164. )
  165. monkeypatch.setattr(
  166. principals_route.Principal,
  167. "one_by_id",
  168. AsyncMock(return_value=_principal(id=5, kind=PrincipalType.ORG)),
  169. )
  170. existing_link = MagicMock()
  171. session = _session(existing_link)
  172. with pytest.raises(AlreadyExistsException):
  173. await principals_route.add_route_principal(
  174. session=session,
  175. id=1,
  176. body=principals_route.PrincipalRef(
  177. principal_type=PrincipalType.ORG, principal_id=5
  178. ),
  179. )
  180. @pytest.mark.asyncio
  181. async def test_add_principal_creates_link_and_invalidates_cache(monkeypatch):
  182. monkeypatch.setattr(
  183. principals_route.ModelRoute,
  184. "one_by_id",
  185. AsyncMock(return_value=_route()),
  186. )
  187. monkeypatch.setattr(
  188. principals_route.Principal,
  189. "one_by_id",
  190. AsyncMock(return_value=_principal(id=5, kind=PrincipalType.ORG)),
  191. )
  192. cache_mock = AsyncMock()
  193. monkeypatch.setattr(principals_route, "revoke_model_access_cache", cache_mock)
  194. session = _session(None) # no existing link
  195. await principals_route.add_route_principal(
  196. session=session,
  197. id=1,
  198. body=principals_route.PrincipalRef(
  199. principal_type=PrincipalType.ORG, principal_id=5
  200. ),
  201. )
  202. session.add.assert_called_once()
  203. session.commit.assert_awaited()
  204. cache_mock.assert_awaited_once()
  205. # ---- remove -----------------------------------------------------------------
  206. @pytest.mark.asyncio
  207. async def test_remove_principal_404_when_missing(monkeypatch):
  208. monkeypatch.setattr(
  209. principals_route.ModelRoute,
  210. "one_by_id",
  211. AsyncMock(return_value=_route()),
  212. )
  213. session = _session(None)
  214. with pytest.raises(NotFoundException):
  215. await principals_route.remove_route_principal(
  216. session=session,
  217. id=1,
  218. principal_type=PrincipalType.USER,
  219. principal_id=99,
  220. )
  221. @pytest.mark.asyncio
  222. async def test_remove_principal_invalidates_cache(monkeypatch):
  223. monkeypatch.setattr(
  224. principals_route.ModelRoute,
  225. "one_by_id",
  226. AsyncMock(return_value=_route()),
  227. )
  228. cache_mock = AsyncMock()
  229. monkeypatch.setattr(principals_route, "revoke_model_access_cache", cache_mock)
  230. link = MagicMock()
  231. session = _session(link)
  232. await principals_route.remove_route_principal(
  233. session=session,
  234. id=1,
  235. principal_type=PrincipalType.ORG,
  236. principal_id=5,
  237. )
  238. session.delete.assert_awaited_once_with(link)
  239. session.commit.assert_awaited()
  240. cache_mock.assert_awaited_once()