test_provisioning.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import pytest
  2. from unittest.mock import AsyncMock, MagicMock
  3. from gpustack.schemas.workers import Worker, WorkerStateEnum
  4. from gpustack.schemas.clusters import (
  5. Cluster,
  6. WorkerPool,
  7. CloudCredential,
  8. ClusterProvider,
  9. ClusterStateEnum,
  10. CloudOptions,
  11. )
  12. from gpustack.server.controllers import WorkerProvisioningController
  13. from gpustack.cloud_providers.abstract import InstanceState, Volume
  14. @pytest.mark.asyncio
  15. async def test_provisioning_flow(monkeypatch):
  16. session = AsyncMock()
  17. session.info = {}
  18. client = AsyncMock()
  19. cluster = Cluster(
  20. id=1, provider=ClusterProvider.DigitalOcean, region="nyc3", credential_id=1
  21. )
  22. cluster.state = ClusterStateEnum.PROVISIONED
  23. pool = WorkerPool(
  24. id=1,
  25. cluster=cluster,
  26. cloud_options=CloudOptions(
  27. volumes=[
  28. Volume(size_gb=10, format="ext4"),
  29. Volume(size_gb=20, format="ext4"),
  30. ]
  31. ),
  32. )
  33. worker = Worker(
  34. id=1,
  35. name="test-worker",
  36. cluster=cluster,
  37. worker_pool=pool,
  38. state=WorkerStateEnum.PENDING,
  39. provider_config={},
  40. cluster_id=1,
  41. )
  42. credential = CloudCredential(id=1, token="dummy")
  43. cfg = MagicMock()
  44. cfg.server_external_url = "http://dummy-server"
  45. cfg.image_name_override = "dummy-image"
  46. monkeypatch.setattr("gpustack.config.config.get_global_config", lambda: cfg)
  47. mock_sshkey = MagicMock()
  48. mock_sshkey.id = "ssh-key-id"
  49. monkeypatch.setattr(
  50. "gpustack.schemas.clusters.Credential.create",
  51. AsyncMock(return_value=mock_sshkey),
  52. )
  53. monkeypatch.setattr(
  54. "gpustack.cloud_providers.common.get_client_from_provider",
  55. lambda provider, credential: client,
  56. )
  57. monkeypatch.setattr(
  58. "gpustack.schemas.clusters.Credential.one_by_id",
  59. AsyncMock(return_value=MagicMock(id=1, external_id="ssh-key-id")),
  60. )
  61. monkeypatch.setattr(
  62. "gpustack.schemas.workers.Worker.one_by_id", AsyncMock(return_value=worker)
  63. )
  64. monkeypatch.setattr(
  65. "gpustack.schemas.clusters.CloudCredential.one_by_id",
  66. AsyncMock(return_value=credential),
  67. )
  68. monkeypatch.setattr("gpustack.server.services.WorkerService.update", AsyncMock())
  69. mock_instance = MagicMock()
  70. mock_instance.id = "instance-id"
  71. client.get_instance = AsyncMock(return_value=mock_instance)
  72. client.create_ssh_key = AsyncMock(return_value="ssh-key-id")
  73. mock_user_data = MagicMock()
  74. mock_user_data.format.return_value = "#!/bin/bash\necho hello"
  75. client.construct_user_data = AsyncMock(return_value=mock_user_data)
  76. client.create_instance = AsyncMock(return_value="instance-id")
  77. client.wait_for_started = AsyncMock(return_value={"id": "instance-id"})
  78. client.wait_for_public_ip = AsyncMock(
  79. return_value={"id": "instance-id", "ip_address": "1.2.3.4"}
  80. )
  81. client.determine_linux_distribution = AsyncMock(return_value=("ubuntu", True))
  82. client.create_volumes_and_attach = AsyncMock(return_value=["vol-1", "vol-2"])
  83. # First call, should enter the SSH key creation process
  84. await WorkerProvisioningController._provisioning_instance(
  85. session, client, worker, cfg
  86. )
  87. assert worker.state == WorkerStateEnum.PROVISIONING
  88. assert worker.state_message == "Creating SSH key"
  89. # Second call, should create SSH key and assign to worker.ssh_key_id
  90. # Here, simulate SSH key not yet created, worker.ssh_key_id should be assigned
  91. await WorkerProvisioningController._provisioning_instance(
  92. session, client, worker, cfg
  93. )
  94. assert worker.ssh_key_id == "ssh-key-id"
  95. assert worker.state_message == "Creating cloud instance"
  96. # Third call, should enter the cloud instance creation process
  97. await WorkerProvisioningController._provisioning_instance(
  98. session, client, worker, cfg
  99. )
  100. assert worker.external_id == "instance-id"
  101. assert worker.state_message == "Waiting for cloud instance started"
  102. # Fourth call, should wait for cloud instance to start
  103. client.wait_for_started.return_value = {"id": "instance-id"}
  104. await WorkerProvisioningController._provisioning_instance(
  105. session, client, worker, cfg
  106. )
  107. assert worker.state_message == "Waiting for instance's public ip"
  108. # Fifth call, the instance should have public ip
  109. mock_instance = MagicMock()
  110. mock_instance.id = "instance-id"
  111. mock_instance.ip_address = "1.2.3.4"
  112. mock_instance.status = InstanceState.RUNNING
  113. client.get_instance.return_value = mock_instance
  114. client.wait_for_public_ip.return_value = mock_instance
  115. await WorkerProvisioningController._provisioning_instance(
  116. session, client, worker, cfg
  117. )
  118. assert worker.state_message == "Waiting for volumes to attach"
  119. # Sixth call, should create and attach volumes
  120. client.create_volumes_and_attach.return_value = ["vol-1", "vol-2"]
  121. await WorkerProvisioningController._provisioning_instance(
  122. session, client, worker, cfg
  123. )
  124. assert worker.provider_config is not None
  125. assert worker.provider_config.get("volume_ids") == ["vol-1", "vol-2"]
  126. # final call, worker provisioning state should have provisioned
  127. await WorkerProvisioningController._provisioning_instance(
  128. session, client, worker, cfg
  129. )
  130. assert worker.state == WorkerStateEnum.INITIALIZING
  131. @pytest.mark.asyncio
  132. async def test_deleting_flow(monkeypatch):
  133. session = AsyncMock()
  134. client = AsyncMock()
  135. cluster = Cluster(id=1, provider="DigitalOcean", region="nyc3", credential_id=1)
  136. pool = WorkerPool(id=1, cluster=cluster)
  137. worker = Worker(
  138. id=1,
  139. name="test-worker",
  140. cluster=cluster,
  141. worker_pool=pool,
  142. state=WorkerStateEnum.DELETING,
  143. external_id="instance-id",
  144. deleted_at="2025-08-29",
  145. )
  146. credential = CloudCredential(id=1, token="dummy")
  147. monkeypatch.setattr(
  148. "gpustack.cloud_providers.common.get_client_from_provider",
  149. lambda provider, credential: client,
  150. )
  151. monkeypatch.setattr(
  152. "gpustack.schemas.workers.Worker.one_by_id", AsyncMock(return_value=worker)
  153. )
  154. monkeypatch.setattr(
  155. "gpustack.schemas.clusters.CloudCredential.one_by_id",
  156. AsyncMock(return_value=credential),
  157. )
  158. monkeypatch.setattr("gpustack.server.services.WorkerService.delete", AsyncMock())
  159. client.delete_instance = AsyncMock()
  160. await WorkerProvisioningController._deleting_instance(session, client, worker)
  161. client.delete_instance.assert_awaited_with("instance-id")