| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- import logging
- import random
- import string
- import asyncio
- from typing import List, Optional, Dict, Any
- from .abstract import (
- ProviderClientBase,
- CloudInstance,
- CloudInstanceCreate,
- InstanceState,
- )
- from pydo.aio import Client
- from gpustack.schemas.clusters import Volume
- from gpustack.cloud_providers.user_data import UserDataTemplate
- from gpustack_runtime.detector import ManufacturerEnum
- logger = logging.getLogger(__name__)
- status_mapping = {
- "new": InstanceState.CREATED,
- "active": InstanceState.RUNNING,
- }
- class DigitalOceanClient(ProviderClientBase):
- client: Client
- def __init__(self, token: str):
- self.client = Client(token=token, timeout=30)
- async def create_instance(self, instance: CloudInstanceCreate) -> str:
- tags: List[str] = [f"{k}:{v}" for k, v in instance.labels.items()]
- req = {
- "name": instance.name,
- "image": instance.image,
- "size": instance.type,
- "region": instance.region,
- "ssh_keys": [instance.ssh_key_id],
- "user_data": instance.user_data,
- "tags": tags,
- }
- try:
- logger.info(f"Creating digital ocean droplet with name {instance.name}")
- logger.debug(f"Request body: {req}")
- droplet_resp = await self.client.droplets.create(body=req)
- id = droplet_resp['droplet']['id']
- return str(id)
- except Exception as e:
- logger.error(f"Failed to create digital ocean instance: {e}")
- raise e
- async def delete_instance(self, external_id: str):
- logger.info(f"Deleting digital ocean instance with id {external_id}")
- delete_response = (
- await self.client.droplets.destroy_with_associated_resources_dangerous(
- external_id,
- x_dangerous=True,
- )
- )
- if delete_response is None:
- return
- logger.warning(
- f"Failed to delete droplet {external_id}, Delete response: {delete_response}"
- )
- raise RuntimeError(
- f"Failed to delete droplet {external_id}, {delete_response.message}"
- )
- async def get_instance(self, external_id: str) -> Optional[CloudInstance]:
- response = await self.client.droplets.get(external_id)
- instance: Dict[str, Any] = response.get('droplet', None)
- if instance is None:
- return None
- ip_address = None
- v4_list = instance.get('networks', {}).get('v4', [])
- for net in v4_list:
- if net.get('type') == 'public':
- ip_address = net.get('ip_address')
- break
- status: InstanceState = status_mapping.get(
- instance.get('status'), InstanceState.UNKNOWN
- )
- return CloudInstance(
- external_id=str(instance.get('id')),
- name=instance.get('name'),
- image=instance.get('image', {}).get('slug', ''),
- type=instance.get('size_slug'),
- region=instance.get('region', {}).get('slug', ''),
- ssh_key_id=None,
- volume_ids=instance.get('volume_ids', []),
- user_data=None,
- status=status,
- ip_address=ip_address,
- )
- async def wait_for_started(
- self, external_id: str, backoff: int = 15, limit: int = 20
- ) -> CloudInstance:
- for _ in range(limit):
- instance = await self.get_instance(external_id)
- if instance and instance.status == InstanceState.RUNNING:
- return instance
- await asyncio.sleep(backoff)
- raise TimeoutError(
- f"DigitalOcean droplet {external_id} did not start within {limit} retries"
- )
- async def wait_for_public_ip(
- self, external_id: str, backoff: int = 15, limit: int = 20
- ) -> CloudInstance:
- for _ in range(limit):
- instance = await self.get_instance(external_id)
- if (
- instance
- and instance.ip_address is not None
- and instance.ip_address != ""
- ):
- return instance
- await asyncio.sleep(backoff)
- raise TimeoutError(
- f"DigitalOcean droplet {external_id} did not acquire a public IP within {limit} retries"
- )
- async def create_ssh_key(self, worker_name: str, public_key: str) -> str:
- ssh_key_resp = await self.client.ssh_keys.create(
- body={"name": f"sshkey-{worker_name}", "public_key": public_key},
- )
- id = ssh_key_resp['ssh_key']['id']
- return str(id)
- async def delete_ssh_key(self, id: str):
- await self.client.ssh_keys.delete(id)
- async def create_volumes_and_attach(
- self, worker_id: int, external_id: str, region: str, *volumes: Volume
- ) -> List[str]:
- # validate volumes
- volume_ids = []
- if len(volumes) == 0:
- return volume_ids
- for idx, volume in enumerate(volumes):
- size_gb = volume.size_gb
- if size_gb is None or size_gb <= 0:
- raise ValueError(
- f"Volume #{idx} missing or invalid 'size_gb': {volume}"
- )
- format = volume.format
- if format is None or format not in ['ext4', 'xfs']:
- raise ValueError(f"Volume #{idx} missing or invalid 'format': {volume}")
- if len(format) > (16 - 2 - 2 - 2):
- # 16 is max label length, 2 for underscores, 2 for index digits and 2 for hashed prefix
- raise ValueError(f"Volume #{idx} 'format' too long: {volume}")
- random_prefix = ''.join(random.choices(string.ascii_lowercase, k=6))
- for volume in volumes:
- index = volumes.index(volume)
- label = f"{random_prefix}_{volume.format}_{index}"
- if len(label) > 16:
- label = label[-16:]
- name = (
- f'{volume.name}-{worker_id}' if volume.name else label.replace('_', '-')
- )
- logger.info(
- f"Creating volume {name} of size {volume.size_gb}GB in region {region}"
- )
- vol_resp = await self.client.volumes.create(
- body={
- "size_gigabytes": volume.size_gb,
- "name": name,
- "region": region,
- "filesystem_type": volume.format,
- "filesystem_label": label,
- },
- )
- vol_id = vol_resp['volume']['id']
- volume_ids.append(str(vol_id))
- logger.info(f"Attaching volume {vol_id} to droplet {external_id}")
- resp = await self.client.volume_actions.post_by_id(
- volume_id=vol_id,
- body={"type": "attach", "droplet_id": external_id, "region": region},
- )
- id: str = resp.get('id', None)
- message: str = resp.get('message', None)
- if id is not None and message is not None:
- logger.error(
- f"Failed to attach volume {vol_id} to droplet {external_id}, response: {message}"
- )
- raise RuntimeError(
- f"Failed to attach volume {vol_id} to droplet {external_id}, response: {message}"
- )
- return volume_ids
- async def construct_user_data(
- self,
- server_url,
- token,
- image_name,
- os_image,
- worker_name,
- secret_configs: Dict[str, Any] = {},
- ) -> UserDataTemplate:
- image_info = await self.client.images.get(os_image)
- distribution = image_info.get('image', {}).get('distribution', '').lower()
- image_slug = image_info.get('image', {}).get('slug', '').lower()
- setup_driver = None
- install_driver = None
- # This is a trick to find out AI/ML ready images of DigitalOcean
- if image_slug.startswith("gpu"):
- # AMD will not take affect for now
- if os_image.lower().find("amd") != -1:
- setup_driver = ManufacturerEnum.AMD
- else:
- setup_driver = ManufacturerEnum.NVIDIA
- elif distribution in ['ubuntu', 'debian']:
- install_driver = ManufacturerEnum.NVIDIA
- setup_driver = ManufacturerEnum.NVIDIA
- user_data = await super().construct_user_data(
- server_url, token, image_name, os_image, worker_name, secret_configs
- )
- user_data.distribution = distribution
- user_data.setup_driver = setup_driver
- user_data.install_driver = install_driver
- user_data.insert_runcmd(
- "mkdir -p /var/lib/gpustack",
- "curl -s http://169.254.169.254/metadata/v1/id > /var/lib/gpustack/external_id",
- 'ip=$(curl -s http://169.254.169.254/metadata/v1/interfaces/public/0/ipv4/address); ip_lc=$(echo "$ip" | tr "A-Z" "a-z"); if [ "$ip_lc" != "not found" ]; then echo "$ip" > /var/lib/gpustack/advertise_address; fi',
- )
- return user_data
- @classmethod
- def get_api_endpoint(cls) -> str:
- return "https://api.digitalocean.com"
- @classmethod
- def process_header(cls, ak: str, sk: str, options: dict, headers: dict) -> dict:
- headers["Authorization"] = f"Bearer {sk}"
- return headers
|