| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928 |
- import ipaddress
- import logging
- import os
- import secrets
- import socket
- import uuid
- from enum import Enum
- from typing import List, Optional, Dict
- from urllib.parse import urlparse
- import httpx
- import hmac
- import hashlib
- from gpustack_runtime.detector import (
- manufacturer_to_backend,
- available_manufacturers,
- available_backends,
- )
- from pydantic import model_validator
- from pydantic_settings import BaseSettings, SettingsConfigDict
- from gpustack.utils import validators
- from gpustack.schemas.workers import (
- CPUInfo,
- FileSystemInfo,
- GPUDeviceStatus,
- KernelInfo,
- MemoryInfo,
- MountPoint,
- OperatingSystemInfo,
- SwapInfo,
- SystemInfo,
- UptimeInfo,
- GPUDevicesStatus,
- GPUNetworkInfo,
- )
- from gpustack.schemas.users import AuthProviderEnum
- from gpustack.schemas.config import (
- ModelInstanceProxyModeEnum,
- PredefinedConfig,
- PredefinedConfigNoDefaults,
- GatewayModeEnum,
- )
- from gpustack import __version__
- from gpustack.config.registration import (
- read_registration_token,
- read_worker_token,
- determine_default_registry,
- )
- from gpustack.utils.network import (
- get_first_non_loopback_ip,
- get_system_trust_store_ssl_context,
- use_proxy_env_for_url,
- )
- from gpustack.utils import platform
- _config = None
- logger = logging.getLogger(__name__)
- class WorkerConfig(PredefinedConfig):
- # common config which should be dynamic or not configurable
- data_dir: Optional[str] = None
- advertise_address: Optional[str] = None
- # Worker options which are different for each worker
- token: Optional[str] = None
- server_url: Optional[str] = None
- worker_ip: Optional[str] = None
- worker_ifname: Optional[str] = None
- worker_name: Optional[str] = None
- class Config(WorkerConfig, BaseSettings):
- """A class used to define GPUStack configuration.
- Attributes:
- port: Port to bind the server to. Default is 80.
- tls_port: Port to bind the TLS server to. Default is 443.
- api_port: Port to bind the gpustack API server to. Default is 30080.
- advertise_address: The address to expose for external access. Auto-detected by default.
- debug: Enable debug mode.
- data_dir: Directory to store data. Default is OS specific.
- huggingface_token: User Access Token to authenticate to the Hugging Face Hub.
- metrics_port: Port to expose metrics on.
- disable_metrics: Disable server metrics.
- ssl_keyfile: Path to the SSL key file.
- ssl_certfile: Path to the SSL certificate file.
- database_url: URL of the database.
- disable_worker: (Deprecated) Disable embedded worker.
- enable_worker: Enable embedded worker.
- bootstrap_password: Password for the bootstrap admin user.
- jwt_secret_key: Secret key for JWT. Auto-generated by default.
- force_auth_localhost: Force authentication for requests originating from
- localhost (127.0.0.1). When set to True, all requests
- from localhost will require authentication.
- disable_update_check: Disable update check.
- update_check_url: URL to check for updates.
- model_catalog_file: Path or URL to the model catalog file.
- token: Shared secret used to register worker.
- server_url: URL of the server.
- worker_ip: IP address of the worker node. Auto-detected by default.
- worker_ifname: Network interface name of the worker node. Auto-detected by default.
- worker_name: Name of the worker node. Use the hostname by default.
- disable_worker_metrics: Disable worker metrics.
- worker_metrics_port: Port to expose metrics on.
- worker_port: Port to bind the worker to.
- service_port_range: Port range for inference services, specified as a string in the form 'N1-N2'. Both ends of the range are inclusive. Default is '40000-40063'.
- ray_port_range: Port range for Ray services(vLLM distributed deployment using), specified as a string in the form 'N1-N2'. Both ends of the range are inclusive. Default is '41000-41999'.
- log_dir: Directory to store logs.
- bin_dir: Directory to store additional binaries, e.g., versioned backend executables.
- benchmark_dir: Directory to store benchmark results.
- benchmark_max_duration_seconds: Max duration for a benchmark before timeout. Disabled when unset.
- pipx_path: Path to the pipx executable, used to install versioned backends.
- system_reserved: Reserved system resources.
- tools_download_base_url: Base URL to download dependency tools.
- enable_hf_transfer: [Deprecated] No-op since huggingface_hub v1.0 removed hf_transfer support; hf_xet is now the default downloader.
- enable_cors: Enable CORS in server.
- allow_origins: A list of origins that should be permitted to make cross-origin requests.
- allow_credentials: Indicate that cookies should be supported for cross-origin requests.
- allow_methods: A list of HTTP methods that should be allowed for cross-origin requests.
- allow_headers: A list of HTTP request headers that should be supported for cross-origin requests.
- server_external_url: Specified external URL for the server.
- system_default_container_registry: Default registry for container images (server and inference images).
- image_name_override: Force override of the image name.
- image_repo: Repository for the container images.
- service_discovery_name: Name of the service discovery service in DNS. Only useful when deployed in Kubernetes with service discovery.
- gateway_mode: Gateway deployment mode. Options are 'auto', 'embedded', 'incluster', 'external', 'disabled'. Default is 'auto'.
- gateway_kubeconfig: Path to the kubeconfig file for gateway. Only used when gateway_mode is 'external'.
- gateway_concurrency: Number of concurrent connections for the embedded gateway. Default is 16.
- gateway_namespace: The namespace where the gateway component is deployed.
- namespace: Kubernetes namespace for GPUStack to deploy gateway routing rules and model instances.
- disable_builtin_observability: Disable embedded Grafana and Prometheus services.
- grafana_url: Base URL for Grafana UI used by redirects and proxying. When unset, defaults to the embedded Grafana URL unless builtin observability is disabled.
- grafana_worker_dashboard_uid: Grafana dashboard UID for worker dashboard.
- grafana_model_dashboard_uid: Grafana dashboard UID for model dashboard.
- gateway_plugin_server_url: URL to fetch gateway plugin manifest for embedded gateway.
- """
- # Server options
- # Deprecated, as we using docker image to run the server, host is not used.
- host: Optional[str] = None
- # The port and tls_port are used in gateway configuration.
- port: Optional[int] = 80
- tls_port: Optional[int] = 443
- # The api_port is used in gpustack server/worker serving API requests.
- api_port: Optional[int] = 30080
- proxy_port: Optional[int] = 30079
- database_port: Optional[int] = 5432
- database_url: Optional[str] = None
- disable_worker: Optional[bool] = None # Deprecated
- enable_worker: bool = False
- bootstrap_password: Optional[str] = None
- jwt_secret_key: Optional[str] = None
- resources: Optional[dict] = None
- ssl_keyfile: Optional[str] = None
- ssl_certfile: Optional[str] = None
- force_auth_localhost: bool = False
- metrics_port: int = 10161
- disable_metrics: bool = False
- disable_update_check: bool = False
- disable_openapi_docs: bool = False
- update_check_url: Optional[str] = None
- model_catalog_file: Optional[str] = None
- enable_cors: bool = False
- allow_origins: Optional[List[str]] = ['*']
- allow_credentials: bool = False
- allow_methods: Optional[List[str]] = ['GET', 'POST']
- allow_headers: Optional[List[str]] = ['Authorization', 'Content-Type', 'X-API-Key']
- external_auth_type: Optional[str] = None # external auth type
- external_auth_name: Optional[str] = None # external auth name
- external_auth_full_name: Optional[str] = None # external auth full name
- external_auth_avatar_url: Optional[str] = None # external auth avatar url
- external_auth_default_inactive: bool = False # external auth default inactive
- oidc_client_id: Optional[str] = None # oidc client id
- oidc_client_secret: Optional[str] = None # oidc client secret
- oidc_redirect_uri: Optional[str] = None # oidc redirect uri
- oidc_issuer: Optional[str] = None # oidc issuer
- oidc_skip_userinfo: bool = False # skip to request the oidc user_info endpoint
- oidc_use_userinfo: Optional[bool] = (
- None # Deprecated, use oidc_skip_userinfo instead
- )
- openid_configuration: Optional[dict] = None # fetched openid configuration
- saml_sp_entity_id: Optional[str] = None # saml sp_entity_id
- saml_sp_acs_url: Optional[str] = None # saml sp_acs_url
- saml_sp_x509_cert: Optional[str] = '' # saml sp_x509_cert
- saml_sp_private_key: Optional[str] = '' # saml sp_private_key
- saml_sp_attribute_prefix: Optional[str] = None # saml sp attribute prefix
- saml_idp_entity_id: Optional[str] = None # saml idp_entityId
- saml_idp_server_url: Optional[str] = None # saml idp_server_url
- saml_idp_logout_url: Optional[str] = None
- saml_sp_slo_url: Optional[str] = None
- saml_idp_x509_cert: Optional[str] = '' # saml idp_x509_cert
- saml_security: Optional[str] = '{}' # saml security
- server_external_url: Optional[str] = None
- # custom post-logout redirection key for compatibility with different IdPs.
- external_auth_post_logout_redirect_key: Optional[str] = None
- # Number of concurrent connections for the embedded gateway.
- gateway_concurrency: int = 16
- gateway_plugin_server_url: Optional[str] = None
- gateway_ingress_class: str = "higress"
- disable_builtin_observability: bool = False
- builtin_prometheus_port: int = 19090
- builtin_grafana_port: int = 13000
- grafana_url: Optional[str] = None
- grafana_worker_dashboard_uid: Optional[str] = "gpustack-worker"
- grafana_model_dashboard_uid: Optional[str] = "gpustack-model"
- server_id: Optional[str] = None
- _set_worker_fields = {}
- _derive_gateway_token = None
- _jwt_secret_key_user_provided = False
- model_config = SettingsConfigDict(
- env_prefix="GPUSTACK_", protected_namespaces=('settings_',), extra="allow"
- )
- def __init__(self, **values):
- super().__init__(**values)
- self._set_worker_fields = self.model_dump(
- exclude_defaults=True,
- exclude_unset=True,
- exclude_none=True,
- include=self.__pydantic_fields_set__
- & set(PredefinedConfig.model_fields.keys()),
- )
- def prepare_dir(dir_path: Optional[str], default: str) -> str:
- return default if dir_path is None else os.path.abspath(dir_path)
- # common options
- self.data_dir = prepare_dir(self.data_dir, self.get_data_dir())
- self.cache_dir = prepare_dir(
- self.cache_dir, os.path.join(self.data_dir, "cache")
- )
- self.bin_dir = prepare_dir(self.bin_dir, os.path.join(self.data_dir, "bin"))
- self.log_dir = prepare_dir(self.log_dir, os.path.join(self.data_dir, "log"))
- self.benchmark_dir = prepare_dir(
- self.benchmark_dir, os.path.join(self.data_dir, "benchmarks")
- )
- if isinstance(self.grafana_url, str) and not self.grafana_url.strip():
- self.grafana_url = None
- if self.token is None:
- self.token = read_registration_token(self.data_dir)
- if self.advertise_address is None and os.path.exists(
- os.path.join(self.data_dir, "advertise_address")
- ):
- with open(os.path.join(self.data_dir, "advertise_address"), "r") as f:
- addr = f.read().strip()
- try:
- if len(addr) > 0 and ipaddress.ip_address(addr):
- self.advertise_address = addr
- except Exception:
- pass
- if (
- self._is_worker()
- and self.token is None
- and read_worker_token(self.data_dir) is None
- ):
- raise Exception("Token is required when running as worker")
- # Generate server_id if not provided
- if self.server_id is None:
- self.server_id = f"{socket.gethostname()}-{uuid.uuid4().hex[:8]}"
- # Snapshot whether jwt_secret_key came from user input (flag / env / config
- # file) before prepare_jwt_secret_key() auto-fills it. Server startup uses
- # this to enforce that distributed mode requires an explicit key.
- self._jwt_secret_key_user_provided = self.jwt_secret_key is not None
- self.prepare_jwt_secret_key()
- # server options
- self.init_auth()
- if self.system_reserved is None:
- self.system_reserved = {"ram": 0, "vram": 0}
- if self.service_discovery_name is None:
- self.service_discovery_name = "worker" if self._is_worker() else "server"
- self.make_dirs()
- self.detect_gateway_mode()
- # default to worker proxy mode if running as worker
- if self.proxy_mode is None:
- self.proxy_mode = ModelInstanceProxyModeEnum.WORKER
- self._derive_gateway_token = hmac.new(
- self.jwt_secret_key.encode(), b"gateway-metrics-push", hashlib.sha256
- ).hexdigest()
- @model_validator(mode="after")
- def check_all(self): # noqa: C901
- if 'PYTEST_CURRENT_TEST' in os.environ:
- # Skip validation during tests
- return self
- if (self.ssl_keyfile and not self.ssl_certfile) or (
- self.ssl_certfile and not self.ssl_keyfile
- ):
- raise Exception(
- 'Both "ssl_keyfile" and "ssl_certfile" must be provided, or neither.'
- )
- if self.server_url:
- self.server_url = self.server_url.rstrip("/")
- if validators.url(self.server_url) is not True:
- raise Exception("Invalid server URL.")
- if self.resources:
- self.get_gpu_devices()
- self.get_system_info()
- if self.service_port_range:
- self.check_port_range(self.service_port_range)
- if self.ray_port_range:
- self.check_port_range(self.ray_port_range, diff=20)
- if self.oidc_use_userinfo is not None:
- self.oidc_skip_userinfo = not self.oidc_use_userinfo
- if self.database_url is not None:
- self.check_database_url()
- return self
- def get_grafana_url(self) -> Optional[str]:
- if self.grafana_url is not None:
- return self.grafana_url
- if self.disable_builtin_observability:
- return None
- return f"http://127.0.0.1:{self.builtin_grafana_port}"
- def get_builtin_prometheus_url(self) -> Optional[str]:
- if self.disable_builtin_observability or self.grafana_url is not None:
- return None
- return f"http://127.0.0.1:{self.builtin_prometheus_port}"
- @staticmethod
- def check_port_range(port_range: str, diff: Optional[int] = None):
- ports = port_range.split("-")
- if len(ports) != 2:
- raise Exception(f"Invalid port range: {port_range}")
- if not ports[0].isdigit() or not ports[1].isdigit():
- raise Exception("Port range must be numeric")
- if int(ports[0]) > int(ports[1]):
- raise Exception(f"Invalid port range: {ports[0]} > {ports[1]}")
- if diff is not None:
- if int(ports[1]) - int(ports[0]) + 1 < diff:
- raise Exception(
- f"Port range is too small: {port_range}, at least {diff} ports are required"
- )
- def make_dirs(self):
- os.makedirs(self.data_dir, exist_ok=True)
- os.makedirs(self.cache_dir, exist_ok=True)
- os.makedirs(self.bin_dir, exist_ok=True)
- os.makedirs(self.log_dir, exist_ok=True)
- os.makedirs(self.benchmark_dir, exist_ok=True)
- # ensure higress data dir exists
- os.makedirs(self.higress_base_dir(), exist_ok=True)
- if self.server_role() != self.ServerRole.WORKER:
- os.makedirs(self.postgres_base_dir(), exist_ok=True)
- def get_system_info(self) -> SystemInfo: # noqa: C901
- """get system info from resources
- resource example:
- ```yaml
- resources:
- cpu:
- total: 10
- memory:
- total: 34359738368
- is_unified_memory: true
- swap:
- total: 3221225472
- filesystem:
- - name: Macintosh HD
- mount_point: /
- mount_from: /dev/disk3s1s1
- total: 994662584320
- os:
- name: macOS
- version: "14.5"
- kernel:
- name: Darwin
- release: 23.5.0
- version: "Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024;"
- architecture: ""
- uptime:
- uptime: 355250885
- boot_time: 2025-02-24T09:17:51.337+0800
- ```
- """
- system_info: SystemInfo = SystemInfo()
- if not self.resources:
- return None
- cpu_dict = self.resources.get("cpu")
- if cpu_dict and cpu_dict.get("total"):
- system_info.cpu = CPUInfo(total=cpu_dict.get("total"))
- memory_dict = self.resources.get("memory")
- if memory_dict and memory_dict.get("total"):
- system_info.memory = MemoryInfo(total=memory_dict.get("total"))
- swap_dict = self.resources.get("swap")
- if swap_dict and swap_dict.get("total"):
- system_info.swap = SwapInfo(total=swap_dict.get("total"))
- filesystem_dict = self.resources.get("filesystem")
- if filesystem_dict:
- filesystem: FileSystemInfo = []
- for fs in filesystem_dict:
- name = fs.get("name")
- mount_point = fs.get("mount_point")
- mount_from = fs.get("mount_from")
- total = fs.get("total")
- if not name:
- raise Exception("Filesystem name is required")
- if not mount_point:
- raise Exception("Filesystem mount_point is required")
- if not mount_from:
- raise Exception("Filesystem mount_from is required")
- if total is None:
- raise Exception("Filesystem total is required")
- filesystem.append(
- MountPoint(
- name=name,
- mount_point=mount_point,
- mount_from=mount_from,
- total=total,
- )
- )
- system_info.filesystem = filesystem
- os_dict = self.resources.get("os")
- if os_dict:
- name = os_dict.get("name")
- version = os_dict.get("version")
- if not name:
- raise Exception("OS name is required")
- if not version:
- raise Exception("OS version is required")
- system_info.os = OperatingSystemInfo(name=name, version=version)
- kernel_dict = self.resources.get("kernel")
- if kernel_dict:
- name = kernel_dict.get("name")
- release = kernel_dict.get("release")
- version = kernel_dict.get("version")
- architecture = kernel_dict.get("architecture")
- if not name:
- raise Exception("Kernel name is required")
- if not release:
- raise Exception("Kernel release is required")
- if not version:
- raise Exception("Kernel version is required")
- system_info.kernel = KernelInfo(
- name=name, release=release, version=version, architecture=architecture
- )
- uptime_dict = self.resources.get("uptime")
- if uptime_dict:
- uptime = uptime_dict.get("uptime")
- boot_time = uptime_dict.get("boot_time")
- if uptime is None:
- raise Exception("Uptime is required")
- if not boot_time:
- raise Exception("Boot time is required")
- system_info.uptime = UptimeInfo(uptime=uptime, boot_time=boot_time)
- if not any(
- [
- system_info.cpu,
- system_info.memory,
- system_info.swap,
- system_info.filesystem,
- system_info.os,
- system_info.kernel,
- system_info.uptime,
- ]
- ):
- return None
- return system_info
- def get_gpu_devices(self) -> GPUDevicesStatus: # noqa: C901
- """get gpu devices from resources
- resource example:
- ```yaml
- resources:
- gpu_devices:
- - name: Ascend CANN 910b
- vendor: ascend
- arch_family: Ascend910B2
- index: 0
- device_index: 0 # optional
- device_chip_index: 0 # optional
- compute_capability: "9.0" # optional
- memory:
- total: 22906503168
- is_unified_memory: true
- network:
- status: "up"
- inet: "29.17.45.215"
- netmask: "255.255.0.0" # optional
- mac: "6c34:91:87:3c:ae" # optional
- gateway: "29.17.0.1" # optional
- iface: "eth4" # optional
- mtu: 8192 # optional
- ```
- """
- gpu_devices: GPUDevicesStatus = []
- if not self.resources:
- return None
- gpu_device_dict = self.resources.get("gpu_devices")
- if not gpu_device_dict:
- return None
- for gd in gpu_device_dict:
- name = gd.get("name")
- arch_family = gd.get("arch_family", None)
- index = gd.get("index")
- compute_capability = gd.get("compute_capability", None)
- device_index = gd.get("device_index", index)
- device_chip_index = gd.get("device_chip_index", 0)
- vendor = gd.get("vendor")
- memory = gd.get("memory")
- network = gd.get("network")
- runtime_version = gd.get("runtime_version")
- type_ = gd.get("type") or manufacturer_to_backend(vendor)
- if not name:
- raise Exception("GPU device name is required")
- if index is None:
- raise Exception("GPU device index is required")
- vendors = available_manufacturers()
- if vendor not in vendors:
- raise Exception(
- f"Unsupported GPU device vendor, available vendors are: {','.join(map(str, vendors))}"
- )
- if not memory:
- raise Exception("GPU device memory is required")
- elif not memory.get("total"):
- raise Exception("GPU device memory total is required")
- if network:
- network_status = network.get("status", "up")
- if network_status not in ["up", "down"]:
- raise Exception(
- "GPU device network status is invalid, supported status are: up, down"
- )
- network_inet = network.get("inet", None)
- if network_inet is None:
- raise Exception("GPU device network inet is required")
- elif not validators.ip(network_inet):
- raise Exception("GPU device network inet is invalid")
- network_netmask = network.get("netmask", None)
- if network_netmask and not validators.ip(network_netmask):
- raise Exception("GPU device network netmask is invalid")
- gateway = network.get("gateway", None)
- if gateway and not validators.ip(gateway):
- raise Exception("GPU device network gateway is invalid")
- types = available_backends()
- if type_ not in types:
- raise Exception(
- f"Unsupported GPU type, available type are: {','.join(map(str, types))}"
- )
- gpu_devices.append(
- GPUDeviceStatus(
- index=index,
- arch_family=arch_family,
- compute_capability=compute_capability,
- device_index=device_index,
- device_chip_index=device_chip_index,
- name=name,
- vendor=vendor,
- runtime_version=runtime_version,
- memory=MemoryInfo(
- total=memory.get("total"),
- is_unified_memory=memory.get("is_unified_memory", False),
- ),
- network=(
- None
- if not network
- else GPUNetworkInfo(
- status=network.get("status", "up"),
- inet=network.get("inet"),
- netmask=network.get("netmask", ""),
- mac=network.get("mac", ""),
- gateway=network.get("gateway", ""),
- iface=network.get("iface", None),
- mtu=network.get("mtu", None),
- )
- ),
- type=type_,
- )
- )
- return gpu_devices
- def get_database_url(self) -> str:
- if self.database_url is not None:
- return self.database_url
- return (
- f"postgresql://root@127.0.0.1:{self.database_port}/gpustack?sslmode=disable"
- )
- def check_database_url(self):
- if self.database_url is None:
- return
- if not self.database_url.startswith(
- "postgresql://"
- ) and not self.database_url.startswith("mysql://") and not self.database_url.startswith("sqlite"):
- raise Exception(
- "Unsupported database scheme. Supported databases are postgresql, and mysql."
- )
- def init_auth(self):
- if self.oidc_issuer:
- self.external_auth_type = AuthProviderEnum.OIDC
- self.openid_configuration = get_openid_configuration(self.oidc_issuer)
- elif self.saml_idp_server_url:
- self.external_auth_type = AuthProviderEnum.SAML
- @staticmethod
- def get_data_dir():
- app_name = "gpustack"
- if os.name == "nt": # Windows
- data_dir = os.path.join(os.environ["APPDATA"], app_name)
- elif os.name == "posix":
- data_dir = f"/var/lib/{app_name}"
- else:
- raise Exception("Unsupported OS")
- return os.path.abspath(data_dir)
- def prepare_jwt_secret_key(self):
- if self.jwt_secret_key is not None:
- return
- key_path = os.path.join(self.data_dir, "jwt_secret_key")
- if os.path.exists(key_path):
- with open(key_path, "r") as file:
- key = file.read().strip()
- else:
- key = secrets.token_hex(32)
- os.makedirs(self.data_dir, exist_ok=True)
- with open(key_path, "w") as file:
- file.write(key)
- self.jwt_secret_key = key
- def _is_worker(self):
- return self.server_url is not None
- def postgres_base_dir(self) -> str:
- return os.path.join(self.data_dir, "postgresql")
- def higress_base_dir(self) -> str:
- return os.path.join(self.data_dir, "higress")
- def detect_gateway_mode(self):
- if self.gateway_mode == GatewayModeEnum.auto:
- if self.server_role() == self.ServerRole.WORKER:
- self.gateway_mode = GatewayModeEnum.disabled
- return
- is_embedded = self.gateway_kubeconfig is None
- in_cluster = platform.is_inside_kubernetes()
- if in_cluster and platform.is_supported_higress(self.gateway_ingress_class):
- self.gateway_mode = GatewayModeEnum.incluster
- elif is_embedded:
- # in cluster but not supported higress will fallback to embedded
- self.gateway_mode = GatewayModeEnum.embedded
- else:
- self.gateway_mode = GatewayModeEnum.external
- if (
- self.server_role() == self.ServerRole.WORKER
- and self.gateway_mode == GatewayModeEnum.embedded
- ):
- raise Exception("Cannot run embedded gateway when running as worker.")
- if self.gateway_mode == GatewayModeEnum.embedded:
- # path to embed kubeconfig
- self.gateway_kubeconfig = os.path.join(
- self.higress_base_dir(), "kubeconfig"
- )
- if (
- self.gateway_mode == GatewayModeEnum.external
- and not platform.is_supported_higress(
- self.gateway_ingress_class, self.gateway_kubeconfig
- )
- ):
- raise Exception("The k8s cluster for gpustack does not support Higress.")
- if self.gateway_plugin_server_url is None:
- # for embedded gateway model, higress will fetch plugins from gpustack server
- # for disabled gateway model, gateway_plugin_server_url is not used, so it doesn't matter if it's set or not.
- address = "127.0.0.1"
- if self.gateway_mode == GatewayModeEnum.incluster:
- address = get_first_non_loopback_ip()
- elif self.gateway_mode == GatewayModeEnum.external:
- address = self.get_advertise_address()
- self.gateway_plugin_server_url = f"http://{address}:{self.get_api_port()}"
- else:
- if self.gateway_mode == GatewayModeEnum.embedded:
- raise Exception(
- "Cannot set gateway_plugin_server_url when running embedded gateway, as the embedded gateway will use the local plugin server."
- )
- class ServerRole(Enum):
- SERVER = "server"
- WORKER = "worker"
- BOTH = "both"
- def server_role(self) -> ServerRole:
- if self._is_worker():
- return self.ServerRole.WORKER
- elif self._is_both_role():
- return self.ServerRole.BOTH
- else:
- return self.ServerRole.SERVER
- def _is_both_role(self) -> bool:
- """
- Determine if the server is running in both server and worker mode. If the
- `enable_worker` flag is set to True, the server is running in both modes. If the
- `disable_worker` flag is set to True, the server is running in server-only mode.
- If neither flag is set, the presence of a `bootstrap_version` file in the data
- directory is checked. If the file does not exist, it indicates that the server was
- installed using a version that defaults to running in both modes.
- Returns:
- bool: True if running in both server and worker mode, False otherwise.
- """
- if self._is_worker():
- return False
- elif self.enable_worker:
- return True
- elif self.disable_worker:
- return False
- # As of v2.0.1, a `bootstrap_version` file is created in data_dir.
- # If the file exists, it indicates that the server was installed
- # using a version that defaults to running server-only mode.
- bootstrap_version_path = os.path.join(self.data_dir, "bootstrap_version")
- if os.path.exists(bootstrap_version_path):
- return False
- return True
- def get_advertise_address(self) -> str:
- return self.advertise_address or get_first_non_loopback_ip()
- def get_namespace(self) -> str:
- # for the embedded gateway, use the gateway namespace
- if self.gateway_mode in [GatewayModeEnum.embedded, GatewayModeEnum.external]:
- return self.gateway_namespace
- return self.namespace
- def get_external_hostname(self) -> Optional[str]:
- hostname = None
- if self.server_external_url:
- parsed_url = urlparse(self.server_external_url)
- hostname = parsed_url.hostname
- if not hostname:
- return None
- try:
- ipaddress.ip_address(hostname)
- return None
- except Exception:
- return hostname
- def get_tls_secret_name(self) -> Optional[str]:
- if not self.ssl_certfile or not self.ssl_keyfile:
- return None
- hostname = self.get_external_hostname()
- if hostname:
- return f"gpustack-tls-{hostname.replace('.', '-')}"
- else:
- return "gpustack-tls-default"
- def get_server_url(self) -> str:
- # returns server if not None else returns embedded server url
- return (
- self.server_url or f"http://127.0.0.1:{self.api_port}"
- if self.api_port
- else "http://127.0.0.1"
- )
- def get_api_port(self, embedded_worker: bool = False) -> int:
- if embedded_worker:
- return self.worker_port
- if self.server_role() != self.ServerRole.WORKER:
- return self.api_port
- return (
- self.api_port
- if self.gateway_mode == GatewayModeEnum.embedded
- else self.worker_port
- )
- def get_gateway_port(self) -> int:
- return (
- self.port
- if self.server_role() != self.ServerRole.WORKER
- else self.worker_port
- )
- def reload_token(self):
- token = read_registration_token(self.data_dir)
- if token:
- self.token = token
- def reload_worker_config(self, worker_config: Optional[PredefinedConfigNoDefaults]):
- if worker_config is None:
- return
- updated = {
- **worker_config.model_dump(exclude_none=True),
- **self._set_worker_fields,
- }
- for key, value in updated.items():
- if key in self.__class__.model_fields:
- setattr(self, key, value)
- self.check_all()
- def get_system_reserved(self) -> Dict[str, int]:
- system_reserved_in_bytes = {**(self.system_reserved or {})}
- system_reserved_in_bytes["ram"] = (
- system_reserved_in_bytes.get(
- "ram", system_reserved_in_bytes.pop("memory", 0)
- )
- << 30
- )
- system_reserved_in_bytes["vram"] = (
- system_reserved_in_bytes.get(
- "vram", system_reserved_in_bytes.pop("gpu_memory", 0)
- )
- << 30
- )
- return system_reserved_in_bytes
- def get_proxy_port(self) -> int:
- return self.proxy_port
- def get_proxy_listen_address(self, default: str = "0.0.0.0") -> str:
- return "127.0.0.1" if self.gateway_mode == GatewayModeEnum.embedded else default
- def get_proxy_url(self) -> Optional[str]:
- return f"http://{self.get_proxy_listen_address(self.get_advertise_address())}:{self.get_proxy_port()}"
- def get_derived_gateway_token(self) -> str:
- return self._derive_gateway_token
- def get_image_name(
- image_name_override: Optional[str],
- registry: Optional[str] = None,
- image_repo: str = "gpustack/gpustack",
- ) -> str:
- if image_name_override:
- return image_name_override
- version = __version__
- if version.removeprefix("v") == "0.0.0":
- version = "dev"
- prefix = f"{registry}/" if registry else ""
- return f"{prefix}{image_repo}:{version}"
- def get_cluster_image_name(worker_config: Optional[PredefinedConfigNoDefaults]) -> str:
- cfg = get_global_config()
- if worker_config is None:
- return get_image_name(
- image_repo=cfg.image_repo,
- image_name_override=cfg.image_name_override,
- registry=determine_default_registry(cfg.system_default_container_registry),
- )
- registry = determine_default_registry(
- worker_config.system_default_container_registry
- or cfg.system_default_container_registry
- )
- return get_image_name(
- image_name_override=worker_config.image_name_override
- or cfg.image_name_override,
- image_repo=worker_config.image_repo or cfg.image_repo,
- registry=registry,
- )
- def get_openid_configuration(issuer: str) -> dict:
- """Fetch OpenID configuration from the issuer."""
- url = f"{issuer.rstrip('/')}/.well-known/openid-configuration"
- try:
- use_proxy_env = use_proxy_env_for_url(url)
- verify = get_system_trust_store_ssl_context()
- with httpx.Client(timeout=10, verify=verify, trust_env=use_proxy_env) as client:
- resp = client.get(url)
- resp.raise_for_status()
- return resp.json()
- except Exception as e:
- raise Exception(
- f"Failed to get OpenID configuration: {str(e)}. Please check the issuer URL and ensure {url} is accessible."
- ) from e
- def get_global_config() -> Config:
- return _config
- def set_global_config(cfg: Config):
- global _config
- _config = cfg
- return cfg
|