registration.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import os
  2. import time
  3. from cachetools import TTLCache, cached
  4. from typing import Optional
  5. from gpustack.client import ClientSet
  6. from gpustack.client.worker_manager_clients import (
  7. WorkerRegistrationClient,
  8. )
  9. from gpustack.security import API_KEY_PREFIX
  10. from gpustack.utils.uuid import get_legacy_uuid, get_system_uuid
  11. from gpustack.utils.network import check_registry_reachable
  12. registration_token_filename = "token"
  13. worker_token_filename = "worker_token"
  14. def read_token(data_dir: str, filename) -> Optional[str]:
  15. token_path = os.path.join(data_dir, filename)
  16. if os.path.exists(token_path):
  17. with open(token_path, "r") as f:
  18. return f.read().strip()
  19. return None
  20. def write_token(data_dir: str, filename: str, token: str):
  21. token_path = os.path.join(data_dir, filename)
  22. if os.path.exists(token_path):
  23. with open(token_path, "r") as f:
  24. existing_token = f.read().strip()
  25. if existing_token == token:
  26. return # Token is already written
  27. with open(token_path, "w") as f:
  28. f.write(token + "\n")
  29. def read_worker_token(data_dir: str) -> Optional[str]:
  30. return read_token(data_dir, worker_token_filename)
  31. def write_worker_token(data_dir: str, token: str):
  32. write_token(data_dir, worker_token_filename, token)
  33. def read_registration_token(data_dir: str) -> Optional[str]:
  34. return read_token(data_dir, registration_token_filename)
  35. def write_registration_token(data_dir: str, token: str):
  36. write_token(data_dir, registration_token_filename, token)
  37. def registration_client(
  38. data_dir: str,
  39. server_url: str,
  40. registration_token: Optional[str] = None,
  41. wait_token_file: bool = False,
  42. ) -> Optional[WorkerRegistrationClient]:
  43. # if token exists, skip registration
  44. if registration_token is None and wait_token_file:
  45. timeout = 10
  46. start_time = time.time()
  47. while True:
  48. registration_token = read_registration_token(data_dir)
  49. if registration_token is not None:
  50. break
  51. if time.time() - start_time > timeout:
  52. raise FileNotFoundError("Registration token file not found")
  53. time.sleep(0.5)
  54. if registration_token:
  55. if not registration_token.startswith(API_KEY_PREFIX):
  56. legacy_uuid = get_legacy_uuid(data_dir) or get_system_uuid()
  57. if not legacy_uuid:
  58. raise ValueError(
  59. "Legacy UUID not found, please re-register the worker."
  60. )
  61. registration_token = f"{API_KEY_PREFIX}_{legacy_uuid}_{registration_token}"
  62. clientset = ClientSet(
  63. base_url=server_url,
  64. api_key=registration_token,
  65. )
  66. return WorkerRegistrationClient(clientset.http_client)
  67. return None
  68. cache = TTLCache(maxsize=3, ttl=3600)
  69. @cached(cache)
  70. def determine_default_registry(override: Optional[str] = None) -> Optional[str]:
  71. if override is not None and len(override) > 0:
  72. return override
  73. docker_hub_reachable = check_registry_reachable("https://registry-1.docker.io")
  74. quay_io_reachable = check_registry_reachable("https://quay.io")
  75. if docker_hub_reachable:
  76. return None
  77. elif quay_io_reachable:
  78. return "quay.io"
  79. else:
  80. return None