test_worker_controller.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from unittest.mock import AsyncMock, MagicMock, patch
  2. import pytest
  3. from gpustack.schemas.models import (
  4. DistributedServerCoordinateModeEnum,
  5. DistributedServers,
  6. ModelInstanceSubordinateWorker,
  7. ModelInstanceStateEnum,
  8. )
  9. from gpustack.schemas.workers import WorkerStateEnum
  10. from gpustack.server.bus import Event, EventType
  11. from gpustack.server.controllers import WorkerController
  12. from tests.fixtures.workers.fixtures import (
  13. linux_nvidia_1_4090_24gx1,
  14. linux_nvidia_2_4080_16gx2,
  15. )
  16. from tests.utils.mock import mock_async_session
  17. from tests.utils.model import new_model_instance
  18. @pytest.mark.asyncio
  19. async def test_worker_controller_marks_distributed_subordinate_unreachable_when_worker_unreachable():
  20. main_worker = linux_nvidia_1_4090_24gx1()
  21. subordinate_worker = linux_nvidia_2_4080_16gx2()
  22. subordinate_worker.state = WorkerStateEnum.UNREACHABLE
  23. subordinate_worker.unreachable = True
  24. instance = new_model_instance(
  25. 1,
  26. "distributed-instance",
  27. 1,
  28. worker_id=main_worker.id,
  29. state=ModelInstanceStateEnum.RUNNING,
  30. )
  31. instance.worker_name = main_worker.name
  32. instance.distributed_servers = DistributedServers(
  33. mode=DistributedServerCoordinateModeEnum.RUN_FIRST,
  34. subordinate_workers=[
  35. ModelInstanceSubordinateWorker(
  36. worker_id=subordinate_worker.id,
  37. worker_name=subordinate_worker.name,
  38. worker_ip=subordinate_worker.ip,
  39. state=ModelInstanceStateEnum.RUNNING,
  40. )
  41. ],
  42. )
  43. update = AsyncMock(return_value=instance.name)
  44. event = Event(
  45. type=EventType.UPDATED,
  46. data=subordinate_worker,
  47. changed_fields={"state": (WorkerStateEnum.READY, WorkerStateEnum.UNREACHABLE)},
  48. )
  49. with (
  50. patch(
  51. "gpustack.server.controllers.async_session",
  52. return_value=mock_async_session(),
  53. ),
  54. patch(
  55. "gpustack.server.controllers.ModelInstance.all_by_field",
  56. AsyncMock(return_value=[instance]),
  57. ),
  58. patch(
  59. "gpustack.server.controllers.ModelInstanceService.update",
  60. update,
  61. ),
  62. ):
  63. controller = WorkerController(MagicMock())
  64. await controller._reconcile(event)
  65. update.assert_awaited_once()
  66. updated_instance = update.await_args.args[-2]
  67. patch_dict = update.await_args.args[-1]
  68. assert updated_instance.name == instance.name
  69. assert "state" not in patch_dict
  70. assert "state_message" not in patch_dict
  71. subordinate_patch = patch_dict["distributed_servers"].subordinate_workers[0]
  72. assert subordinate_patch.state == ModelInstanceStateEnum.UNREACHABLE
  73. assert subordinate_patch.state_message == "Worker is unreachable from the server"
  74. @pytest.mark.asyncio
  75. async def test_worker_controller_deletes_distributed_instance_when_subordinate_worker_deleted():
  76. main_worker = linux_nvidia_1_4090_24gx1()
  77. subordinate_worker = linux_nvidia_2_4080_16gx2()
  78. subordinate_worker.state = WorkerStateEnum.READY
  79. instance = new_model_instance(
  80. 1,
  81. "distributed-instance",
  82. 1,
  83. worker_id=main_worker.id,
  84. state=ModelInstanceStateEnum.RUNNING,
  85. )
  86. instance.worker_name = main_worker.name
  87. instance.distributed_servers = DistributedServers(
  88. mode=DistributedServerCoordinateModeEnum.RUN_FIRST,
  89. subordinate_workers=[
  90. ModelInstanceSubordinateWorker(
  91. worker_id=subordinate_worker.id,
  92. worker_name=subordinate_worker.name,
  93. worker_ip=subordinate_worker.ip,
  94. state=ModelInstanceStateEnum.RUNNING,
  95. )
  96. ],
  97. )
  98. batch_delete = AsyncMock(return_value=[instance.name])
  99. event = Event(type=EventType.DELETED, data=subordinate_worker)
  100. with (
  101. patch(
  102. "gpustack.server.controllers.async_session",
  103. return_value=mock_async_session(),
  104. ),
  105. patch(
  106. "gpustack.server.controllers.ModelInstance.all_by_field",
  107. AsyncMock(return_value=[instance]),
  108. ),
  109. patch(
  110. "gpustack.server.controllers.ModelInstanceService.batch_delete",
  111. batch_delete,
  112. ),
  113. ):
  114. controller = WorkerController(MagicMock())
  115. await controller._reconcile(event)
  116. batch_delete.assert_awaited_once()
  117. deleted_instances = batch_delete.await_args.args[-1]
  118. assert [item.name for item in deleted_instances] == [instance.name]