| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import base64
- from typing import Dict, Tuple, Type, Callable
- from .abstract import ProviderClientBase, CloudInstanceCreate
- from .digital_ocean import DigitalOceanClient
- from gpustack.schemas.clusters import ClusterProvider, CloudCredential, Credential
- from gpustack.schemas.workers import Worker
- from cryptography.hazmat.primitives import serialization
- from cryptography.hazmat.primitives.asymmetric import rsa, ed25519
- factory: Dict[
- ClusterProvider,
- Tuple[Type[ProviderClientBase], Callable[[CloudCredential], ProviderClientBase]],
- ] = {
- ClusterProvider.DigitalOcean: (
- DigitalOceanClient,
- lambda credential: DigitalOceanClient(token=credential.secret),
- ),
- }
- def get_client_from_provider(
- provider: ClusterProvider,
- credential: CloudCredential,
- ) -> ProviderClientBase:
- type_factory = factory.get(provider, None)
- if type_factory is None:
- raise ValueError(f"Unsupported provider: {provider}")
- f = type_factory[1]
- return f(credential)
- def construct_cloud_instance(
- worker: Worker, ssh_key: Credential, user_data: str
- ) -> CloudInstanceCreate:
- """
- Assuming the cloud instance is not created
- """
- cluster = worker.cluster
- pool = worker.worker_pool
- labels = dict(worker.labels or {})
- labels.pop("provider", None)
- labels.pop("instance_type", None)
- return CloudInstanceCreate(
- name=worker.name,
- image=pool.os_image,
- type=pool.instance_type,
- region=cluster.region,
- ssh_key_id=ssh_key.external_id,
- user_data=user_data,
- labels={
- "cluster_id": cluster.id,
- "worker_id": worker.id,
- **labels,
- },
- )
- def generate_ssh_key_pair(
- algorithm: str = "ED25519", key_size: int = 2048
- ) -> Tuple[str, str]:
- """
- algorithm: RSA or ED25519
- returns private_key in base64 encoded, public_key in pem format
- """
- if algorithm.upper() == "RSA":
- key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
- key_bytes = key.private_bytes(
- encoding=serialization.Encoding.PEM,
- format=serialization.PrivateFormat.OpenSSH,
- encryption_algorithm=serialization.NoEncryption(),
- )
- public_key = (
- key.public_key()
- .public_bytes(
- encoding=serialization.Encoding.OpenSSH,
- format=serialization.PublicFormat.OpenSSH,
- )
- .decode()
- )
- elif algorithm.upper() == "ED25519":
- key = ed25519.Ed25519PrivateKey.generate()
- key_bytes = key.private_bytes(
- encoding=serialization.Encoding.Raw,
- format=serialization.PrivateFormat.Raw,
- encryption_algorithm=serialization.NoEncryption(),
- )
- public_key = (
- key.public_key()
- .public_bytes(
- encoding=serialization.Encoding.OpenSSH,
- format=serialization.PublicFormat.OpenSSH,
- )
- .decode()
- )
- else:
- raise ValueError(f"Unsupported algorithm: {algorithm}")
- private_key_b64 = base64.b64encode(key_bytes).decode()
- return private_key_b64, public_key
- def key_bytes_to_openssh_pem(key_bytes: bytes, algorithm: str):
- if algorithm.upper() == "RSA":
- return key_bytes
- elif algorithm.upper() == "ED25519":
- key = ed25519.Ed25519PrivateKey.from_private_bytes(key_bytes)
- pem = key.private_bytes(
- encoding=serialization.Encoding.PEM,
- format=serialization.PrivateFormat.OpenSSH,
- encryption_algorithm=serialization.NoEncryption(),
- )
- else:
- raise ValueError("Unsupported algorithm")
- return pem
|