import os import logging from typing import Optional, Tuple import httpx from gpustack import __version__, __git_commit__ from gpustack.client import ClientSet from gpustack.client.worker_manager_clients import ( WorkerStatusClient, WorkerRegistrationClient, ) from gpustack.config.config import Config from gpustack.schemas.workers import ( WorkerCreate, WorkerUpdate, WorkerRegistrationPublic, ) from gpustack.schemas.config import PredefinedConfigNoDefaults from gpustack.security import API_KEY_PREFIX from gpustack.utils import platform from gpustack.worker.collector import WorkerStatusCollector from gpustack.config.registration import ( registration_client, read_worker_token, write_worker_token, ) from gpustack.utils.uuid import ( set_worker_name, get_worker_name, set_legacy_uuid, get_legacy_uuid, ) from gpustack.utils.version import is_worker_version_compatible logger = logging.getLogger(__name__) class WorkerManager: _is_embedded: bool _collector: WorkerStatusCollector _clientset: Optional[ClientSet] = None _registration_client: WorkerRegistrationClient _status_client: WorkerStatusClient def __init__( self, cfg: Config, is_embedded: bool, collector: WorkerStatusCollector, ): self._is_embedded = is_embedded self._cfg = cfg self._collector = collector worker_token = read_worker_token(self._cfg.data_dir) if worker_token: self._prepare_clients(worker_token) def _prepare_clients(self, token: str): self._clientset = ClientSet( base_url=self._cfg.get_server_url(), api_key=token, ) self._status_client = WorkerStatusClient(self._clientset.http_client) def sync_worker_status(self): """ Should be called periodically to sync the worker node status with the server. It registers the worker node with the server if necessary. """ if self._status_client is None: return try: workerStatus = self._collector.timed_collect(self._clientset) except Exception as e: logger.error(f"Failed to collect status for worker: {e}") return try: self._status_client.create(workerStatus) except Exception as e: logger.error(f"Failed to update worker status: {e}") async def register_with_server( self, ) -> Tuple[ClientSet, Optional[PredefinedConfigNoDefaults]]: # always re-register the worker and retrive the token and config try: worker_registerred = await self._register_worker() token = worker_registerred.token write_worker_token(self._cfg.data_dir, token) self._prepare_clients(token) return self._clientset, worker_registerred.worker_config except Exception as e: logger.error(f"Failed to register worker: {e}") raise async def _register_worker(self) -> WorkerRegistrationPublic: name = self._cfg.worker_name or get_worker_name(self._cfg.data_dir) logger.info( f"Registering worker with name: {name or ''}", ) if self._is_embedded: # always reloads the token self._cfg.reload_token() self._registration_client = registration_client( data_dir=self._cfg.data_dir, server_url=self._cfg.get_server_url(), registration_token=self._cfg.token, wait_token_file=self._is_embedded, ) external_id = None external_id_path = os.path.join(self._cfg.data_dir, 'external_id') if os.path.exists(external_id_path): with open(os.path.join(self._cfg.data_dir, 'external_id'), 'r') as f: external_id = f.read() workerStatus = self._collector.timed_collect(initial=True) # Set empty name if not specified to avoid validation error workerUpdate = WorkerUpdate( name=name or "", labels=self._ensure_builtin_labels(), ) to_register = WorkerCreate.model_validate( { **workerStatus.model_dump(), **workerUpdate.model_dump(), "external_id": external_id, "worker_version": __version__, } ) created = await self._registration_client.create_async(to_register) logger.info(f"Worker {created.name} registered with worker_id {created.id}.") set_worker_name(self._cfg.data_dir, created.name) set_legacy_uuid(self._cfg.data_dir, created.worker_uuid) return created def _register_shutdown_hooks(self): pass def _ensure_builtin_labels(self) -> dict: labels = { "os": platform.system(), "arch": platform.arch(), } # worker name label will be set during registration name = self._cfg.worker_name or get_worker_name(self._cfg.data_dir) if name: labels["worker-name"] = name # Legacy workers with version 0.7.x send worker_uuid as part of registration. # Legacy workers with version <0.7.x don't have worker_uuid, so we use this label as part of the registration allowance. is_legacy_token = self._cfg.token and not self._cfg.token.startswith( API_KEY_PREFIX ) is_legacy_worker = get_legacy_uuid(self._cfg.data_dir) is None is_existing_worker = get_worker_name(self._cfg.data_dir) is not None if (is_legacy_token or is_legacy_worker) and is_existing_worker: labels["gpustack.existence-check"] = "true" return labels async def _fetch_server_version(self) -> Optional[dict]: """ Fetch the server version from the /version endpoint. Returns None if the endpoint is not available (e.g., old server). """ server_url = self._cfg.get_server_url() version_url = f"{server_url}/version" try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(version_url) if response.status_code == 404: logger.warning( "Server does not support version check. " "Please upgrade your server to enable version verification." ) return None response.raise_for_status() return response.json() except httpx.TimeoutException: logger.error(f"Timeout while fetching server version from {version_url}") raise except Exception as e: logger.error(f"Failed to fetch server version: {e}") raise async def check_server_version(self): """ Check if the worker version is compatible with the server version. Logs a warning if versions are incompatible but allows startup to continue. """ try: server_version_info = await self._fetch_server_version() except Exception as e: logger.error(f"Server version check failed: {e}") return if server_version_info is None: # Old server without version endpoint - warn and continue logger.warning( "Server version unknown. Proceeding without version check. " "Consider upgrading your server." ) return server_version = server_version_info.get("version", "unknown") server_git_commit = server_version_info.get("git_commit", "unknown") is_compatible = is_worker_version_compatible(__version__, server_version) if not is_compatible: warning_msg = ( f"Version mismatch detected:\n" f" Worker version: {__version__} (commit: {__git_commit__})\n" f" Server version: {server_version} (commit: {server_git_commit})\n\n" f"Please upgrade your worker to match the server version." ) logger.warning(warning_msg) else: logger.info( f"Version check passed: worker {__version__} matches server {server_version}" )