test_worker_instance_cleaner.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from datetime import datetime, timedelta, timezone
  2. from unittest.mock import AsyncMock, patch
  3. import pytest
  4. from gpustack.schemas.models import (
  5. DistributedServerCoordinateModeEnum,
  6. DistributedServers,
  7. ModelInstanceSubordinateWorker,
  8. ModelInstanceStateEnum,
  9. )
  10. from gpustack.schemas.workers import WorkerStateEnum
  11. from gpustack.server.worker_instance_cleaner import WorkerInstanceCleaner
  12. from tests.fixtures.workers.fixtures import (
  13. linux_nvidia_1_4090_24gx1,
  14. linux_nvidia_2_4080_16gx2,
  15. )
  16. from tests.utils.mock import mock_async_session
  17. from tests.utils.model import new_model_instance
  18. def _offline_worker(worker, *, heartbeat_age_seconds: int = 600):
  19. worker.state = WorkerStateEnum.NOT_READY
  20. worker.maintenance = None
  21. worker.heartbeat_time = datetime.now(timezone.utc) - timedelta(
  22. seconds=heartbeat_age_seconds
  23. )
  24. return worker
  25. @pytest.mark.asyncio
  26. async def test_cleanup_offline_worker_instances_deletes_main_worker_instances():
  27. offline_worker = _offline_worker(linux_nvidia_1_4090_24gx1())
  28. other_worker = linux_nvidia_2_4080_16gx2()
  29. instance = new_model_instance(
  30. 1,
  31. "main-worker-instance",
  32. 1,
  33. worker_id=offline_worker.id,
  34. state=ModelInstanceStateEnum.RUNNING,
  35. )
  36. instance.worker_name = offline_worker.name
  37. batch_delete = AsyncMock(return_value=[instance.name])
  38. with (
  39. patch(
  40. "gpustack.server.worker_instance_cleaner.async_session",
  41. return_value=mock_async_session(),
  42. ),
  43. patch(
  44. "gpustack.server.worker_instance_cleaner.Worker.all",
  45. AsyncMock(return_value=[offline_worker, other_worker]),
  46. ),
  47. patch(
  48. "gpustack.server.worker_instance_cleaner.ModelInstance.all",
  49. AsyncMock(return_value=[instance]),
  50. ),
  51. patch(
  52. "gpustack.server.worker_instance_cleaner.ModelInstanceService.batch_delete",
  53. batch_delete,
  54. ),
  55. patch(
  56. "gpustack.server.worker_instance_cleaner.envs.MODEL_INSTANCE_RESCHEDULE_GRACE_PERIOD",
  57. 300,
  58. ),
  59. ):
  60. cleaner = WorkerInstanceCleaner()
  61. await cleaner._cleanup_offline_worker_instances()
  62. batch_delete.assert_awaited_once()
  63. deleted_instances = batch_delete.await_args.args[-1]
  64. assert [item.name for item in deleted_instances] == [instance.name]
  65. @pytest.mark.asyncio
  66. async def test_cleanup_offline_worker_instances_deletes_distributed_instances_with_offline_subordinate():
  67. main_worker = linux_nvidia_1_4090_24gx1()
  68. offline_subordinate = _offline_worker(linux_nvidia_2_4080_16gx2())
  69. instance = new_model_instance(
  70. 1,
  71. "distributed-instance",
  72. 1,
  73. worker_id=main_worker.id,
  74. state=ModelInstanceStateEnum.RUNNING,
  75. )
  76. instance.worker_name = main_worker.name
  77. instance.distributed_servers = DistributedServers(
  78. mode=DistributedServerCoordinateModeEnum.RUN_FIRST,
  79. subordinate_workers=[
  80. ModelInstanceSubordinateWorker(
  81. worker_id=offline_subordinate.id,
  82. worker_name=offline_subordinate.name,
  83. worker_ip=offline_subordinate.ip,
  84. state=ModelInstanceStateEnum.RUNNING,
  85. )
  86. ],
  87. )
  88. batch_delete = AsyncMock(return_value=[instance.name])
  89. with (
  90. patch(
  91. "gpustack.server.worker_instance_cleaner.async_session",
  92. return_value=mock_async_session(),
  93. ),
  94. patch(
  95. "gpustack.server.worker_instance_cleaner.Worker.all",
  96. AsyncMock(return_value=[main_worker, offline_subordinate]),
  97. ),
  98. patch(
  99. "gpustack.server.worker_instance_cleaner.ModelInstance.all",
  100. AsyncMock(return_value=[instance]),
  101. ),
  102. patch(
  103. "gpustack.server.worker_instance_cleaner.ModelInstanceService.batch_delete",
  104. batch_delete,
  105. ),
  106. patch(
  107. "gpustack.server.worker_instance_cleaner.envs.MODEL_INSTANCE_RESCHEDULE_GRACE_PERIOD",
  108. 300,
  109. ),
  110. ):
  111. cleaner = WorkerInstanceCleaner()
  112. await cleaner._cleanup_offline_worker_instances()
  113. batch_delete.assert_awaited_once()
  114. deleted_instances = batch_delete.await_args.args[-1]
  115. assert [item.name for item in deleted_instances] == [instance.name]