import ipaddress import logging import os import secrets import socket import uuid from enum import Enum from typing import List, Optional, Dict from urllib.parse import urlparse import httpx import hmac import hashlib from gpustack_runtime.detector import ( manufacturer_to_backend, available_manufacturers, available_backends, ) from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from gpustack.utils import validators from gpustack.schemas.workers import ( CPUInfo, FileSystemInfo, GPUDeviceStatus, KernelInfo, MemoryInfo, MountPoint, OperatingSystemInfo, SwapInfo, SystemInfo, UptimeInfo, GPUDevicesStatus, GPUNetworkInfo, ) from gpustack.schemas.users import AuthProviderEnum from gpustack.schemas.config import ( ModelInstanceProxyModeEnum, PredefinedConfig, PredefinedConfigNoDefaults, GatewayModeEnum, ) from gpustack import __version__ from gpustack.config.registration import ( read_registration_token, read_worker_token, determine_default_registry, ) from gpustack.utils.network import ( get_first_non_loopback_ip, get_system_trust_store_ssl_context, use_proxy_env_for_url, ) from gpustack.utils import platform _config = None logger = logging.getLogger(__name__) class WorkerConfig(PredefinedConfig): # common config which should be dynamic or not configurable data_dir: Optional[str] = None advertise_address: Optional[str] = None # Worker options which are different for each worker token: Optional[str] = None server_url: Optional[str] = None worker_ip: Optional[str] = None worker_ifname: Optional[str] = None worker_name: Optional[str] = None class Config(WorkerConfig, BaseSettings): """A class used to define GPUStack configuration. Attributes: port: Port to bind the server to. Default is 80. tls_port: Port to bind the TLS server to. Default is 443. api_port: Port to bind the gpustack API server to. Default is 30080. advertise_address: The address to expose for external access. Auto-detected by default. debug: Enable debug mode. data_dir: Directory to store data. Default is OS specific. huggingface_token: User Access Token to authenticate to the Hugging Face Hub. metrics_port: Port to expose metrics on. disable_metrics: Disable server metrics. ssl_keyfile: Path to the SSL key file. ssl_certfile: Path to the SSL certificate file. database_url: URL of the database. disable_worker: (Deprecated) Disable embedded worker. enable_worker: Enable embedded worker. bootstrap_password: Password for the bootstrap admin user. jwt_secret_key: Secret key for JWT. Auto-generated by default. force_auth_localhost: Force authentication for requests originating from localhost (127.0.0.1). When set to True, all requests from localhost will require authentication. disable_update_check: Disable update check. update_check_url: URL to check for updates. model_catalog_file: Path or URL to the model catalog file. token: Shared secret used to register worker. server_url: URL of the server. worker_ip: IP address of the worker node. Auto-detected by default. worker_ifname: Network interface name of the worker node. Auto-detected by default. worker_name: Name of the worker node. Use the hostname by default. disable_worker_metrics: Disable worker metrics. worker_metrics_port: Port to expose metrics on. worker_port: Port to bind the worker to. service_port_range: Port range for inference services, specified as a string in the form 'N1-N2'. Both ends of the range are inclusive. Default is '40000-40063'. ray_port_range: Port range for Ray services(vLLM distributed deployment using), specified as a string in the form 'N1-N2'. Both ends of the range are inclusive. Default is '41000-41999'. log_dir: Directory to store logs. bin_dir: Directory to store additional binaries, e.g., versioned backend executables. benchmark_dir: Directory to store benchmark results. benchmark_max_duration_seconds: Max duration for a benchmark before timeout. Disabled when unset. pipx_path: Path to the pipx executable, used to install versioned backends. system_reserved: Reserved system resources. tools_download_base_url: Base URL to download dependency tools. enable_hf_transfer: [Deprecated] No-op since huggingface_hub v1.0 removed hf_transfer support; hf_xet is now the default downloader. enable_cors: Enable CORS in server. allow_origins: A list of origins that should be permitted to make cross-origin requests. allow_credentials: Indicate that cookies should be supported for cross-origin requests. allow_methods: A list of HTTP methods that should be allowed for cross-origin requests. allow_headers: A list of HTTP request headers that should be supported for cross-origin requests. server_external_url: Specified external URL for the server. system_default_container_registry: Default registry for container images (server and inference images). image_name_override: Force override of the image name. image_repo: Repository for the container images. service_discovery_name: Name of the service discovery service in DNS. Only useful when deployed in Kubernetes with service discovery. gateway_mode: Gateway deployment mode. Options are 'auto', 'embedded', 'incluster', 'external', 'disabled'. Default is 'auto'. gateway_kubeconfig: Path to the kubeconfig file for gateway. Only used when gateway_mode is 'external'. gateway_concurrency: Number of concurrent connections for the embedded gateway. Default is 16. gateway_namespace: The namespace where the gateway component is deployed. namespace: Kubernetes namespace for GPUStack to deploy gateway routing rules and model instances. disable_builtin_observability: Disable embedded Grafana and Prometheus services. grafana_url: Base URL for Grafana UI used by redirects and proxying. When unset, defaults to the embedded Grafana URL unless builtin observability is disabled. grafana_worker_dashboard_uid: Grafana dashboard UID for worker dashboard. grafana_model_dashboard_uid: Grafana dashboard UID for model dashboard. gateway_plugin_server_url: URL to fetch gateway plugin manifest for embedded gateway. """ # Server options # Deprecated, as we using docker image to run the server, host is not used. host: Optional[str] = None # The port and tls_port are used in gateway configuration. port: Optional[int] = 80 tls_port: Optional[int] = 443 # The api_port is used in gpustack server/worker serving API requests. api_port: Optional[int] = 30080 proxy_port: Optional[int] = 30079 database_port: Optional[int] = 5432 database_url: Optional[str] = None disable_worker: Optional[bool] = None # Deprecated enable_worker: bool = False bootstrap_password: Optional[str] = None jwt_secret_key: Optional[str] = None resources: Optional[dict] = None ssl_keyfile: Optional[str] = None ssl_certfile: Optional[str] = None force_auth_localhost: bool = False metrics_port: int = 10161 disable_metrics: bool = False disable_update_check: bool = False disable_openapi_docs: bool = False update_check_url: Optional[str] = None model_catalog_file: Optional[str] = None enable_cors: bool = False allow_origins: Optional[List[str]] = ['*'] allow_credentials: bool = False allow_methods: Optional[List[str]] = ['GET', 'POST'] allow_headers: Optional[List[str]] = ['Authorization', 'Content-Type', 'X-API-Key'] external_auth_type: Optional[str] = None # external auth type external_auth_name: Optional[str] = None # external auth name external_auth_full_name: Optional[str] = None # external auth full name external_auth_avatar_url: Optional[str] = None # external auth avatar url external_auth_default_inactive: bool = False # external auth default inactive oidc_client_id: Optional[str] = None # oidc client id oidc_client_secret: Optional[str] = None # oidc client secret oidc_redirect_uri: Optional[str] = None # oidc redirect uri oidc_issuer: Optional[str] = None # oidc issuer oidc_skip_userinfo: bool = False # skip to request the oidc user_info endpoint oidc_use_userinfo: Optional[bool] = ( None # Deprecated, use oidc_skip_userinfo instead ) openid_configuration: Optional[dict] = None # fetched openid configuration saml_sp_entity_id: Optional[str] = None # saml sp_entity_id saml_sp_acs_url: Optional[str] = None # saml sp_acs_url saml_sp_x509_cert: Optional[str] = '' # saml sp_x509_cert saml_sp_private_key: Optional[str] = '' # saml sp_private_key saml_sp_attribute_prefix: Optional[str] = None # saml sp attribute prefix saml_idp_entity_id: Optional[str] = None # saml idp_entityId saml_idp_server_url: Optional[str] = None # saml idp_server_url saml_idp_logout_url: Optional[str] = None saml_sp_slo_url: Optional[str] = None saml_idp_x509_cert: Optional[str] = '' # saml idp_x509_cert saml_security: Optional[str] = '{}' # saml security server_external_url: Optional[str] = None # custom post-logout redirection key for compatibility with different IdPs. external_auth_post_logout_redirect_key: Optional[str] = None # Number of concurrent connections for the embedded gateway. gateway_concurrency: int = 16 gateway_plugin_server_url: Optional[str] = None gateway_ingress_class: str = "higress" disable_builtin_observability: bool = False builtin_prometheus_port: int = 19090 builtin_grafana_port: int = 13000 grafana_url: Optional[str] = None grafana_worker_dashboard_uid: Optional[str] = "gpustack-worker" grafana_model_dashboard_uid: Optional[str] = "gpustack-model" server_id: Optional[str] = None _set_worker_fields = {} _derive_gateway_token = None _jwt_secret_key_user_provided = False model_config = SettingsConfigDict( env_prefix="GPUSTACK_", protected_namespaces=('settings_',), extra="allow" ) def __init__(self, **values): super().__init__(**values) self._set_worker_fields = self.model_dump( exclude_defaults=True, exclude_unset=True, exclude_none=True, include=self.__pydantic_fields_set__ & set(PredefinedConfig.model_fields.keys()), ) def prepare_dir(dir_path: Optional[str], default: str) -> str: return default if dir_path is None else os.path.abspath(dir_path) # common options self.data_dir = prepare_dir(self.data_dir, self.get_data_dir()) self.cache_dir = prepare_dir( self.cache_dir, os.path.join(self.data_dir, "cache") ) self.bin_dir = prepare_dir(self.bin_dir, os.path.join(self.data_dir, "bin")) self.log_dir = prepare_dir(self.log_dir, os.path.join(self.data_dir, "log")) self.benchmark_dir = prepare_dir( self.benchmark_dir, os.path.join(self.data_dir, "benchmarks") ) if isinstance(self.grafana_url, str) and not self.grafana_url.strip(): self.grafana_url = None if self.token is None: self.token = read_registration_token(self.data_dir) if self.advertise_address is None and os.path.exists( os.path.join(self.data_dir, "advertise_address") ): with open(os.path.join(self.data_dir, "advertise_address"), "r") as f: addr = f.read().strip() try: if len(addr) > 0 and ipaddress.ip_address(addr): self.advertise_address = addr except Exception: pass if ( self._is_worker() and self.token is None and read_worker_token(self.data_dir) is None ): raise Exception("Token is required when running as worker") # Generate server_id if not provided if self.server_id is None: self.server_id = f"{socket.gethostname()}-{uuid.uuid4().hex[:8]}" # Snapshot whether jwt_secret_key came from user input (flag / env / config # file) before prepare_jwt_secret_key() auto-fills it. Server startup uses # this to enforce that distributed mode requires an explicit key. self._jwt_secret_key_user_provided = self.jwt_secret_key is not None self.prepare_jwt_secret_key() # server options self.init_auth() if self.system_reserved is None: self.system_reserved = {"ram": 0, "vram": 0} if self.service_discovery_name is None: self.service_discovery_name = "worker" if self._is_worker() else "server" self.make_dirs() self.detect_gateway_mode() # default to worker proxy mode if running as worker if self.proxy_mode is None: self.proxy_mode = ModelInstanceProxyModeEnum.WORKER self._derive_gateway_token = hmac.new( self.jwt_secret_key.encode(), b"gateway-metrics-push", hashlib.sha256 ).hexdigest() @model_validator(mode="after") def check_all(self): # noqa: C901 if 'PYTEST_CURRENT_TEST' in os.environ: # Skip validation during tests return self if (self.ssl_keyfile and not self.ssl_certfile) or ( self.ssl_certfile and not self.ssl_keyfile ): raise Exception( 'Both "ssl_keyfile" and "ssl_certfile" must be provided, or neither.' ) if self.server_url: self.server_url = self.server_url.rstrip("/") if validators.url(self.server_url) is not True: raise Exception("Invalid server URL.") if self.resources: self.get_gpu_devices() self.get_system_info() if self.service_port_range: self.check_port_range(self.service_port_range) if self.ray_port_range: self.check_port_range(self.ray_port_range, diff=20) if self.oidc_use_userinfo is not None: self.oidc_skip_userinfo = not self.oidc_use_userinfo if self.database_url is not None: self.check_database_url() return self def get_grafana_url(self) -> Optional[str]: if self.grafana_url is not None: return self.grafana_url if self.disable_builtin_observability: return None return f"http://127.0.0.1:{self.builtin_grafana_port}" def get_builtin_prometheus_url(self) -> Optional[str]: if self.disable_builtin_observability or self.grafana_url is not None: return None return f"http://127.0.0.1:{self.builtin_prometheus_port}" @staticmethod def check_port_range(port_range: str, diff: Optional[int] = None): ports = port_range.split("-") if len(ports) != 2: raise Exception(f"Invalid port range: {port_range}") if not ports[0].isdigit() or not ports[1].isdigit(): raise Exception("Port range must be numeric") if int(ports[0]) > int(ports[1]): raise Exception(f"Invalid port range: {ports[0]} > {ports[1]}") if diff is not None: if int(ports[1]) - int(ports[0]) + 1 < diff: raise Exception( f"Port range is too small: {port_range}, at least {diff} ports are required" ) def make_dirs(self): os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.cache_dir, exist_ok=True) os.makedirs(self.bin_dir, exist_ok=True) os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.benchmark_dir, exist_ok=True) # ensure higress data dir exists os.makedirs(self.higress_base_dir(), exist_ok=True) if self.server_role() != self.ServerRole.WORKER: os.makedirs(self.postgres_base_dir(), exist_ok=True) def get_system_info(self) -> SystemInfo: # noqa: C901 """get system info from resources resource example: ```yaml resources: cpu: total: 10 memory: total: 34359738368 is_unified_memory: true swap: total: 3221225472 filesystem: - name: Macintosh HD mount_point: / mount_from: /dev/disk3s1s1 total: 994662584320 os: name: macOS version: "14.5" kernel: name: Darwin release: 23.5.0 version: "Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024;" architecture: "" uptime: uptime: 355250885 boot_time: 2025-02-24T09:17:51.337+0800 ``` """ system_info: SystemInfo = SystemInfo() if not self.resources: return None cpu_dict = self.resources.get("cpu") if cpu_dict and cpu_dict.get("total"): system_info.cpu = CPUInfo(total=cpu_dict.get("total")) memory_dict = self.resources.get("memory") if memory_dict and memory_dict.get("total"): system_info.memory = MemoryInfo(total=memory_dict.get("total")) swap_dict = self.resources.get("swap") if swap_dict and swap_dict.get("total"): system_info.swap = SwapInfo(total=swap_dict.get("total")) filesystem_dict = self.resources.get("filesystem") if filesystem_dict: filesystem: FileSystemInfo = [] for fs in filesystem_dict: name = fs.get("name") mount_point = fs.get("mount_point") mount_from = fs.get("mount_from") total = fs.get("total") if not name: raise Exception("Filesystem name is required") if not mount_point: raise Exception("Filesystem mount_point is required") if not mount_from: raise Exception("Filesystem mount_from is required") if total is None: raise Exception("Filesystem total is required") filesystem.append( MountPoint( name=name, mount_point=mount_point, mount_from=mount_from, total=total, ) ) system_info.filesystem = filesystem os_dict = self.resources.get("os") if os_dict: name = os_dict.get("name") version = os_dict.get("version") if not name: raise Exception("OS name is required") if not version: raise Exception("OS version is required") system_info.os = OperatingSystemInfo(name=name, version=version) kernel_dict = self.resources.get("kernel") if kernel_dict: name = kernel_dict.get("name") release = kernel_dict.get("release") version = kernel_dict.get("version") architecture = kernel_dict.get("architecture") if not name: raise Exception("Kernel name is required") if not release: raise Exception("Kernel release is required") if not version: raise Exception("Kernel version is required") system_info.kernel = KernelInfo( name=name, release=release, version=version, architecture=architecture ) uptime_dict = self.resources.get("uptime") if uptime_dict: uptime = uptime_dict.get("uptime") boot_time = uptime_dict.get("boot_time") if uptime is None: raise Exception("Uptime is required") if not boot_time: raise Exception("Boot time is required") system_info.uptime = UptimeInfo(uptime=uptime, boot_time=boot_time) if not any( [ system_info.cpu, system_info.memory, system_info.swap, system_info.filesystem, system_info.os, system_info.kernel, system_info.uptime, ] ): return None return system_info def get_gpu_devices(self) -> GPUDevicesStatus: # noqa: C901 """get gpu devices from resources resource example: ```yaml resources: gpu_devices: - name: Ascend CANN 910b vendor: ascend arch_family: Ascend910B2 index: 0 device_index: 0 # optional device_chip_index: 0 # optional compute_capability: "9.0" # optional memory: total: 22906503168 is_unified_memory: true network: status: "up" inet: "29.17.45.215" netmask: "255.255.0.0" # optional mac: "6c34:91:87:3c:ae" # optional gateway: "29.17.0.1" # optional iface: "eth4" # optional mtu: 8192 # optional ``` """ gpu_devices: GPUDevicesStatus = [] if not self.resources: return None gpu_device_dict = self.resources.get("gpu_devices") if not gpu_device_dict: return None for gd in gpu_device_dict: name = gd.get("name") arch_family = gd.get("arch_family", None) index = gd.get("index") compute_capability = gd.get("compute_capability", None) device_index = gd.get("device_index", index) device_chip_index = gd.get("device_chip_index", 0) vendor = gd.get("vendor") memory = gd.get("memory") network = gd.get("network") runtime_version = gd.get("runtime_version") type_ = gd.get("type") or manufacturer_to_backend(vendor) if not name: raise Exception("GPU device name is required") if index is None: raise Exception("GPU device index is required") vendors = available_manufacturers() if vendor not in vendors: raise Exception( f"Unsupported GPU device vendor, available vendors are: {','.join(map(str, vendors))}" ) if not memory: raise Exception("GPU device memory is required") elif not memory.get("total"): raise Exception("GPU device memory total is required") if network: network_status = network.get("status", "up") if network_status not in ["up", "down"]: raise Exception( "GPU device network status is invalid, supported status are: up, down" ) network_inet = network.get("inet", None) if network_inet is None: raise Exception("GPU device network inet is required") elif not validators.ip(network_inet): raise Exception("GPU device network inet is invalid") network_netmask = network.get("netmask", None) if network_netmask and not validators.ip(network_netmask): raise Exception("GPU device network netmask is invalid") gateway = network.get("gateway", None) if gateway and not validators.ip(gateway): raise Exception("GPU device network gateway is invalid") types = available_backends() if type_ not in types: raise Exception( f"Unsupported GPU type, available type are: {','.join(map(str, types))}" ) gpu_devices.append( GPUDeviceStatus( index=index, arch_family=arch_family, compute_capability=compute_capability, device_index=device_index, device_chip_index=device_chip_index, name=name, vendor=vendor, runtime_version=runtime_version, memory=MemoryInfo( total=memory.get("total"), is_unified_memory=memory.get("is_unified_memory", False), ), network=( None if not network else GPUNetworkInfo( status=network.get("status", "up"), inet=network.get("inet"), netmask=network.get("netmask", ""), mac=network.get("mac", ""), gateway=network.get("gateway", ""), iface=network.get("iface", None), mtu=network.get("mtu", None), ) ), type=type_, ) ) return gpu_devices def get_database_url(self) -> str: if self.database_url is not None: return self.database_url return ( f"postgresql://root@127.0.0.1:{self.database_port}/gpustack?sslmode=disable" ) def check_database_url(self): if self.database_url is None: return if not self.database_url.startswith( "postgresql://" ) and not self.database_url.startswith("mysql://") and not self.database_url.startswith("sqlite"): raise Exception( "Unsupported database scheme. Supported databases are postgresql, and mysql." ) def init_auth(self): if self.oidc_issuer: self.external_auth_type = AuthProviderEnum.OIDC self.openid_configuration = get_openid_configuration(self.oidc_issuer) elif self.saml_idp_server_url: self.external_auth_type = AuthProviderEnum.SAML @staticmethod def get_data_dir(): app_name = "gpustack" if os.name == "nt": # Windows data_dir = os.path.join(os.environ["APPDATA"], app_name) elif os.name == "posix": data_dir = f"/var/lib/{app_name}" else: raise Exception("Unsupported OS") return os.path.abspath(data_dir) def prepare_jwt_secret_key(self): if self.jwt_secret_key is not None: return key_path = os.path.join(self.data_dir, "jwt_secret_key") if os.path.exists(key_path): with open(key_path, "r") as file: key = file.read().strip() else: key = secrets.token_hex(32) os.makedirs(self.data_dir, exist_ok=True) with open(key_path, "w") as file: file.write(key) self.jwt_secret_key = key def _is_worker(self): return self.server_url is not None def postgres_base_dir(self) -> str: return os.path.join(self.data_dir, "postgresql") def higress_base_dir(self) -> str: return os.path.join(self.data_dir, "higress") def detect_gateway_mode(self): if self.gateway_mode == GatewayModeEnum.auto: if self.server_role() == self.ServerRole.WORKER: self.gateway_mode = GatewayModeEnum.disabled return is_embedded = self.gateway_kubeconfig is None in_cluster = platform.is_inside_kubernetes() if in_cluster and platform.is_supported_higress(self.gateway_ingress_class): self.gateway_mode = GatewayModeEnum.incluster elif is_embedded: # in cluster but not supported higress will fallback to embedded self.gateway_mode = GatewayModeEnum.embedded else: self.gateway_mode = GatewayModeEnum.external if ( self.server_role() == self.ServerRole.WORKER and self.gateway_mode == GatewayModeEnum.embedded ): raise Exception("Cannot run embedded gateway when running as worker.") if self.gateway_mode == GatewayModeEnum.embedded: # path to embed kubeconfig self.gateway_kubeconfig = os.path.join( self.higress_base_dir(), "kubeconfig" ) if ( self.gateway_mode == GatewayModeEnum.external and not platform.is_supported_higress( self.gateway_ingress_class, self.gateway_kubeconfig ) ): raise Exception("The k8s cluster for gpustack does not support Higress.") if self.gateway_plugin_server_url is None: # for embedded gateway model, higress will fetch plugins from gpustack server # for disabled gateway model, gateway_plugin_server_url is not used, so it doesn't matter if it's set or not. address = "127.0.0.1" if self.gateway_mode == GatewayModeEnum.incluster: address = get_first_non_loopback_ip() elif self.gateway_mode == GatewayModeEnum.external: address = self.get_advertise_address() self.gateway_plugin_server_url = f"http://{address}:{self.get_api_port()}" else: if self.gateway_mode == GatewayModeEnum.embedded: raise Exception( "Cannot set gateway_plugin_server_url when running embedded gateway, as the embedded gateway will use the local plugin server." ) class ServerRole(Enum): SERVER = "server" WORKER = "worker" BOTH = "both" def server_role(self) -> ServerRole: if self._is_worker(): return self.ServerRole.WORKER elif self._is_both_role(): return self.ServerRole.BOTH else: return self.ServerRole.SERVER def _is_both_role(self) -> bool: """ Determine if the server is running in both server and worker mode. If the `enable_worker` flag is set to True, the server is running in both modes. If the `disable_worker` flag is set to True, the server is running in server-only mode. If neither flag is set, the presence of a `bootstrap_version` file in the data directory is checked. If the file does not exist, it indicates that the server was installed using a version that defaults to running in both modes. Returns: bool: True if running in both server and worker mode, False otherwise. """ if self._is_worker(): return False elif self.enable_worker: return True elif self.disable_worker: return False # As of v2.0.1, a `bootstrap_version` file is created in data_dir. # If the file exists, it indicates that the server was installed # using a version that defaults to running server-only mode. bootstrap_version_path = os.path.join(self.data_dir, "bootstrap_version") if os.path.exists(bootstrap_version_path): return False return True def get_advertise_address(self) -> str: return self.advertise_address or get_first_non_loopback_ip() def get_namespace(self) -> str: # for the embedded gateway, use the gateway namespace if self.gateway_mode in [GatewayModeEnum.embedded, GatewayModeEnum.external]: return self.gateway_namespace return self.namespace def get_external_hostname(self) -> Optional[str]: hostname = None if self.server_external_url: parsed_url = urlparse(self.server_external_url) hostname = parsed_url.hostname if not hostname: return None try: ipaddress.ip_address(hostname) return None except Exception: return hostname def get_tls_secret_name(self) -> Optional[str]: if not self.ssl_certfile or not self.ssl_keyfile: return None hostname = self.get_external_hostname() if hostname: return f"gpustack-tls-{hostname.replace('.', '-')}" else: return "gpustack-tls-default" def get_server_url(self) -> str: # returns server if not None else returns embedded server url return ( self.server_url or f"http://127.0.0.1:{self.api_port}" if self.api_port else "http://127.0.0.1" ) def get_api_port(self, embedded_worker: bool = False) -> int: if embedded_worker: return self.worker_port if self.server_role() != self.ServerRole.WORKER: return self.api_port return ( self.api_port if self.gateway_mode == GatewayModeEnum.embedded else self.worker_port ) def get_gateway_port(self) -> int: return ( self.port if self.server_role() != self.ServerRole.WORKER else self.worker_port ) def reload_token(self): token = read_registration_token(self.data_dir) if token: self.token = token def reload_worker_config(self, worker_config: Optional[PredefinedConfigNoDefaults]): if worker_config is None: return updated = { **worker_config.model_dump(exclude_none=True), **self._set_worker_fields, } for key, value in updated.items(): if key in self.__class__.model_fields: setattr(self, key, value) self.check_all() def get_system_reserved(self) -> Dict[str, int]: system_reserved_in_bytes = {**(self.system_reserved or {})} system_reserved_in_bytes["ram"] = ( system_reserved_in_bytes.get( "ram", system_reserved_in_bytes.pop("memory", 0) ) << 30 ) system_reserved_in_bytes["vram"] = ( system_reserved_in_bytes.get( "vram", system_reserved_in_bytes.pop("gpu_memory", 0) ) << 30 ) return system_reserved_in_bytes def get_proxy_port(self) -> int: return self.proxy_port def get_proxy_listen_address(self, default: str = "0.0.0.0") -> str: return "127.0.0.1" if self.gateway_mode == GatewayModeEnum.embedded else default def get_proxy_url(self) -> Optional[str]: return f"http://{self.get_proxy_listen_address(self.get_advertise_address())}:{self.get_proxy_port()}" def get_derived_gateway_token(self) -> str: return self._derive_gateway_token def get_image_name( image_name_override: Optional[str], registry: Optional[str] = None, image_repo: str = "gpustack/gpustack", ) -> str: if image_name_override: return image_name_override version = __version__ if version.removeprefix("v") == "0.0.0": version = "dev" prefix = f"{registry}/" if registry else "" return f"{prefix}{image_repo}:{version}" def get_cluster_image_name(worker_config: Optional[PredefinedConfigNoDefaults]) -> str: cfg = get_global_config() if worker_config is None: return get_image_name( image_repo=cfg.image_repo, image_name_override=cfg.image_name_override, registry=determine_default_registry(cfg.system_default_container_registry), ) registry = determine_default_registry( worker_config.system_default_container_registry or cfg.system_default_container_registry ) return get_image_name( image_name_override=worker_config.image_name_override or cfg.image_name_override, image_repo=worker_config.image_repo or cfg.image_repo, registry=registry, ) def get_openid_configuration(issuer: str) -> dict: """Fetch OpenID configuration from the issuer.""" url = f"{issuer.rstrip('/')}/.well-known/openid-configuration" try: use_proxy_env = use_proxy_env_for_url(url) verify = get_system_trust_store_ssl_context() with httpx.Client(timeout=10, verify=verify, trust_env=use_proxy_env) as client: resp = client.get(url) resp.raise_for_status() return resp.json() except Exception as e: raise Exception( f"Failed to get OpenID configuration: {str(e)}. Please check the issuer URL and ensure {url} is accessible." ) from e def get_global_config() -> Config: return _config def set_global_config(cfg: Config): global _config _config = cfg return cfg