abstract.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from dataclasses import dataclass
  2. from typing import Optional, List, Dict, Any
  3. from abc import ABC, abstractmethod
  4. from enum import Enum
  5. from gpustack.schemas.clusters import Volume
  6. from gpustack.cloud_providers.user_data import UserDataTemplate
  7. class InstanceState(str, Enum):
  8. CREATED = "created"
  9. RUNNING = "running"
  10. STOPPING = "stopping"
  11. STOPPED = "stopped"
  12. TERMINATED = "terminated"
  13. UNKNOWN = "unknown"
  14. @dataclass
  15. class CloudInstanceCreate:
  16. name: str
  17. image: str
  18. type: str
  19. region: str
  20. ssh_key_id: str
  21. user_data: Optional[str] = None
  22. labels: Optional[Dict[str, str]] = None
  23. @dataclass
  24. class CloudInstance(CloudInstanceCreate):
  25. external_id: Optional[str] = None
  26. status: InstanceState = InstanceState.CREATED
  27. ip_address: Optional[str] = None
  28. ssh_key_id: Optional[str] = None
  29. volume_ids: Optional[List[str]] = None
  30. class ProviderClientBase(ABC):
  31. """
  32. The lifecycle is like:
  33. 1. create_ssh_key
  34. 2. create_instance with created ssh_key
  35. 3. wait_for_started
  36. 4. wait_for_public_ip
  37. 5. [optional] create_volumes_and_attach
  38. 6. delete_instance
  39. 7. [optional] delete_ssh_key
  40. """
  41. @abstractmethod
  42. async def create_instance(self, instance: CloudInstanceCreate) -> Optional[str]:
  43. pass
  44. @abstractmethod
  45. async def delete_instance(self, external_id: str):
  46. pass
  47. @abstractmethod
  48. async def get_instance(self, external_id: str) -> Optional[CloudInstance]:
  49. pass
  50. @abstractmethod
  51. async def wait_for_started(
  52. self, external_id: str, backoff: int = 5, limit: int = 60
  53. ) -> CloudInstance:
  54. pass
  55. @abstractmethod
  56. async def wait_for_public_ip(
  57. self, external_id: str, backoff: int = 5, limit: int = 60
  58. ) -> CloudInstance:
  59. pass
  60. @abstractmethod
  61. async def create_ssh_key(self, worker_name: str, public_key: str) -> str:
  62. pass
  63. @abstractmethod
  64. async def delete_ssh_key(self, id: str):
  65. pass
  66. @abstractmethod
  67. async def create_volumes_and_attach(
  68. self, worker_id: int, external_id: str, region: str, *volumes: Volume
  69. ) -> List[str]:
  70. """
  71. Create volumes and attach them to the instance.
  72. Volumes should be tuple of {"size_gb": 10, "format": "ext4", "name": "my-volume"}, the name is optional.
  73. """
  74. pass
  75. async def construct_user_data(
  76. self,
  77. server_url: str,
  78. token: str,
  79. image_name: str,
  80. os_image: str,
  81. worker_name: str,
  82. secret_configs: Dict[str, Any] = {},
  83. ) -> UserDataTemplate:
  84. user_data = UserDataTemplate(
  85. server_url=server_url,
  86. token=token,
  87. image_name=image_name,
  88. secret_configs=secret_configs,
  89. worker_name=worker_name,
  90. )
  91. return user_data
  92. @classmethod
  93. def get_api_endpoint(cls) -> str:
  94. return ""
  95. @classmethod
  96. def process_header(cls, ak: str, sk: str, options: dict, headers: dict) -> dict:
  97. return headers