digital_ocean.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import logging
  2. import random
  3. import string
  4. import asyncio
  5. from typing import List, Optional, Dict, Any
  6. from .abstract import (
  7. ProviderClientBase,
  8. CloudInstance,
  9. CloudInstanceCreate,
  10. InstanceState,
  11. )
  12. from pydo.aio import Client
  13. from gpustack.schemas.clusters import Volume
  14. from gpustack.cloud_providers.user_data import UserDataTemplate
  15. from gpustack_runtime.detector import ManufacturerEnum
  16. logger = logging.getLogger(__name__)
  17. status_mapping = {
  18. "new": InstanceState.CREATED,
  19. "active": InstanceState.RUNNING,
  20. }
  21. class DigitalOceanClient(ProviderClientBase):
  22. client: Client
  23. def __init__(self, token: str):
  24. self.client = Client(token=token, timeout=30)
  25. async def create_instance(self, instance: CloudInstanceCreate) -> str:
  26. tags: List[str] = [f"{k}:{v}" for k, v in instance.labels.items()]
  27. req = {
  28. "name": instance.name,
  29. "image": instance.image,
  30. "size": instance.type,
  31. "region": instance.region,
  32. "ssh_keys": [instance.ssh_key_id],
  33. "user_data": instance.user_data,
  34. "tags": tags,
  35. }
  36. try:
  37. logger.info(f"Creating digital ocean droplet with name {instance.name}")
  38. logger.debug(f"Request body: {req}")
  39. droplet_resp = await self.client.droplets.create(body=req)
  40. id = droplet_resp['droplet']['id']
  41. return str(id)
  42. except Exception as e:
  43. logger.error(f"Failed to create digital ocean instance: {e}")
  44. raise e
  45. async def delete_instance(self, external_id: str):
  46. logger.info(f"Deleting digital ocean instance with id {external_id}")
  47. delete_response = (
  48. await self.client.droplets.destroy_with_associated_resources_dangerous(
  49. external_id,
  50. x_dangerous=True,
  51. )
  52. )
  53. if delete_response is None:
  54. return
  55. logger.warning(
  56. f"Failed to delete droplet {external_id}, Delete response: {delete_response}"
  57. )
  58. raise RuntimeError(
  59. f"Failed to delete droplet {external_id}, {delete_response.message}"
  60. )
  61. async def get_instance(self, external_id: str) -> Optional[CloudInstance]:
  62. response = await self.client.droplets.get(external_id)
  63. instance: Dict[str, Any] = response.get('droplet', None)
  64. if instance is None:
  65. return None
  66. ip_address = None
  67. v4_list = instance.get('networks', {}).get('v4', [])
  68. for net in v4_list:
  69. if net.get('type') == 'public':
  70. ip_address = net.get('ip_address')
  71. break
  72. status: InstanceState = status_mapping.get(
  73. instance.get('status'), InstanceState.UNKNOWN
  74. )
  75. return CloudInstance(
  76. external_id=str(instance.get('id')),
  77. name=instance.get('name'),
  78. image=instance.get('image', {}).get('slug', ''),
  79. type=instance.get('size_slug'),
  80. region=instance.get('region', {}).get('slug', ''),
  81. ssh_key_id=None,
  82. volume_ids=instance.get('volume_ids', []),
  83. user_data=None,
  84. status=status,
  85. ip_address=ip_address,
  86. )
  87. async def wait_for_started(
  88. self, external_id: str, backoff: int = 15, limit: int = 20
  89. ) -> CloudInstance:
  90. for _ in range(limit):
  91. instance = await self.get_instance(external_id)
  92. if instance and instance.status == InstanceState.RUNNING:
  93. return instance
  94. await asyncio.sleep(backoff)
  95. raise TimeoutError(
  96. f"DigitalOcean droplet {external_id} did not start within {limit} retries"
  97. )
  98. async def wait_for_public_ip(
  99. self, external_id: str, backoff: int = 15, limit: int = 20
  100. ) -> CloudInstance:
  101. for _ in range(limit):
  102. instance = await self.get_instance(external_id)
  103. if (
  104. instance
  105. and instance.ip_address is not None
  106. and instance.ip_address != ""
  107. ):
  108. return instance
  109. await asyncio.sleep(backoff)
  110. raise TimeoutError(
  111. f"DigitalOcean droplet {external_id} did not acquire a public IP within {limit} retries"
  112. )
  113. async def create_ssh_key(self, worker_name: str, public_key: str) -> str:
  114. ssh_key_resp = await self.client.ssh_keys.create(
  115. body={"name": f"sshkey-{worker_name}", "public_key": public_key},
  116. )
  117. id = ssh_key_resp['ssh_key']['id']
  118. return str(id)
  119. async def delete_ssh_key(self, id: str):
  120. await self.client.ssh_keys.delete(id)
  121. async def create_volumes_and_attach(
  122. self, worker_id: int, external_id: str, region: str, *volumes: Volume
  123. ) -> List[str]:
  124. # validate volumes
  125. volume_ids = []
  126. if len(volumes) == 0:
  127. return volume_ids
  128. for idx, volume in enumerate(volumes):
  129. size_gb = volume.size_gb
  130. if size_gb is None or size_gb <= 0:
  131. raise ValueError(
  132. f"Volume #{idx} missing or invalid 'size_gb': {volume}"
  133. )
  134. format = volume.format
  135. if format is None or format not in ['ext4', 'xfs']:
  136. raise ValueError(f"Volume #{idx} missing or invalid 'format': {volume}")
  137. if len(format) > (16 - 2 - 2 - 2):
  138. # 16 is max label length, 2 for underscores, 2 for index digits and 2 for hashed prefix
  139. raise ValueError(f"Volume #{idx} 'format' too long: {volume}")
  140. random_prefix = ''.join(random.choices(string.ascii_lowercase, k=6))
  141. for volume in volumes:
  142. index = volumes.index(volume)
  143. label = f"{random_prefix}_{volume.format}_{index}"
  144. if len(label) > 16:
  145. label = label[-16:]
  146. name = (
  147. f'{volume.name}-{worker_id}' if volume.name else label.replace('_', '-')
  148. )
  149. logger.info(
  150. f"Creating volume {name} of size {volume.size_gb}GB in region {region}"
  151. )
  152. vol_resp = await self.client.volumes.create(
  153. body={
  154. "size_gigabytes": volume.size_gb,
  155. "name": name,
  156. "region": region,
  157. "filesystem_type": volume.format,
  158. "filesystem_label": label,
  159. },
  160. )
  161. vol_id = vol_resp['volume']['id']
  162. volume_ids.append(str(vol_id))
  163. logger.info(f"Attaching volume {vol_id} to droplet {external_id}")
  164. resp = await self.client.volume_actions.post_by_id(
  165. volume_id=vol_id,
  166. body={"type": "attach", "droplet_id": external_id, "region": region},
  167. )
  168. id: str = resp.get('id', None)
  169. message: str = resp.get('message', None)
  170. if id is not None and message is not None:
  171. logger.error(
  172. f"Failed to attach volume {vol_id} to droplet {external_id}, response: {message}"
  173. )
  174. raise RuntimeError(
  175. f"Failed to attach volume {vol_id} to droplet {external_id}, response: {message}"
  176. )
  177. return volume_ids
  178. async def construct_user_data(
  179. self,
  180. server_url,
  181. token,
  182. image_name,
  183. os_image,
  184. worker_name,
  185. secret_configs: Dict[str, Any] = {},
  186. ) -> UserDataTemplate:
  187. image_info = await self.client.images.get(os_image)
  188. distribution = image_info.get('image', {}).get('distribution', '').lower()
  189. image_slug = image_info.get('image', {}).get('slug', '').lower()
  190. setup_driver = None
  191. install_driver = None
  192. # This is a trick to find out AI/ML ready images of DigitalOcean
  193. if image_slug.startswith("gpu"):
  194. # AMD will not take affect for now
  195. if os_image.lower().find("amd") != -1:
  196. setup_driver = ManufacturerEnum.AMD
  197. else:
  198. setup_driver = ManufacturerEnum.NVIDIA
  199. elif distribution in ['ubuntu', 'debian']:
  200. install_driver = ManufacturerEnum.NVIDIA
  201. setup_driver = ManufacturerEnum.NVIDIA
  202. user_data = await super().construct_user_data(
  203. server_url, token, image_name, os_image, worker_name, secret_configs
  204. )
  205. user_data.distribution = distribution
  206. user_data.setup_driver = setup_driver
  207. user_data.install_driver = install_driver
  208. user_data.insert_runcmd(
  209. "mkdir -p /var/lib/gpustack",
  210. "curl -s http://169.254.169.254/metadata/v1/id > /var/lib/gpustack/external_id",
  211. 'ip=$(curl -s http://169.254.169.254/metadata/v1/interfaces/public/0/ipv4/address); ip_lc=$(echo "$ip" | tr "A-Z" "a-z"); if [ "$ip_lc" != "not found" ]; then echo "$ip" > /var/lib/gpustack/advertise_address; fi',
  212. )
  213. return user_data
  214. @classmethod
  215. def get_api_endpoint(cls) -> str:
  216. return "https://api.digitalocean.com"
  217. @classmethod
  218. def process_header(cls, ak: str, sk: str, options: dict, headers: dict) -> dict:
  219. headers["Authorization"] = f"Bearer {sk}"
  220. return headers