| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 |
- from types import SimpleNamespace
- from unittest.mock import AsyncMock, MagicMock, patch
- import pytest
- from gpustack.api.exceptions import ForbiddenException, NotFoundException
- from gpustack.routes.workers import (
- delete_worker,
- get_worker_privatekey,
- update_worker,
- update_worker_data,
- )
- from gpustack.schemas.principals import OrgRole
- from gpustack.schemas.workers import (
- Worker,
- WorkerCreate,
- WorkerStateEnum,
- WorkerStatus,
- WorkerUpdate,
- Maintenance,
- SystemReserved,
- )
- from gpustack.schemas.clusters import Cluster, ClusterProvider
- def test_update_worker_data_preserves_maintenance_mode():
- """
- Test that maintenance mode is preserved when a worker re-registers.
- This verifies the fix for the issue where workers automatically exit
- maintenance mode after restart.
- """
- # Create an existing worker with maintenance mode enabled
- existing_worker = Worker(
- id=1,
- name="test-worker",
- labels={"env": "test"},
- maintenance=Maintenance(enabled=True, message="Scheduled maintenance"),
- state=WorkerStateEnum.MAINTENANCE,
- cluster_id=1,
- hostname="test-host",
- ip="192.168.1.100",
- ifname="eth0",
- port=8080,
- worker_uuid="test-uuid-123",
- status=WorkerStatus.get_default_status(),
- )
- # Create a worker registration request without maintenance field
- # (simulating a worker restart/re-registration)
- worker_in = WorkerCreate(
- name="test-worker",
- labels={"env": "test", "new": "label"},
- maintenance=None, # Not set during re-registration
- hostname="test-host",
- ip="192.168.1.100",
- ifname="eth0",
- port=8080,
- worker_uuid="test-uuid-123",
- cluster_id=1,
- status=WorkerStatus.get_default_status(),
- system_reserved=SystemReserved(ram=0, vram=0),
- )
- # Update the worker data
- updated_worker = update_worker_data(worker_in, existing=existing_worker)
- # Verify that maintenance mode is preserved
- assert updated_worker.maintenance is not None
- assert updated_worker.maintenance.enabled is True
- assert updated_worker.maintenance.message == "Scheduled maintenance"
- # State will be computed as MAINTENANCE because of compute_state()
- assert updated_worker.state == WorkerStateEnum.MAINTENANCE
- def test_update_worker_data_can_disable_maintenance_mode():
- """
- Test that maintenance mode can be explicitly disabled when provided.
- """
- # Create an existing worker with maintenance mode enabled
- existing_worker = Worker(
- id=1,
- name="test-worker",
- labels={"env": "test"},
- maintenance=Maintenance(enabled=True, message="Scheduled maintenance"),
- state=WorkerStateEnum.MAINTENANCE,
- cluster_id=1,
- hostname="test-host",
- ip="192.168.1.100",
- ifname="eth0",
- port=8080,
- worker_uuid="test-uuid-123",
- status=WorkerStatus.get_default_status(),
- )
- # Create a worker update request with maintenance explicitly disabled
- worker_in = WorkerCreate(
- name="test-worker",
- labels={"env": "test"},
- maintenance=Maintenance(enabled=False, message=None),
- hostname="test-host",
- ip="192.168.1.100",
- ifname="eth0",
- port=8080,
- worker_uuid="test-uuid-123",
- cluster_id=1,
- status=WorkerStatus.get_default_status(),
- system_reserved=SystemReserved(ram=0, vram=0),
- )
- # Update the worker data
- updated_worker = update_worker_data(worker_in, existing=existing_worker)
- # Verify that maintenance mode is disabled
- assert updated_worker.maintenance is not None
- assert updated_worker.maintenance.enabled is False
- assert updated_worker.maintenance.message is None
- # State will be computed based on heartbeat, but maintenance is disabled
- # Since maintenance is disabled, the state won't be MAINTENANCE
- assert updated_worker.state != WorkerStateEnum.MAINTENANCE
- def test_update_worker_data_new_worker_without_maintenance():
- """
- Test that a new worker can be created without maintenance mode.
- """
- # Create a new worker registration request
- worker_in = WorkerCreate(
- name="new-worker",
- labels={"env": "prod"},
- maintenance=None,
- hostname="new-host",
- ip="192.168.1.101",
- ifname="eth0",
- port=8080,
- worker_uuid="new-uuid-456",
- cluster_id=1,
- status=WorkerStatus.get_default_status(),
- system_reserved=SystemReserved(ram=0, vram=0),
- )
- # Create cluster for new worker
- cluster = Cluster(
- id=1,
- name="test-cluster",
- provider=ClusterProvider.Docker,
- )
- # Create a new worker (no existing worker)
- new_worker = update_worker_data(worker_in, existing=None, cluster=cluster)
- # Verify that the new worker is created without maintenance mode
- assert new_worker.maintenance is None
- # State may be NOT_READY due to missing heartbeat, but not MAINTENANCE
- assert new_worker.state != WorkerStateEnum.MAINTENANCE
- def test_update_worker_data_preserves_labels_merge():
- """
- Test that labels are properly merged when updating a worker.
- """
- # Create an existing worker with some labels
- existing_worker = Worker(
- id=1,
- name="test-worker",
- labels={"env": "test", "region": "us-west"},
- maintenance=None,
- state=WorkerStateEnum.READY,
- cluster_id=1,
- hostname="test-host",
- ip="192.168.1.100",
- ifname="eth0",
- port=8080,
- worker_uuid="test-uuid-123",
- status=WorkerStatus.get_default_status(),
- )
- # Create a worker update with new labels
- worker_in = WorkerCreate(
- name="test-worker",
- labels={"env": "prod", "zone": "a"}, # env changes, zone is new
- maintenance=None,
- hostname="test-host",
- ip="192.168.1.100",
- ifname="eth0",
- port=8080,
- worker_uuid="test-uuid-123",
- cluster_id=1,
- status=WorkerStatus.get_default_status(),
- system_reserved=SystemReserved(ram=0, vram=0),
- )
- # Update the worker data
- updated_worker = update_worker_data(worker_in, existing=existing_worker)
- # Verify that labels are properly merged
- assert updated_worker.labels["env"] == "prod" # Updated
- assert updated_worker.labels["region"] == "us-west" # Preserved
- assert updated_worker.labels["zone"] == "a" # New
- def _ctx(*, is_platform_admin=False, current_principal_id=None, org_role=None):
- """Minimal TenantContext stub for the worker write paths.
- The handlers only consult ``is_platform_admin`` / ``current_principal_id`` /
- ``org_role`` (via ``assert_org_owned_writable`` and
- ``assert_cluster_resource_visible``), so a SimpleNamespace is enough.
- """
- return SimpleNamespace(
- user=SimpleNamespace(is_system=False),
- is_platform_admin=is_platform_admin,
- current_principal_id=current_principal_id,
- org_role=org_role,
- accessible_cluster_ids=set(),
- )
- def _worker(owner_principal_id=None, deleted_at=None):
- return SimpleNamespace(
- id=42,
- cluster_id=1,
- owner_principal_id=owner_principal_id,
- deleted_at=deleted_at,
- ssh_key_id=None,
- external_id=None,
- state=WorkerStateEnum.READY,
- )
- def _patch_worker_one_by_id(worker):
- """Stub Worker.one_by_id to return the given worker."""
- return patch(
- "gpustack.routes.workers.Worker.one_by_id",
- AsyncMock(return_value=worker),
- )
- @pytest.mark.asyncio
- async def test_delete_worker_allowed_for_org_admin_in_owning_org():
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
- session = MagicMock()
- with (
- _patch_worker_one_by_id(worker),
- patch("gpustack.routes.workers.WorkerService") as service_cls,
- ):
- service = service_cls.return_value
- service.delete = AsyncMock()
- await delete_worker(ctx=ctx, session=session, id=worker.id)
- service.delete.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_delete_worker_forbidden_for_plain_org_user():
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(ForbiddenException):
- await delete_worker(ctx=ctx, session=session, id=worker.id)
- @pytest.mark.asyncio
- async def test_delete_worker_returns_not_found_for_other_org():
- """Cross-org access must 404, not 403, to avoid leaking row existence."""
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN)
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(NotFoundException):
- await delete_worker(ctx=ctx, session=session, id=worker.id)
- @pytest.mark.asyncio
- async def test_delete_worker_on_globally_shared_cluster_requires_platform_admin():
- """Workers on a global cluster (owner is None) shared via cluster_access
- are visible to the recipient Org admin but not writable — only platform
- admin may mutate global rows."""
- worker = _worker(owner_principal_id=None)
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
- ctx.accessible_cluster_ids = {worker.cluster_id}
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(ForbiddenException):
- await delete_worker(ctx=ctx, session=session, id=worker.id)
- @pytest.mark.asyncio
- async def test_delete_worker_allowed_for_platform_admin():
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(is_platform_admin=True, current_principal_id=None)
- session = MagicMock()
- with (
- _patch_worker_one_by_id(worker),
- patch("gpustack.routes.workers.WorkerService") as service_cls,
- ):
- service = service_cls.return_value
- service.delete = AsyncMock()
- await delete_worker(ctx=ctx, session=session, id=worker.id)
- service.delete.assert_awaited_once()
- @pytest.mark.asyncio
- async def test_delete_already_soft_deleted_worker_returns_not_found():
- """A soft-deleted worker must look like it doesn't exist for any caller."""
- import datetime
- worker = _worker(
- owner_principal_id=10,
- deleted_at=datetime.datetime.now(datetime.timezone.utc),
- )
- ctx = _ctx(is_platform_admin=True, current_principal_id=None)
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(NotFoundException):
- await delete_worker(ctx=ctx, session=session, id=worker.id)
- @pytest.mark.asyncio
- async def test_update_worker_allowed_for_org_admin_in_owning_org():
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.ADMIN)
- session = MagicMock()
- worker_in = WorkerUpdate(name="test-worker", maintenance=None)
- with (
- _patch_worker_one_by_id(worker),
- patch("gpustack.routes.workers.WorkerService") as service_cls,
- ):
- service = service_cls.return_value
- service.update = AsyncMock()
- result = await update_worker(
- ctx=ctx, session=session, id=worker.id, worker_in=worker_in
- )
- service.update.assert_awaited_once()
- assert result is worker
- @pytest.mark.asyncio
- async def test_update_worker_forbidden_for_plain_org_user():
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
- session = MagicMock()
- worker_in = WorkerUpdate(name="test-worker", maintenance=None)
- with _patch_worker_one_by_id(worker):
- with pytest.raises(ForbiddenException):
- await update_worker(
- ctx=ctx, session=session, id=worker.id, worker_in=worker_in
- )
- @pytest.mark.asyncio
- async def test_get_worker_privatekey_forbidden_for_plain_org_user():
- """Private key is a write-class secret — same gate as delete/update."""
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=10, org_role=OrgRole.USER)
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(ForbiddenException):
- await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)
- @pytest.mark.asyncio
- async def test_get_worker_privatekey_returns_not_found_for_other_org():
- worker = _worker(owner_principal_id=10)
- ctx = _ctx(current_principal_id=99, org_role=OrgRole.ADMIN)
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(NotFoundException):
- await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)
- @pytest.mark.asyncio
- async def test_update_already_soft_deleted_worker_returns_not_found():
- """Mirrors delete: update must treat a soft-deleted worker as absent."""
- import datetime
- worker = _worker(
- owner_principal_id=10,
- deleted_at=datetime.datetime.now(datetime.timezone.utc),
- )
- ctx = _ctx(is_platform_admin=True, current_principal_id=None)
- session = MagicMock()
- worker_in = WorkerUpdate(name="test-worker", maintenance=None)
- with _patch_worker_one_by_id(worker):
- with pytest.raises(NotFoundException):
- await update_worker(
- ctx=ctx, session=session, id=worker.id, worker_in=worker_in
- )
- @pytest.mark.asyncio
- async def test_get_worker_privatekey_for_soft_deleted_worker_returns_not_found():
- """Soft-deleted workers must not leak private keys."""
- import datetime
- worker = _worker(
- owner_principal_id=10,
- deleted_at=datetime.datetime.now(datetime.timezone.utc),
- )
- ctx = _ctx(is_platform_admin=True, current_principal_id=None)
- session = MagicMock()
- with _patch_worker_one_by_id(worker):
- with pytest.raises(NotFoundException):
- await get_worker_privatekey(ctx=ctx, session=session, id=worker.id)
|