test_workers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. from types import SimpleNamespace
  2. from unittest.mock import AsyncMock, MagicMock, patch
  3. import pytest
  4. from gpustack.api.exceptions import ForbiddenException, NotFoundException
  5. from gpustack.routes.workers import (
  6. delete_worker,
  7. get_worker_privatekey,
  8. update_worker,
  9. update_worker_data,
  10. )
  11. from gpustack.schemas.principals import OrgRole
  12. from gpustack.schemas.workers import (
  13. Worker,
  14. WorkerCreate,
  15. WorkerStateEnum,
  16. WorkerStatus,
  17. WorkerUpdate,
  18. Maintenance,
  19. SystemReserved,
  20. )
  21. from gpustack.schemas.clusters import Cluster, ClusterProvider
  22. def test_update_worker_data_preserves_maintenance_mode():
  23. """
  24. Test that maintenance mode is preserved when a worker re-registers.
  25. This verifies the fix for the issue where workers automatically exit
  26. maintenance mode after restart.
  27. """
  28. # Create an existing worker with maintenance mode enabled
  29. existing_worker = Worker(
  30. id=1,
  31. name="test-worker",
  32. labels={"env": "test"},
  33. maintenance=Maintenance(enabled=True, message="Scheduled maintenance"),
  34. state=WorkerStateEnum.MAINTENANCE,
  35. cluster_id=1,
  36. hostname="test-host",
  37. ip="192.168.1.100",
  38. ifname="eth0",
  39. port=8080,
  40. worker_uuid="test-uuid-123",
  41. status=WorkerStatus.get_default_status(),
  42. )
  43. # Create a worker registration request without maintenance field
  44. # (simulating a worker restart/re-registration)
  45. worker_in = WorkerCreate(
  46. name="test-worker",
  47. labels={"env": "test", "new": "label"},
  48. maintenance=None, # Not set during re-registration
  49. hostname="test-host",
  50. ip="192.168.1.100",
  51. ifname="eth0",
  52. port=8080,
  53. worker_uuid="test-uuid-123",
  54. cluster_id=1,
  55. status=WorkerStatus.get_default_status(),
  56. system_reserved=SystemReserved(ram=0, vram=0),
  57. )
  58. # Update the worker data
  59. updated_worker = update_worker_data(worker_in, existing=existing_worker)
  60. # Verify that maintenance mode is preserved
  61. assert updated_worker.maintenance is not None
  62. assert updated_worker.maintenance.enabled is True
  63. assert updated_worker.maintenance.message == "Scheduled maintenance"
  64. # State will be computed as MAINTENANCE because of compute_state()
  65. assert updated_worker.state == WorkerStateEnum.MAINTENANCE
  66. def test_update_worker_data_can_disable_maintenance_mode():
  67. """
  68. Test that maintenance mode can be explicitly disabled when provided.
  69. """
  70. # Create an existing worker with maintenance mode enabled
  71. existing_worker = Worker(
  72. id=1,
  73. name="test-worker",
  74. labels={"env": "test"},
  75. maintenance=Maintenance(enabled=True, message="Scheduled maintenance"),
  76. state=WorkerStateEnum.MAINTENANCE,
  77. cluster_id=1,
  78. hostname="test-host",
  79. ip="192.168.1.100",
  80. ifname="eth0",
  81. port=8080,
  82. worker_uuid="test-uuid-123",
  83. status=WorkerStatus.get_default_status(),
  84. )
  85. # Create a worker update request with maintenance explicitly disabled
  86. worker_in = WorkerCreate(
  87. name="test-worker",
  88. labels={"env": "test"},
  89. maintenance=Maintenance(enabled=False, message=None),
  90. hostname="test-host",
  91. ip="192.168.1.100",
  92. ifname="eth0",
  93. port=8080,
  94. worker_uuid="test-uuid-123",
  95. cluster_id=1,
  96. status=WorkerStatus.get_default_status(),
  97. system_reserved=SystemReserved(ram=0, vram=0),
  98. )
  99. # Update the worker data
  100. updated_worker = update_worker_data(worker_in, existing=existing_worker)
  101. # Verify that maintenance mode is disabled
  102. assert updated_worker.maintenance is not None
  103. assert updated_worker.maintenance.enabled is False
  104. assert updated_worker.maintenance.message is None
  105. # State will be computed based on heartbeat, but maintenance is disabled
  106. # Since maintenance is disabled, the state won't be MAINTENANCE
  107. assert updated_worker.state != WorkerStateEnum.MAINTENANCE
  108. def test_update_worker_data_new_worker_without_maintenance():
  109. """
  110. Test that a new worker can be created without maintenance mode.
  111. """
  112. # Create a new worker registration request
  113. worker_in = WorkerCreate(
  114. name="new-worker",
  115. labels={"env": "prod"},
  116. maintenance=None,
  117. hostname="new-host",
  118. ip="192.168.1.101",
  119. ifname="eth0",
  120. port=8080,
  121. worker_uuid="new-uuid-456",
  122. cluster_id=1,
  123. status=WorkerStatus.get_default_status(),
  124. system_reserved=SystemReserved(ram=0, vram=0),
  125. )
  126. # Create cluster for new worker
  127. cluster = Cluster(
  128. id=1,
  129. name="test-cluster",
  130. provider=ClusterProvider.Docker,
  131. )
  132. # Create a new worker (no existing worker)
  133. new_worker = update_worker_data(worker_in, existing=None, cluster=cluster)
  134. # Verify that the new worker is created without maintenance mode
  135. assert new_worker.maintenance is None
  136. # State may be NOT_READY due to missing heartbeat, but not MAINTENANCE
  137. assert new_worker.state != WorkerStateEnum.MAINTENANCE
  138. def test_update_worker_data_preserves_labels_merge():
  139. """
  140. Test that labels are properly merged when updating a worker.
  141. """
  142. # Create an existing worker with some labels
  143. existing_worker = Worker(
  144. id=1,
  145. name="test-worker",
  146. labels={"env": "test", "region": "us-west"},
  147. maintenance=None,
  148. state=WorkerStateEnum.READY,
  149. cluster_id=1,
  150. hostname="test-host",
  151. ip="192.168.1.100",
  152. ifname="eth0",
  153. port=8080,
  154. worker_uuid="test-uuid-123",
  155. status=WorkerStatus.get_default_status(),
  156. )
  157. # Create a worker update with new labels
  158. worker_in = WorkerCreate(
  159. name="test-worker",
  160. labels={"env": "prod", "zone": "a"}, # env changes, zone is new
  161. maintenance=None,
  162. hostname="test-host",
  163. ip="192.168.1.100",
  164. ifname="eth0",
  165. port=8080,
  166. worker_uuid="test-uuid-123",
  167. cluster_id=1,
  168. status=WorkerStatus.get_default_status(),
  169. system_reserved=SystemReserved(ram=0, vram=0),
  170. )
  171. # Update the worker data
  172. updated_worker = update_worker_data(worker_in, existing=existing_worker)
  173. # Verify that labels are properly merged
  174. assert updated_worker.labels["env"] == "prod" # Updated
  175. assert updated_worker.labels["region"] == "us-west" # Preserved
  176. assert updated_worker.labels["zone"] == "a" # New
  177. def _ctx(*, is_platform_admin=False, current_principal_id=None, org_role=None):
  178. """Minimal TenantContext stub for the worker write paths.
  179. The handlers only consult ``is_platform_admin`` / ``current_principal_id`` /
  180. ``org_role`` (via ``assert_org_owned_writable`` and
  181. ``assert_cluster_resource_visible``), so a SimpleNamespace is enough.
  182. """
  183. return SimpleNamespace(
  184. user=SimpleNamespace(is_system=False),
  185. is_platform_admin=is_platform_admin,
  186. current_principal_id=current_principal_id,
  187. org_role=org_role,
  188. accessible_cluster_ids=set(),
  189. )
  190. def _worker(owner_principal_id=None, deleted_at=None):
  191. return SimpleNamespace(
  192. id=42,
  193. cluster_id=1,
  194. owner_principal_id=owner_principal_id,
  195. deleted_at=deleted_at,
  196. ssh_key_id=None,
  197. external_id=None,
  198. state=WorkerStateEnum.READY,
  199. )
  200. def _patch_worker_one_by_id(worker):
  201. """Stub Worker.one_by_id to return the given worker."""
  202. return patch(
  203. "gpustack.routes.workers.Worker.one_by_id",
  204. AsyncMock(return_value=worker),
  205. )
  206. @pytest.mark.asyncio
  207. async def test_delete_worker_allowed_for_org_admin_in_owning_org():
  208. worker = _worker(owner_principal_id=10)
  209. ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
  210. session = MagicMock()
  211. with (
  212. _patch_worker_one_by_id(worker),
  213. patch("gpustack.routes.workers.WorkerService") as service_cls,
  214. ):
  215. service = service_cls.return_value
  216. service.delete = AsyncMock()
  217. await delete_worker(ctx=ctx, session=session, id=worker.id)
  218. service.delete.assert_awaited_once()
  219. @pytest.mark.asyncio
  220. async def test_delete_worker_forbidden_for_plain_org_user():
  221. worker = _worker(owner_principal_id=10)
  222. ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
  223. session = MagicMock()
  224. with _patch_worker_one_by_id(worker):
  225. with pytest.raises(ForbiddenException):
  226. await delete_worker(ctx=ctx, session=session, id=worker.id)
  227. @pytest.mark.asyncio
  228. async def test_delete_worker_returns_not_found_for_other_org():
  229. """Cross-org access must 404, not 403, to avoid leaking row existence."""
  230. worker = _worker(owner_principal_id=10)
  231. ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN)
  232. session = MagicMock()
  233. with _patch_worker_one_by_id(worker):
  234. with pytest.raises(NotFoundException):
  235. await delete_worker(ctx=ctx, session=session, id=worker.id)
  236. @pytest.mark.asyncio
  237. async def test_delete_worker_on_globally_shared_cluster_requires_platform_admin():
  238. """Workers on a global cluster (owner is None) shared via cluster_access
  239. are visible to the recipient Org admin but not writable — only platform
  240. admin may mutate global rows."""
  241. worker = _worker(owner_principal_id=None)
  242. ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
  243. ctx.accessible_cluster_ids = {worker.cluster_id}
  244. session = MagicMock()
  245. with _patch_worker_one_by_id(worker):
  246. with pytest.raises(ForbiddenException):
  247. await delete_worker(ctx=ctx, session=session, id=worker.id)
  248. @pytest.mark.asyncio
  249. async def test_delete_worker_allowed_for_platform_admin():
  250. worker = _worker(owner_principal_id=10)
  251. ctx = _ctx(is_platform_admin=True, current_principal_id=None)
  252. session = MagicMock()
  253. with (
  254. _patch_worker_one_by_id(worker),
  255. patch("gpustack.routes.workers.WorkerService") as service_cls,
  256. ):
  257. service = service_cls.return_value
  258. service.delete = AsyncMock()
  259. await delete_worker(ctx=ctx, session=session, id=worker.id)
  260. service.delete.assert_awaited_once()
  261. @pytest.mark.asyncio
  262. async def test_delete_already_soft_deleted_worker_returns_not_found():
  263. """A soft-deleted worker must look like it doesn't exist for any caller."""
  264. import datetime
  265. worker = _worker(
  266. owner_principal_id=10,
  267. deleted_at=datetime.datetime.now(datetime.timezone.utc),
  268. )
  269. ctx = _ctx(is_platform_admin=True, current_principal_id=None)
  270. session = MagicMock()
  271. with _patch_worker_one_by_id(worker):
  272. with pytest.raises(NotFoundException):
  273. await delete_worker(ctx=ctx, session=session, id=worker.id)
  274. @pytest.mark.asyncio
  275. async def test_update_worker_allowed_for_org_admin_in_owning_org():
  276. worker = _worker(owner_principal_id=10)
  277. ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
  278. session = MagicMock()
  279. worker_in = WorkerUpdate(name="test-worker", maintenance=None)
  280. with (
  281. _patch_worker_one_by_id(worker),
  282. patch("gpustack.routes.workers.WorkerService") as service_cls,
  283. ):
  284. service = service_cls.return_value
  285. service.update = AsyncMock()
  286. result = await update_worker(
  287. ctx=ctx, session=session, id=worker.id, worker_in=worker_in
  288. )
  289. service.update.assert_awaited_once()
  290. assert result is worker
  291. @pytest.mark.asyncio
  292. async def test_update_worker_forbidden_for_plain_org_user():
  293. worker = _worker(owner_principal_id=10)
  294. ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
  295. session = MagicMock()
  296. worker_in = WorkerUpdate(name="test-worker", maintenance=None)
  297. with _patch_worker_one_by_id(worker):
  298. with pytest.raises(ForbiddenException):
  299. await update_worker(
  300. ctx=ctx, session=session, id=worker.id, worker_in=worker_in
  301. )
  302. @pytest.mark.asyncio
  303. async def test_get_worker_privatekey_forbidden_for_plain_org_user():
  304. """Private key is a write-class secret — same gate as delete/update."""
  305. worker = _worker(owner_principal_id=10)
  306. ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
  307. session = MagicMock()
  308. with _patch_worker_one_by_id(worker):
  309. with pytest.raises(ForbiddenException):
  310. await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)
  311. @pytest.mark.asyncio
  312. async def test_get_worker_privatekey_returns_not_found_for_other_org():
  313. worker = _worker(owner_principal_id=10)
  314. ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN)
  315. session = MagicMock()
  316. with _patch_worker_one_by_id(worker):
  317. with pytest.raises(NotFoundException):
  318. await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)
  319. @pytest.mark.asyncio
  320. async def test_update_already_soft_deleted_worker_returns_not_found():
  321. """Mirrors delete: update must treat a soft-deleted worker as absent."""
  322. import datetime
  323. worker = _worker(
  324. owner_principal_id=10,
  325. deleted_at=datetime.datetime.now(datetime.timezone.utc),
  326. )
  327. ctx = _ctx(is_platform_admin=True, current_principal_id=None)
  328. session = MagicMock()
  329. worker_in = WorkerUpdate(name="test-worker", maintenance=None)
  330. with _patch_worker_one_by_id(worker):
  331. with pytest.raises(NotFoundException):
  332. await update_worker(
  333. ctx=ctx, session=session, id=worker.id, worker_in=worker_in
  334. )
  335. @pytest.mark.asyncio
  336. async def test_get_worker_privatekey_for_soft_deleted_worker_returns_not_found():
  337. """Soft-deleted workers must not leak private keys."""
  338. import datetime
  339. worker = _worker(
  340. owner_principal_id=10,
  341. deleted_at=datetime.datetime.now(datetime.timezone.utc),
  342. )
  343. ctx = _ctx(is_platform_admin=True, current_principal_id=None)
  344. session = MagicMock()
  345. with _patch_worker_one_by_id(worker):
  346. with pytest.raises(NotFoundException):
  347. await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)