| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- import os
- import logging
- from typing import Optional, Tuple
- import httpx
- from gpustack import __version__, __git_commit__
- from gpustack.client import ClientSet
- from gpustack.client.worker_manager_clients import (
- WorkerStatusClient,
- WorkerRegistrationClient,
- )
- from gpustack.config.config import Config
- from gpustack.schemas.workers import (
- WorkerCreate,
- WorkerUpdate,
- WorkerRegistrationPublic,
- )
- from gpustack.schemas.config import PredefinedConfigNoDefaults
- from gpustack.security import API_KEY_PREFIX
- from gpustack.utils import platform
- from gpustack.worker.collector import WorkerStatusCollector
- from gpustack.config.registration import (
- registration_client,
- read_worker_token,
- write_worker_token,
- )
- from gpustack.utils.uuid import (
- set_worker_name,
- get_worker_name,
- set_legacy_uuid,
- get_legacy_uuid,
- )
- from gpustack.utils.version import is_worker_version_compatible
- logger = logging.getLogger(__name__)
- class WorkerManager:
- _is_embedded: bool
- _collector: WorkerStatusCollector
- _clientset: Optional[ClientSet] = None
- _registration_client: WorkerRegistrationClient
- _status_client: WorkerStatusClient
- def __init__(
- self,
- cfg: Config,
- is_embedded: bool,
- collector: WorkerStatusCollector,
- ):
- self._is_embedded = is_embedded
- self._cfg = cfg
- self._collector = collector
- worker_token = read_worker_token(self._cfg.data_dir)
- if worker_token:
- self._prepare_clients(worker_token)
- def _prepare_clients(self, token: str):
- self._clientset = ClientSet(
- base_url=self._cfg.get_server_url(),
- api_key=token,
- )
- self._status_client = WorkerStatusClient(self._clientset.http_client)
- def sync_worker_status(self):
- """
- Should be called periodically to sync the worker node status with the server.
- It registers the worker node with the server if necessary.
- """
- if self._status_client is None:
- return
- try:
- workerStatus = self._collector.timed_collect(self._clientset)
- except Exception as e:
- logger.error(f"Failed to collect status for worker: {e}")
- return
- try:
- self._status_client.create(workerStatus)
- except Exception as e:
- logger.error(f"Failed to update worker status: {e}")
- async def register_with_server(
- self,
- ) -> Tuple[ClientSet, Optional[PredefinedConfigNoDefaults]]:
- # always re-register the worker and retrive the token and config
- try:
- worker_registerred = await self._register_worker()
- token = worker_registerred.token
- write_worker_token(self._cfg.data_dir, token)
- self._prepare_clients(token)
- return self._clientset, worker_registerred.worker_config
- except Exception as e:
- logger.error(f"Failed to register worker: {e}")
- raise
- async def _register_worker(self) -> WorkerRegistrationPublic:
- name = self._cfg.worker_name or get_worker_name(self._cfg.data_dir)
- logger.info(
- f"Registering worker with name: {name or '<auto-generated-name>'}",
- )
- if self._is_embedded:
- # always reloads the token
- self._cfg.reload_token()
- self._registration_client = registration_client(
- data_dir=self._cfg.data_dir,
- server_url=self._cfg.get_server_url(),
- registration_token=self._cfg.token,
- wait_token_file=self._is_embedded,
- )
- external_id = None
- external_id_path = os.path.join(self._cfg.data_dir, 'external_id')
- if os.path.exists(external_id_path):
- with open(os.path.join(self._cfg.data_dir, 'external_id'), 'r') as f:
- external_id = f.read()
- workerStatus = self._collector.timed_collect(initial=True)
- # Set empty name if not specified to avoid validation error
- workerUpdate = WorkerUpdate(
- name=name or "",
- labels=self._ensure_builtin_labels(),
- )
- to_register = WorkerCreate.model_validate(
- {
- **workerStatus.model_dump(),
- **workerUpdate.model_dump(),
- "external_id": external_id,
- "worker_version": __version__,
- }
- )
- created = await self._registration_client.create_async(to_register)
- logger.info(f"Worker {created.name} registered with worker_id {created.id}.")
- set_worker_name(self._cfg.data_dir, created.name)
- set_legacy_uuid(self._cfg.data_dir, created.worker_uuid)
- return created
- def _register_shutdown_hooks(self):
- pass
- def _ensure_builtin_labels(self) -> dict:
- labels = {
- "os": platform.system(),
- "arch": platform.arch(),
- }
- # worker name label will be set during registration
- name = self._cfg.worker_name or get_worker_name(self._cfg.data_dir)
- if name:
- labels["worker-name"] = name
- # Legacy workers with version 0.7.x send worker_uuid as part of registration.
- # Legacy workers with version <0.7.x don't have worker_uuid, so we use this label as part of the registration allowance.
- is_legacy_token = self._cfg.token and not self._cfg.token.startswith(
- API_KEY_PREFIX
- )
- is_legacy_worker = get_legacy_uuid(self._cfg.data_dir) is None
- is_existing_worker = get_worker_name(self._cfg.data_dir) is not None
- if (is_legacy_token or is_legacy_worker) and is_existing_worker:
- labels["gpustack.existence-check"] = "true"
- return labels
- async def _fetch_server_version(self) -> Optional[dict]:
- """
- Fetch the server version from the /version endpoint.
- Returns None if the endpoint is not available (e.g., old server).
- """
- server_url = self._cfg.get_server_url()
- version_url = f"{server_url}/version"
- try:
- async with httpx.AsyncClient(timeout=5.0) as client:
- response = await client.get(version_url)
- if response.status_code == 404:
- logger.warning(
- "Server does not support version check. "
- "Please upgrade your server to enable version verification."
- )
- return None
- response.raise_for_status()
- return response.json()
- except httpx.TimeoutException:
- logger.error(f"Timeout while fetching server version from {version_url}")
- raise
- except Exception as e:
- logger.error(f"Failed to fetch server version: {e}")
- raise
- async def check_server_version(self):
- """
- Check if the worker version is compatible with the server version.
- Logs a warning if versions are incompatible but allows startup to continue.
- """
- try:
- server_version_info = await self._fetch_server_version()
- except Exception as e:
- logger.error(f"Server version check failed: {e}")
- return
- if server_version_info is None:
- # Old server without version endpoint - warn and continue
- logger.warning(
- "Server version unknown. Proceeding without version check. "
- "Consider upgrading your server."
- )
- return
- server_version = server_version_info.get("version", "unknown")
- server_git_commit = server_version_info.get("git_commit", "unknown")
- is_compatible = is_worker_version_compatible(__version__, server_version)
- if not is_compatible:
- warning_msg = (
- f"Version mismatch detected:\n"
- f" Worker version: {__version__} (commit: {__git_commit__})\n"
- f" Server version: {server_version} (commit: {server_git_commit})\n\n"
- f"Please upgrade your worker to match the server version."
- )
- logger.warning(warning_msg)
- else:
- logger.info(
- f"Version check passed: worker {__version__} matches server {server_version}"
- )
|