common.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import base64
  2. from typing import Dict, Tuple, Type, Callable
  3. from .abstract import ProviderClientBase, CloudInstanceCreate
  4. from .digital_ocean import DigitalOceanClient
  5. from gpustack.schemas.clusters import ClusterProvider, CloudCredential, Credential
  6. from gpustack.schemas.workers import Worker
  7. from cryptography.hazmat.primitives import serialization
  8. from cryptography.hazmat.primitives.asymmetric import rsa, ed25519
  9. factory: Dict[
  10. ClusterProvider,
  11. Tuple[Type[ProviderClientBase], Callable[[CloudCredential], ProviderClientBase]],
  12. ] = {
  13. ClusterProvider.DigitalOcean: (
  14. DigitalOceanClient,
  15. lambda credential: DigitalOceanClient(token=credential.secret),
  16. ),
  17. }
  18. def get_client_from_provider(
  19. provider: ClusterProvider,
  20. credential: CloudCredential,
  21. ) -> ProviderClientBase:
  22. type_factory = factory.get(provider, None)
  23. if type_factory is None:
  24. raise ValueError(f"Unsupported provider: {provider}")
  25. f = type_factory[1]
  26. return f(credential)
  27. def construct_cloud_instance(
  28. worker: Worker, ssh_key: Credential, user_data: str
  29. ) -> CloudInstanceCreate:
  30. """
  31. Assuming the cloud instance is not created
  32. """
  33. cluster = worker.cluster
  34. pool = worker.worker_pool
  35. labels = dict(worker.labels or {})
  36. labels.pop("provider", None)
  37. labels.pop("instance_type", None)
  38. return CloudInstanceCreate(
  39. name=worker.name,
  40. image=pool.os_image,
  41. type=pool.instance_type,
  42. region=cluster.region,
  43. ssh_key_id=ssh_key.external_id,
  44. user_data=user_data,
  45. labels={
  46. "cluster_id": cluster.id,
  47. "worker_id": worker.id,
  48. **labels,
  49. },
  50. )
  51. def generate_ssh_key_pair(
  52. algorithm: str = "ED25519", key_size: int = 2048
  53. ) -> Tuple[str, str]:
  54. """
  55. algorithm: RSA or ED25519
  56. returns private_key in base64 encoded, public_key in pem format
  57. """
  58. if algorithm.upper() == "RSA":
  59. key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
  60. key_bytes = key.private_bytes(
  61. encoding=serialization.Encoding.PEM,
  62. format=serialization.PrivateFormat.OpenSSH,
  63. encryption_algorithm=serialization.NoEncryption(),
  64. )
  65. public_key = (
  66. key.public_key()
  67. .public_bytes(
  68. encoding=serialization.Encoding.OpenSSH,
  69. format=serialization.PublicFormat.OpenSSH,
  70. )
  71. .decode()
  72. )
  73. elif algorithm.upper() == "ED25519":
  74. key = ed25519.Ed25519PrivateKey.generate()
  75. key_bytes = key.private_bytes(
  76. encoding=serialization.Encoding.Raw,
  77. format=serialization.PrivateFormat.Raw,
  78. encryption_algorithm=serialization.NoEncryption(),
  79. )
  80. public_key = (
  81. key.public_key()
  82. .public_bytes(
  83. encoding=serialization.Encoding.OpenSSH,
  84. format=serialization.PublicFormat.OpenSSH,
  85. )
  86. .decode()
  87. )
  88. else:
  89. raise ValueError(f"Unsupported algorithm: {algorithm}")
  90. private_key_b64 = base64.b64encode(key_bytes).decode()
  91. return private_key_b64, public_key
  92. def key_bytes_to_openssh_pem(key_bytes: bytes, algorithm: str):
  93. if algorithm.upper() == "RSA":
  94. return key_bytes
  95. elif algorithm.upper() == "ED25519":
  96. key = ed25519.Ed25519PrivateKey.from_private_bytes(key_bytes)
  97. pem = key.private_bytes(
  98. encoding=serialization.Encoding.PEM,
  99. format=serialization.PrivateFormat.OpenSSH,
  100. encryption_algorithm=serialization.NoEncryption(),
  101. )
  102. else:
  103. raise ValueError("Unsupported algorithm")
  104. return pem