from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest from gpustack.api.exceptions import ForbiddenException, NotFoundException from gpustack.routes.workers import ( delete_worker, get_worker_privatekey, update_worker, update_worker_data, ) from gpustack.schemas.principals import OrgRole from gpustack.schemas.workers import ( Worker, WorkerCreate, WorkerStateEnum, WorkerStatus, WorkerUpdate, Maintenance, SystemReserved, ) from gpustack.schemas.clusters import Cluster, ClusterProvider def test_update_worker_data_preserves_maintenance_mode(): """ Test that maintenance mode is preserved when a worker re-registers. This verifies the fix for the issue where workers automatically exit maintenance mode after restart. """ # Create an existing worker with maintenance mode enabled existing_worker = Worker( id=1, name="test-worker", labels={"env": "test"}, maintenance=Maintenance(enabled=True, message="Scheduled maintenance"), state=WorkerStateEnum.MAINTENANCE, cluster_id=1, hostname="test-host", ip="192.168.1.100", ifname="eth0", port=8080, worker_uuid="test-uuid-123", status=WorkerStatus.get_default_status(), ) # Create a worker registration request without maintenance field # (simulating a worker restart/re-registration) worker_in = WorkerCreate( name="test-worker", labels={"env": "test", "new": "label"}, maintenance=None, # Not set during re-registration hostname="test-host", ip="192.168.1.100", ifname="eth0", port=8080, worker_uuid="test-uuid-123", cluster_id=1, status=WorkerStatus.get_default_status(), system_reserved=SystemReserved(ram=0, vram=0), ) # Update the worker data updated_worker = update_worker_data(worker_in, existing=existing_worker) # Verify that maintenance mode is preserved assert updated_worker.maintenance is not None assert updated_worker.maintenance.enabled is True assert updated_worker.maintenance.message == "Scheduled maintenance" # State will be computed as MAINTENANCE because of compute_state() assert updated_worker.state == WorkerStateEnum.MAINTENANCE def test_update_worker_data_can_disable_maintenance_mode(): """ Test that maintenance mode can be explicitly disabled when provided. """ # Create an existing worker with maintenance mode enabled existing_worker = Worker( id=1, name="test-worker", labels={"env": "test"}, maintenance=Maintenance(enabled=True, message="Scheduled maintenance"), state=WorkerStateEnum.MAINTENANCE, cluster_id=1, hostname="test-host", ip="192.168.1.100", ifname="eth0", port=8080, worker_uuid="test-uuid-123", status=WorkerStatus.get_default_status(), ) # Create a worker update request with maintenance explicitly disabled worker_in = WorkerCreate( name="test-worker", labels={"env": "test"}, maintenance=Maintenance(enabled=False, message=None), hostname="test-host", ip="192.168.1.100", ifname="eth0", port=8080, worker_uuid="test-uuid-123", cluster_id=1, status=WorkerStatus.get_default_status(), system_reserved=SystemReserved(ram=0, vram=0), ) # Update the worker data updated_worker = update_worker_data(worker_in, existing=existing_worker) # Verify that maintenance mode is disabled assert updated_worker.maintenance is not None assert updated_worker.maintenance.enabled is False assert updated_worker.maintenance.message is None # State will be computed based on heartbeat, but maintenance is disabled # Since maintenance is disabled, the state won't be MAINTENANCE assert updated_worker.state != WorkerStateEnum.MAINTENANCE def test_update_worker_data_new_worker_without_maintenance(): """ Test that a new worker can be created without maintenance mode. """ # Create a new worker registration request worker_in = WorkerCreate( name="new-worker", labels={"env": "prod"}, maintenance=None, hostname="new-host", ip="192.168.1.101", ifname="eth0", port=8080, worker_uuid="new-uuid-456", cluster_id=1, status=WorkerStatus.get_default_status(), system_reserved=SystemReserved(ram=0, vram=0), ) # Create cluster for new worker cluster = Cluster( id=1, name="test-cluster", provider=ClusterProvider.Docker, ) # Create a new worker (no existing worker) new_worker = update_worker_data(worker_in, existing=None, cluster=cluster) # Verify that the new worker is created without maintenance mode assert new_worker.maintenance is None # State may be NOT_READY due to missing heartbeat, but not MAINTENANCE assert new_worker.state != WorkerStateEnum.MAINTENANCE def test_update_worker_data_preserves_labels_merge(): """ Test that labels are properly merged when updating a worker. """ # Create an existing worker with some labels existing_worker = Worker( id=1, name="test-worker", labels={"env": "test", "region": "us-west"}, maintenance=None, state=WorkerStateEnum.READY, cluster_id=1, hostname="test-host", ip="192.168.1.100", ifname="eth0", port=8080, worker_uuid="test-uuid-123", status=WorkerStatus.get_default_status(), ) # Create a worker update with new labels worker_in = WorkerCreate( name="test-worker", labels={"env": "prod", "zone": "a"}, # env changes, zone is new maintenance=None, hostname="test-host", ip="192.168.1.100", ifname="eth0", port=8080, worker_uuid="test-uuid-123", cluster_id=1, status=WorkerStatus.get_default_status(), system_reserved=SystemReserved(ram=0, vram=0), ) # Update the worker data updated_worker = update_worker_data(worker_in, existing=existing_worker) # Verify that labels are properly merged assert updated_worker.labels["env"] == "prod" # Updated assert updated_worker.labels["region"] == "us-west" # Preserved assert updated_worker.labels["zone"] == "a" # New def _ctx(*, is_platform_admin=False, current_principal_id=None, org_role=None): """Minimal TenantContext stub for the worker write paths. The handlers only consult ``is_platform_admin`` / ``current_principal_id`` / ``org_role`` (via ``assert_org_owned_writable`` and ``assert_cluster_resource_visible``), so a SimpleNamespace is enough. """ return SimpleNamespace( user=SimpleNamespace(is_system=False), is_platform_admin=is_platform_admin, current_principal_id=current_principal_id, org_role=org_role, accessible_cluster_ids=set(), ) def _worker(owner_principal_id=None, deleted_at=None): return SimpleNamespace( id=42, cluster_id=1, owner_principal_id=owner_principal_id, deleted_at=deleted_at, ssh_key_id=None, external_id=None, state=WorkerStateEnum.READY, ) def _patch_worker_one_by_id(worker): """Stub Worker.one_by_id to return the given worker.""" return patch( "gpustack.routes.workers.Worker.one_by_id", AsyncMock(return_value=worker), ) @pytest.mark.asyncio async def test_delete_worker_allowed_for_org_admin_in_owning_org(): worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN) session = MagicMock() with ( _patch_worker_one_by_id(worker), patch("gpustack.routes.workers.WorkerService") as service_cls, ): service = service_cls.return_value service.delete = AsyncMock() await delete_worker(ctx=ctx, session=session, id=worker.id) service.delete.assert_awaited_once() @pytest.mark.asyncio async def test_delete_worker_forbidden_for_plain_org_user(): worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER) session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(ForbiddenException): await delete_worker(ctx=ctx, session=session, id=worker.id) @pytest.mark.asyncio async def test_delete_worker_returns_not_found_for_other_org(): """Cross-org access must 404, not 403, to avoid leaking row existence.""" worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN) session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(NotFoundException): await delete_worker(ctx=ctx, session=session, id=worker.id) @pytest.mark.asyncio async def test_delete_worker_on_globally_shared_cluster_requires_platform_admin(): """Workers on a global cluster (owner is None) shared via cluster_access are visible to the recipient Org admin but not writable — only platform admin may mutate global rows.""" worker = _worker(owner_principal_id=None) ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN) ctx.accessible_cluster_ids = {worker.cluster_id} session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(ForbiddenException): await delete_worker(ctx=ctx, session=session, id=worker.id) @pytest.mark.asyncio async def test_delete_worker_allowed_for_platform_admin(): worker = _worker(owner_principal_id=10) ctx = _ctx(is_platform_admin=True, current_principal_id=None) session = MagicMock() with ( _patch_worker_one_by_id(worker), patch("gpustack.routes.workers.WorkerService") as service_cls, ): service = service_cls.return_value service.delete = AsyncMock() await delete_worker(ctx=ctx, session=session, id=worker.id) service.delete.assert_awaited_once() @pytest.mark.asyncio async def test_delete_already_soft_deleted_worker_returns_not_found(): """A soft-deleted worker must look like it doesn't exist for any caller.""" import datetime worker = _worker( owner_principal_id=10, deleted_at=datetime.datetime.now(datetime.timezone.utc), ) ctx = _ctx(is_platform_admin=True, current_principal_id=None) session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(NotFoundException): await delete_worker(ctx=ctx, session=session, id=worker.id) @pytest.mark.asyncio async def test_update_worker_allowed_for_org_admin_in_owning_org(): worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN) session = MagicMock() worker_in = WorkerUpdate(name="test-worker", maintenance=None) with ( _patch_worker_one_by_id(worker), patch("gpustack.routes.workers.WorkerService") as service_cls, ): service = service_cls.return_value service.update = AsyncMock() result = await update_worker( ctx=ctx, session=session, id=worker.id, worker_in=worker_in ) service.update.assert_awaited_once() assert result is worker @pytest.mark.asyncio async def test_update_worker_forbidden_for_plain_org_user(): worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER) session = MagicMock() worker_in = WorkerUpdate(name="test-worker", maintenance=None) with _patch_worker_one_by_id(worker): with pytest.raises(ForbiddenException): await update_worker( ctx=ctx, session=session, id=worker.id, worker_in=worker_in ) @pytest.mark.asyncio async def test_get_worker_privatekey_forbidden_for_plain_org_user(): """Private key is a write-class secret — same gate as delete/update.""" worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER) session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(ForbiddenException): await get_worker_privatekey(ctx=ctx, session=session, id=worker.id) @pytest.mark.asyncio async def test_get_worker_privatekey_returns_not_found_for_other_org(): worker = _worker(owner_principal_id=10) ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN) session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(NotFoundException): await get_worker_privatekey(ctx=ctx, session=session, id=worker.id) @pytest.mark.asyncio async def test_update_already_soft_deleted_worker_returns_not_found(): """Mirrors delete: update must treat a soft-deleted worker as absent.""" import datetime worker = _worker( owner_principal_id=10, deleted_at=datetime.datetime.now(datetime.timezone.utc), ) ctx = _ctx(is_platform_admin=True, current_principal_id=None) session = MagicMock() worker_in = WorkerUpdate(name="test-worker", maintenance=None) with _patch_worker_one_by_id(worker): with pytest.raises(NotFoundException): await update_worker( ctx=ctx, session=session, id=worker.id, worker_in=worker_in ) @pytest.mark.asyncio async def test_get_worker_privatekey_for_soft_deleted_worker_returns_not_found(): """Soft-deleted workers must not leak private keys.""" import datetime worker = _worker( owner_principal_id=10, deleted_at=datetime.datetime.now(datetime.timezone.utc), ) ctx = _ctx(is_platform_admin=True, current_principal_id=None) session = MagicMock() with _patch_worker_one_by_id(worker): with pytest.raises(NotFoundException): await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)