| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import os
- import time
- from cachetools import TTLCache, cached
- from typing import Optional
- from gpustack.client import ClientSet
- from gpustack.client.worker_manager_clients import (
- WorkerRegistrationClient,
- )
- from gpustack.security import API_KEY_PREFIX
- from gpustack.utils.uuid import get_legacy_uuid, get_system_uuid
- from gpustack.utils.network import check_registry_reachable
- registration_token_filename = "token"
- worker_token_filename = "worker_token"
- def read_token(data_dir: str, filename) -> Optional[str]:
- token_path = os.path.join(data_dir, filename)
- if os.path.exists(token_path):
- with open(token_path, "r") as f:
- return f.read().strip()
- return None
- def write_token(data_dir: str, filename: str, token: str):
- token_path = os.path.join(data_dir, filename)
- if os.path.exists(token_path):
- with open(token_path, "r") as f:
- existing_token = f.read().strip()
- if existing_token == token:
- return # Token is already written
- with open(token_path, "w") as f:
- f.write(token + "\n")
- def read_worker_token(data_dir: str) -> Optional[str]:
- return read_token(data_dir, worker_token_filename)
- def write_worker_token(data_dir: str, token: str):
- write_token(data_dir, worker_token_filename, token)
- def read_registration_token(data_dir: str) -> Optional[str]:
- return read_token(data_dir, registration_token_filename)
- def write_registration_token(data_dir: str, token: str):
- write_token(data_dir, registration_token_filename, token)
- def registration_client(
- data_dir: str,
- server_url: str,
- registration_token: Optional[str] = None,
- wait_token_file: bool = False,
- ) -> Optional[WorkerRegistrationClient]:
- # if token exists, skip registration
- if registration_token is None and wait_token_file:
- timeout = 10
- start_time = time.time()
- while True:
- registration_token = read_registration_token(data_dir)
- if registration_token is not None:
- break
- if time.time() - start_time > timeout:
- raise FileNotFoundError("Registration token file not found")
- time.sleep(0.5)
- if registration_token:
- if not registration_token.startswith(API_KEY_PREFIX):
- legacy_uuid = get_legacy_uuid(data_dir) or get_system_uuid()
- if not legacy_uuid:
- raise ValueError(
- "Legacy UUID not found, please re-register the worker."
- )
- registration_token = f"{API_KEY_PREFIX}_{legacy_uuid}_{registration_token}"
- clientset = ClientSet(
- base_url=server_url,
- api_key=registration_token,
- )
- return WorkerRegistrationClient(clientset.http_client)
- return None
- cache = TTLCache(maxsize=3, ttl=3600)
- @cached(cache)
- def determine_default_registry(override: Optional[str] = None) -> Optional[str]:
- if override is not None and len(override) > 0:
- return override
- docker_hub_reachable = check_registry_reachable("https://registry-1.docker.io")
- quay_io_reachable = check_registry_reachable("https://quay.io")
- if docker_hub_reachable:
- return None
- elif quay_io_reachable:
- return "quay.io"
- else:
- return None
|