| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- from unittest.mock import AsyncMock, MagicMock, patch
- import pytest
- from gpustack.schemas.models import (
- DistributedServerCoordinateModeEnum,
- DistributedServers,
- ModelInstanceSubordinateWorker,
- ModelInstanceStateEnum,
- )
- from gpustack.schemas.workers import WorkerStateEnum
- from gpustack.server.bus import Event, EventType
- from gpustack.server.controllers import WorkerController
- from tests.fixtures.workers.fixtures import (
- linux_nvidia_1_4090_24gx1,
- linux_nvidia_2_4080_16gx2,
- )
- from tests.utils.mock import mock_async_session
- from tests.utils.model import new_model_instance
- @pytest.mark.asyncio
- async def test_worker_controller_marks_distributed_subordinate_unreachable_when_worker_unreachable():
- main_worker = linux_nvidia_1_4090_24gx1()
- subordinate_worker = linux_nvidia_2_4080_16gx2()
- subordinate_worker.state = WorkerStateEnum.UNREACHABLE
- subordinate_worker.unreachable = True
- instance = new_model_instance(
- 1,
- "distributed-instance",
- 1,
- worker_id=main_worker.id,
- state=ModelInstanceStateEnum.RUNNING,
- )
- instance.worker_name = main_worker.name
- instance.distributed_servers = DistributedServers(
- mode=DistributedServerCoordinateModeEnum.RUN_FIRST,
- subordinate_workers=[
- ModelInstanceSubordinateWorker(
- worker_id=subordinate_worker.id,
- worker_name=subordinate_worker.name,
- worker_ip=subordinate_worker.ip,
- state=ModelInstanceStateEnum.RUNNING,
- )
- ],
- )
- update = AsyncMock(return_value=instance.name)
- event = Event(
- type=EventType.UPDATED,
- data=subordinate_worker,
- changed_fields={"state": (WorkerStateEnum.READY, WorkerStateEnum.UNREACHABLE)},
- )
- with (
- patch(
- "gpustack.server.controllers.async_session",
- return_value=mock_async_session(),
- ),
- patch(
- "gpustack.server.controllers.ModelInstance.all_by_field",
- AsyncMock(return_value=[instance]),
- ),
- patch(
- "gpustack.server.controllers.ModelInstanceService.update",
- update,
- ),
- ):
- controller = WorkerController(MagicMock())
- await controller._reconcile(event)
- update.assert_awaited_once()
- updated_instance = update.await_args.args[-2]
- patch_dict = update.await_args.args[-1]
- assert updated_instance.name == instance.name
- assert "state" not in patch_dict
- assert "state_message" not in patch_dict
- subordinate_patch = patch_dict["distributed_servers"].subordinate_workers[0]
- assert subordinate_patch.state == ModelInstanceStateEnum.UNREACHABLE
- assert subordinate_patch.state_message == "Worker is unreachable from the server"
- @pytest.mark.asyncio
- async def test_worker_controller_deletes_distributed_instance_when_subordinate_worker_deleted():
- main_worker = linux_nvidia_1_4090_24gx1()
- subordinate_worker = linux_nvidia_2_4080_16gx2()
- subordinate_worker.state = WorkerStateEnum.READY
- instance = new_model_instance(
- 1,
- "distributed-instance",
- 1,
- worker_id=main_worker.id,
- state=ModelInstanceStateEnum.RUNNING,
- )
- instance.worker_name = main_worker.name
- instance.distributed_servers = DistributedServers(
- mode=DistributedServerCoordinateModeEnum.RUN_FIRST,
- subordinate_workers=[
- ModelInstanceSubordinateWorker(
- worker_id=subordinate_worker.id,
- worker_name=subordinate_worker.name,
- worker_ip=subordinate_worker.ip,
- state=ModelInstanceStateEnum.RUNNING,
- )
- ],
- )
- batch_delete = AsyncMock(return_value=[instance.name])
- event = Event(type=EventType.DELETED, data=subordinate_worker)
- with (
- patch(
- "gpustack.server.controllers.async_session",
- return_value=mock_async_session(),
- ),
- patch(
- "gpustack.server.controllers.ModelInstance.all_by_field",
- AsyncMock(return_value=[instance]),
- ),
- patch(
- "gpustack.server.controllers.ModelInstanceService.batch_delete",
- batch_delete,
- ),
- ):
- controller = WorkerController(MagicMock())
- await controller._reconcile(event)
- batch_delete.assert_awaited_once()
- deleted_instances = batch_delete.await_args.args[-1]
- assert [item.name for item in deleted_instances] == [instance.name]
|