config.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928
  1. import ipaddress
  2. import logging
  3. import os
  4. import secrets
  5. import socket
  6. import uuid
  7. from enum import Enum
  8. from typing import List, Optional, Dict
  9. from urllib.parse import urlparse
  10. import httpx
  11. import hmac
  12. import hashlib
  13. from gpustack_runtime.detector import (
  14. manufacturer_to_backend,
  15. available_manufacturers,
  16. available_backends,
  17. )
  18. from pydantic import model_validator
  19. from pydantic_settings import BaseSettings, SettingsConfigDict
  20. from gpustack.utils import validators
  21. from gpustack.schemas.workers import (
  22. CPUInfo,
  23. FileSystemInfo,
  24. GPUDeviceStatus,
  25. KernelInfo,
  26. MemoryInfo,
  27. MountPoint,
  28. OperatingSystemInfo,
  29. SwapInfo,
  30. SystemInfo,
  31. UptimeInfo,
  32. GPUDevicesStatus,
  33. GPUNetworkInfo,
  34. )
  35. from gpustack.schemas.users import AuthProviderEnum
  36. from gpustack.schemas.config import (
  37. ModelInstanceProxyModeEnum,
  38. PredefinedConfig,
  39. PredefinedConfigNoDefaults,
  40. GatewayModeEnum,
  41. )
  42. from gpustack import __version__
  43. from gpustack.config.registration import (
  44. read_registration_token,
  45. read_worker_token,
  46. determine_default_registry,
  47. )
  48. from gpustack.utils.network import (
  49. get_first_non_loopback_ip,
  50. get_system_trust_store_ssl_context,
  51. use_proxy_env_for_url,
  52. )
  53. from gpustack.utils import platform
  54. _config = None
  55. logger = logging.getLogger(__name__)
  56. class WorkerConfig(PredefinedConfig):
  57. # common config which should be dynamic or not configurable
  58. data_dir: Optional[str] = None
  59. advertise_address: Optional[str] = None
  60. # Worker options which are different for each worker
  61. token: Optional[str] = None
  62. server_url: Optional[str] = None
  63. worker_ip: Optional[str] = None
  64. worker_ifname: Optional[str] = None
  65. worker_name: Optional[str] = None
  66. class Config(WorkerConfig, BaseSettings):
  67. """A class used to define GPUStack configuration.
  68. Attributes:
  69. port: Port to bind the server to. Default is 80.
  70. tls_port: Port to bind the TLS server to. Default is 443.
  71. api_port: Port to bind the gpustack API server to. Default is 30080.
  72. advertise_address: The address to expose for external access. Auto-detected by default.
  73. debug: Enable debug mode.
  74. data_dir: Directory to store data. Default is OS specific.
  75. huggingface_token: User Access Token to authenticate to the Hugging Face Hub.
  76. metrics_port: Port to expose metrics on.
  77. disable_metrics: Disable server metrics.
  78. ssl_keyfile: Path to the SSL key file.
  79. ssl_certfile: Path to the SSL certificate file.
  80. database_url: URL of the database.
  81. disable_worker: (Deprecated) Disable embedded worker.
  82. enable_worker: Enable embedded worker.
  83. bootstrap_password: Password for the bootstrap admin user.
  84. jwt_secret_key: Secret key for JWT. Auto-generated by default.
  85. force_auth_localhost: Force authentication for requests originating from
  86. localhost (127.0.0.1). When set to True, all requests
  87. from localhost will require authentication.
  88. disable_update_check: Disable update check.
  89. update_check_url: URL to check for updates.
  90. model_catalog_file: Path or URL to the model catalog file.
  91. token: Shared secret used to register worker.
  92. server_url: URL of the server.
  93. worker_ip: IP address of the worker node. Auto-detected by default.
  94. worker_ifname: Network interface name of the worker node. Auto-detected by default.
  95. worker_name: Name of the worker node. Use the hostname by default.
  96. disable_worker_metrics: Disable worker metrics.
  97. worker_metrics_port: Port to expose metrics on.
  98. worker_port: Port to bind the worker to.
  99. service_port_range: Port range for inference services, specified as a string in the form 'N1-N2'. Both ends of the range are inclusive. Default is '40000-40063'.
  100. ray_port_range: Port range for Ray services(vLLM distributed deployment using), specified as a string in the form 'N1-N2'. Both ends of the range are inclusive. Default is '41000-41999'.
  101. log_dir: Directory to store logs.
  102. bin_dir: Directory to store additional binaries, e.g., versioned backend executables.
  103. benchmark_dir: Directory to store benchmark results.
  104. benchmark_max_duration_seconds: Max duration for a benchmark before timeout. Disabled when unset.
  105. pipx_path: Path to the pipx executable, used to install versioned backends.
  106. system_reserved: Reserved system resources.
  107. tools_download_base_url: Base URL to download dependency tools.
  108. enable_hf_transfer: [Deprecated] No-op since huggingface_hub v1.0 removed hf_transfer support; hf_xet is now the default downloader.
  109. enable_cors: Enable CORS in server.
  110. allow_origins: A list of origins that should be permitted to make cross-origin requests.
  111. allow_credentials: Indicate that cookies should be supported for cross-origin requests.
  112. allow_methods: A list of HTTP methods that should be allowed for cross-origin requests.
  113. allow_headers: A list of HTTP request headers that should be supported for cross-origin requests.
  114. server_external_url: Specified external URL for the server.
  115. system_default_container_registry: Default registry for container images (server and inference images).
  116. image_name_override: Force override of the image name.
  117. image_repo: Repository for the container images.
  118. service_discovery_name: Name of the service discovery service in DNS. Only useful when deployed in Kubernetes with service discovery.
  119. gateway_mode: Gateway deployment mode. Options are 'auto', 'embedded', 'incluster', 'external', 'disabled'. Default is 'auto'.
  120. gateway_kubeconfig: Path to the kubeconfig file for gateway. Only used when gateway_mode is 'external'.
  121. gateway_concurrency: Number of concurrent connections for the embedded gateway. Default is 16.
  122. gateway_namespace: The namespace where the gateway component is deployed.
  123. namespace: Kubernetes namespace for GPUStack to deploy gateway routing rules and model instances.
  124. disable_builtin_observability: Disable embedded Grafana and Prometheus services.
  125. grafana_url: Base URL for Grafana UI used by redirects and proxying. When unset, defaults to the embedded Grafana URL unless builtin observability is disabled.
  126. grafana_worker_dashboard_uid: Grafana dashboard UID for worker dashboard.
  127. grafana_model_dashboard_uid: Grafana dashboard UID for model dashboard.
  128. gateway_plugin_server_url: URL to fetch gateway plugin manifest for embedded gateway.
  129. """
  130. # Server options
  131. # Deprecated, as we using docker image to run the server, host is not used.
  132. host: Optional[str] = None
  133. # The port and tls_port are used in gateway configuration.
  134. port: Optional[int] = 80
  135. tls_port: Optional[int] = 443
  136. # The api_port is used in gpustack server/worker serving API requests.
  137. api_port: Optional[int] = 30080
  138. proxy_port: Optional[int] = 30079
  139. database_port: Optional[int] = 5432
  140. database_url: Optional[str] = None
  141. disable_worker: Optional[bool] = None # Deprecated
  142. enable_worker: bool = False
  143. bootstrap_password: Optional[str] = None
  144. jwt_secret_key: Optional[str] = None
  145. resources: Optional[dict] = None
  146. ssl_keyfile: Optional[str] = None
  147. ssl_certfile: Optional[str] = None
  148. force_auth_localhost: bool = False
  149. metrics_port: int = 10161
  150. disable_metrics: bool = False
  151. disable_update_check: bool = False
  152. disable_openapi_docs: bool = False
  153. update_check_url: Optional[str] = None
  154. model_catalog_file: Optional[str] = None
  155. enable_cors: bool = False
  156. allow_origins: Optional[List[str]] = ['*']
  157. allow_credentials: bool = False
  158. allow_methods: Optional[List[str]] = ['GET', 'POST']
  159. allow_headers: Optional[List[str]] = ['Authorization', 'Content-Type', 'X-API-Key']
  160. external_auth_type: Optional[str] = None # external auth type
  161. external_auth_name: Optional[str] = None # external auth name
  162. external_auth_full_name: Optional[str] = None # external auth full name
  163. external_auth_avatar_url: Optional[str] = None # external auth avatar url
  164. external_auth_default_inactive: bool = False # external auth default inactive
  165. oidc_client_id: Optional[str] = None # oidc client id
  166. oidc_client_secret: Optional[str] = None # oidc client secret
  167. oidc_redirect_uri: Optional[str] = None # oidc redirect uri
  168. oidc_issuer: Optional[str] = None # oidc issuer
  169. oidc_skip_userinfo: bool = False # skip to request the oidc user_info endpoint
  170. oidc_use_userinfo: Optional[bool] = (
  171. None # Deprecated, use oidc_skip_userinfo instead
  172. )
  173. openid_configuration: Optional[dict] = None # fetched openid configuration
  174. saml_sp_entity_id: Optional[str] = None # saml sp_entity_id
  175. saml_sp_acs_url: Optional[str] = None # saml sp_acs_url
  176. saml_sp_x509_cert: Optional[str] = '' # saml sp_x509_cert
  177. saml_sp_private_key: Optional[str] = '' # saml sp_private_key
  178. saml_sp_attribute_prefix: Optional[str] = None # saml sp attribute prefix
  179. saml_idp_entity_id: Optional[str] = None # saml idp_entityId
  180. saml_idp_server_url: Optional[str] = None # saml idp_server_url
  181. saml_idp_logout_url: Optional[str] = None
  182. saml_sp_slo_url: Optional[str] = None
  183. saml_idp_x509_cert: Optional[str] = '' # saml idp_x509_cert
  184. saml_security: Optional[str] = '{}' # saml security
  185. server_external_url: Optional[str] = None
  186. # custom post-logout redirection key for compatibility with different IdPs.
  187. external_auth_post_logout_redirect_key: Optional[str] = None
  188. # Number of concurrent connections for the embedded gateway.
  189. gateway_concurrency: int = 16
  190. gateway_plugin_server_url: Optional[str] = None
  191. gateway_ingress_class: str = "higress"
  192. disable_builtin_observability: bool = False
  193. builtin_prometheus_port: int = 19090
  194. builtin_grafana_port: int = 13000
  195. grafana_url: Optional[str] = None
  196. grafana_worker_dashboard_uid: Optional[str] = "gpustack-worker"
  197. grafana_model_dashboard_uid: Optional[str] = "gpustack-model"
  198. server_id: Optional[str] = None
  199. _set_worker_fields = {}
  200. _derive_gateway_token = None
  201. _jwt_secret_key_user_provided = False
  202. model_config = SettingsConfigDict(
  203. env_prefix="GPUSTACK_", protected_namespaces=('settings_',), extra="allow"
  204. )
  205. def __init__(self, **values):
  206. super().__init__(**values)
  207. self._set_worker_fields = self.model_dump(
  208. exclude_defaults=True,
  209. exclude_unset=True,
  210. exclude_none=True,
  211. include=self.__pydantic_fields_set__
  212. & set(PredefinedConfig.model_fields.keys()),
  213. )
  214. def prepare_dir(dir_path: Optional[str], default: str) -> str:
  215. return default if dir_path is None else os.path.abspath(dir_path)
  216. # common options
  217. self.data_dir = prepare_dir(self.data_dir, self.get_data_dir())
  218. self.cache_dir = prepare_dir(
  219. self.cache_dir, os.path.join(self.data_dir, "cache")
  220. )
  221. self.bin_dir = prepare_dir(self.bin_dir, os.path.join(self.data_dir, "bin"))
  222. self.log_dir = prepare_dir(self.log_dir, os.path.join(self.data_dir, "log"))
  223. self.benchmark_dir = prepare_dir(
  224. self.benchmark_dir, os.path.join(self.data_dir, "benchmarks")
  225. )
  226. if isinstance(self.grafana_url, str) and not self.grafana_url.strip():
  227. self.grafana_url = None
  228. if self.token is None:
  229. self.token = read_registration_token(self.data_dir)
  230. if self.advertise_address is None and os.path.exists(
  231. os.path.join(self.data_dir, "advertise_address")
  232. ):
  233. with open(os.path.join(self.data_dir, "advertise_address"), "r") as f:
  234. addr = f.read().strip()
  235. try:
  236. if len(addr) > 0 and ipaddress.ip_address(addr):
  237. self.advertise_address = addr
  238. except Exception:
  239. pass
  240. if (
  241. self._is_worker()
  242. and self.token is None
  243. and read_worker_token(self.data_dir) is None
  244. ):
  245. raise Exception("Token is required when running as worker")
  246. # Generate server_id if not provided
  247. if self.server_id is None:
  248. self.server_id = f"{socket.gethostname()}-{uuid.uuid4().hex[:8]}"
  249. # Snapshot whether jwt_secret_key came from user input (flag / env / config
  250. # file) before prepare_jwt_secret_key() auto-fills it. Server startup uses
  251. # this to enforce that distributed mode requires an explicit key.
  252. self._jwt_secret_key_user_provided = self.jwt_secret_key is not None
  253. self.prepare_jwt_secret_key()
  254. # server options
  255. self.init_auth()
  256. if self.system_reserved is None:
  257. self.system_reserved = {"ram": 0, "vram": 0}
  258. if self.service_discovery_name is None:
  259. self.service_discovery_name = "worker" if self._is_worker() else "server"
  260. self.make_dirs()
  261. self.detect_gateway_mode()
  262. # default to worker proxy mode if running as worker
  263. if self.proxy_mode is None:
  264. self.proxy_mode = ModelInstanceProxyModeEnum.WORKER
  265. self._derive_gateway_token = hmac.new(
  266. self.jwt_secret_key.encode(), b"gateway-metrics-push", hashlib.sha256
  267. ).hexdigest()
  268. @model_validator(mode="after")
  269. def check_all(self): # noqa: C901
  270. if 'PYTEST_CURRENT_TEST' in os.environ:
  271. # Skip validation during tests
  272. return self
  273. if (self.ssl_keyfile and not self.ssl_certfile) or (
  274. self.ssl_certfile and not self.ssl_keyfile
  275. ):
  276. raise Exception(
  277. 'Both "ssl_keyfile" and "ssl_certfile" must be provided, or neither.'
  278. )
  279. if self.server_url:
  280. self.server_url = self.server_url.rstrip("/")
  281. if validators.url(self.server_url) is not True:
  282. raise Exception("Invalid server URL.")
  283. if self.resources:
  284. self.get_gpu_devices()
  285. self.get_system_info()
  286. if self.service_port_range:
  287. self.check_port_range(self.service_port_range)
  288. if self.ray_port_range:
  289. self.check_port_range(self.ray_port_range, diff=20)
  290. if self.oidc_use_userinfo is not None:
  291. self.oidc_skip_userinfo = not self.oidc_use_userinfo
  292. if self.database_url is not None:
  293. self.check_database_url()
  294. return self
  295. def get_grafana_url(self) -> Optional[str]:
  296. if self.grafana_url is not None:
  297. return self.grafana_url
  298. if self.disable_builtin_observability:
  299. return None
  300. return f"http://127.0.0.1:{self.builtin_grafana_port}"
  301. def get_builtin_prometheus_url(self) -> Optional[str]:
  302. if self.disable_builtin_observability or self.grafana_url is not None:
  303. return None
  304. return f"http://127.0.0.1:{self.builtin_prometheus_port}"
  305. @staticmethod
  306. def check_port_range(port_range: str, diff: Optional[int] = None):
  307. ports = port_range.split("-")
  308. if len(ports) != 2:
  309. raise Exception(f"Invalid port range: {port_range}")
  310. if not ports[0].isdigit() or not ports[1].isdigit():
  311. raise Exception("Port range must be numeric")
  312. if int(ports[0]) > int(ports[1]):
  313. raise Exception(f"Invalid port range: {ports[0]} > {ports[1]}")
  314. if diff is not None:
  315. if int(ports[1]) - int(ports[0]) + 1 < diff:
  316. raise Exception(
  317. f"Port range is too small: {port_range}, at least {diff} ports are required"
  318. )
  319. def make_dirs(self):
  320. os.makedirs(self.data_dir, exist_ok=True)
  321. os.makedirs(self.cache_dir, exist_ok=True)
  322. os.makedirs(self.bin_dir, exist_ok=True)
  323. os.makedirs(self.log_dir, exist_ok=True)
  324. os.makedirs(self.benchmark_dir, exist_ok=True)
  325. # ensure higress data dir exists
  326. os.makedirs(self.higress_base_dir(), exist_ok=True)
  327. if self.server_role() != self.ServerRole.WORKER:
  328. os.makedirs(self.postgres_base_dir(), exist_ok=True)
  329. def get_system_info(self) -> SystemInfo: # noqa: C901
  330. """get system info from resources
  331. resource example:
  332. ```yaml
  333. resources:
  334. cpu:
  335. total: 10
  336. memory:
  337. total: 34359738368
  338. is_unified_memory: true
  339. swap:
  340. total: 3221225472
  341. filesystem:
  342. - name: Macintosh HD
  343. mount_point: /
  344. mount_from: /dev/disk3s1s1
  345. total: 994662584320
  346. os:
  347. name: macOS
  348. version: "14.5"
  349. kernel:
  350. name: Darwin
  351. release: 23.5.0
  352. version: "Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024;"
  353. architecture: ""
  354. uptime:
  355. uptime: 355250885
  356. boot_time: 2025-02-24T09:17:51.337+0800
  357. ```
  358. """
  359. system_info: SystemInfo = SystemInfo()
  360. if not self.resources:
  361. return None
  362. cpu_dict = self.resources.get("cpu")
  363. if cpu_dict and cpu_dict.get("total"):
  364. system_info.cpu = CPUInfo(total=cpu_dict.get("total"))
  365. memory_dict = self.resources.get("memory")
  366. if memory_dict and memory_dict.get("total"):
  367. system_info.memory = MemoryInfo(total=memory_dict.get("total"))
  368. swap_dict = self.resources.get("swap")
  369. if swap_dict and swap_dict.get("total"):
  370. system_info.swap = SwapInfo(total=swap_dict.get("total"))
  371. filesystem_dict = self.resources.get("filesystem")
  372. if filesystem_dict:
  373. filesystem: FileSystemInfo = []
  374. for fs in filesystem_dict:
  375. name = fs.get("name")
  376. mount_point = fs.get("mount_point")
  377. mount_from = fs.get("mount_from")
  378. total = fs.get("total")
  379. if not name:
  380. raise Exception("Filesystem name is required")
  381. if not mount_point:
  382. raise Exception("Filesystem mount_point is required")
  383. if not mount_from:
  384. raise Exception("Filesystem mount_from is required")
  385. if total is None:
  386. raise Exception("Filesystem total is required")
  387. filesystem.append(
  388. MountPoint(
  389. name=name,
  390. mount_point=mount_point,
  391. mount_from=mount_from,
  392. total=total,
  393. )
  394. )
  395. system_info.filesystem = filesystem
  396. os_dict = self.resources.get("os")
  397. if os_dict:
  398. name = os_dict.get("name")
  399. version = os_dict.get("version")
  400. if not name:
  401. raise Exception("OS name is required")
  402. if not version:
  403. raise Exception("OS version is required")
  404. system_info.os = OperatingSystemInfo(name=name, version=version)
  405. kernel_dict = self.resources.get("kernel")
  406. if kernel_dict:
  407. name = kernel_dict.get("name")
  408. release = kernel_dict.get("release")
  409. version = kernel_dict.get("version")
  410. architecture = kernel_dict.get("architecture")
  411. if not name:
  412. raise Exception("Kernel name is required")
  413. if not release:
  414. raise Exception("Kernel release is required")
  415. if not version:
  416. raise Exception("Kernel version is required")
  417. system_info.kernel = KernelInfo(
  418. name=name, release=release, version=version, architecture=architecture
  419. )
  420. uptime_dict = self.resources.get("uptime")
  421. if uptime_dict:
  422. uptime = uptime_dict.get("uptime")
  423. boot_time = uptime_dict.get("boot_time")
  424. if uptime is None:
  425. raise Exception("Uptime is required")
  426. if not boot_time:
  427. raise Exception("Boot time is required")
  428. system_info.uptime = UptimeInfo(uptime=uptime, boot_time=boot_time)
  429. if not any(
  430. [
  431. system_info.cpu,
  432. system_info.memory,
  433. system_info.swap,
  434. system_info.filesystem,
  435. system_info.os,
  436. system_info.kernel,
  437. system_info.uptime,
  438. ]
  439. ):
  440. return None
  441. return system_info
  442. def get_gpu_devices(self) -> GPUDevicesStatus: # noqa: C901
  443. """get gpu devices from resources
  444. resource example:
  445. ```yaml
  446. resources:
  447. gpu_devices:
  448. - name: Ascend CANN 910b
  449. vendor: ascend
  450. arch_family: Ascend910B2
  451. index: 0
  452. device_index: 0 # optional
  453. device_chip_index: 0 # optional
  454. compute_capability: "9.0" # optional
  455. memory:
  456. total: 22906503168
  457. is_unified_memory: true
  458. network:
  459. status: "up"
  460. inet: "29.17.45.215"
  461. netmask: "255.255.0.0" # optional
  462. mac: "6c34:91:87:3c:ae" # optional
  463. gateway: "29.17.0.1" # optional
  464. iface: "eth4" # optional
  465. mtu: 8192 # optional
  466. ```
  467. """
  468. gpu_devices: GPUDevicesStatus = []
  469. if not self.resources:
  470. return None
  471. gpu_device_dict = self.resources.get("gpu_devices")
  472. if not gpu_device_dict:
  473. return None
  474. for gd in gpu_device_dict:
  475. name = gd.get("name")
  476. arch_family = gd.get("arch_family", None)
  477. index = gd.get("index")
  478. compute_capability = gd.get("compute_capability", None)
  479. device_index = gd.get("device_index", index)
  480. device_chip_index = gd.get("device_chip_index", 0)
  481. vendor = gd.get("vendor")
  482. memory = gd.get("memory")
  483. network = gd.get("network")
  484. runtime_version = gd.get("runtime_version")
  485. type_ = gd.get("type") or manufacturer_to_backend(vendor)
  486. if not name:
  487. raise Exception("GPU device name is required")
  488. if index is None:
  489. raise Exception("GPU device index is required")
  490. vendors = available_manufacturers()
  491. if vendor not in vendors:
  492. raise Exception(
  493. f"Unsupported GPU device vendor, available vendors are: {','.join(map(str, vendors))}"
  494. )
  495. if not memory:
  496. raise Exception("GPU device memory is required")
  497. elif not memory.get("total"):
  498. raise Exception("GPU device memory total is required")
  499. if network:
  500. network_status = network.get("status", "up")
  501. if network_status not in ["up", "down"]:
  502. raise Exception(
  503. "GPU device network status is invalid, supported status are: up, down"
  504. )
  505. network_inet = network.get("inet", None)
  506. if network_inet is None:
  507. raise Exception("GPU device network inet is required")
  508. elif not validators.ip(network_inet):
  509. raise Exception("GPU device network inet is invalid")
  510. network_netmask = network.get("netmask", None)
  511. if network_netmask and not validators.ip(network_netmask):
  512. raise Exception("GPU device network netmask is invalid")
  513. gateway = network.get("gateway", None)
  514. if gateway and not validators.ip(gateway):
  515. raise Exception("GPU device network gateway is invalid")
  516. types = available_backends()
  517. if type_ not in types:
  518. raise Exception(
  519. f"Unsupported GPU type, available type are: {','.join(map(str, types))}"
  520. )
  521. gpu_devices.append(
  522. GPUDeviceStatus(
  523. index=index,
  524. arch_family=arch_family,
  525. compute_capability=compute_capability,
  526. device_index=device_index,
  527. device_chip_index=device_chip_index,
  528. name=name,
  529. vendor=vendor,
  530. runtime_version=runtime_version,
  531. memory=MemoryInfo(
  532. total=memory.get("total"),
  533. is_unified_memory=memory.get("is_unified_memory", False),
  534. ),
  535. network=(
  536. None
  537. if not network
  538. else GPUNetworkInfo(
  539. status=network.get("status", "up"),
  540. inet=network.get("inet"),
  541. netmask=network.get("netmask", ""),
  542. mac=network.get("mac", ""),
  543. gateway=network.get("gateway", ""),
  544. iface=network.get("iface", None),
  545. mtu=network.get("mtu", None),
  546. )
  547. ),
  548. type=type_,
  549. )
  550. )
  551. return gpu_devices
  552. def get_database_url(self) -> str:
  553. if self.database_url is not None:
  554. return self.database_url
  555. return (
  556. f"postgresql://root@127.0.0.1:{self.database_port}/gpustack?sslmode=disable"
  557. )
  558. def check_database_url(self):
  559. if self.database_url is None:
  560. return
  561. if not self.database_url.startswith(
  562. "postgresql://"
  563. ) and not self.database_url.startswith("mysql://") and not self.database_url.startswith("sqlite"):
  564. raise Exception(
  565. "Unsupported database scheme. Supported databases are postgresql, and mysql."
  566. )
  567. def init_auth(self):
  568. if self.oidc_issuer:
  569. self.external_auth_type = AuthProviderEnum.OIDC
  570. self.openid_configuration = get_openid_configuration(self.oidc_issuer)
  571. elif self.saml_idp_server_url:
  572. self.external_auth_type = AuthProviderEnum.SAML
  573. @staticmethod
  574. def get_data_dir():
  575. app_name = "gpustack"
  576. if os.name == "nt": # Windows
  577. data_dir = os.path.join(os.environ["APPDATA"], app_name)
  578. elif os.name == "posix":
  579. data_dir = f"/var/lib/{app_name}"
  580. else:
  581. raise Exception("Unsupported OS")
  582. return os.path.abspath(data_dir)
  583. def prepare_jwt_secret_key(self):
  584. if self.jwt_secret_key is not None:
  585. return
  586. key_path = os.path.join(self.data_dir, "jwt_secret_key")
  587. if os.path.exists(key_path):
  588. with open(key_path, "r") as file:
  589. key = file.read().strip()
  590. else:
  591. key = secrets.token_hex(32)
  592. os.makedirs(self.data_dir, exist_ok=True)
  593. with open(key_path, "w") as file:
  594. file.write(key)
  595. self.jwt_secret_key = key
  596. def _is_worker(self):
  597. return self.server_url is not None
  598. def postgres_base_dir(self) -> str:
  599. return os.path.join(self.data_dir, "postgresql")
  600. def higress_base_dir(self) -> str:
  601. return os.path.join(self.data_dir, "higress")
  602. def detect_gateway_mode(self):
  603. if self.gateway_mode == GatewayModeEnum.auto:
  604. if self.server_role() == self.ServerRole.WORKER:
  605. self.gateway_mode = GatewayModeEnum.disabled
  606. return
  607. is_embedded = self.gateway_kubeconfig is None
  608. in_cluster = platform.is_inside_kubernetes()
  609. if in_cluster and platform.is_supported_higress(self.gateway_ingress_class):
  610. self.gateway_mode = GatewayModeEnum.incluster
  611. elif is_embedded:
  612. # in cluster but not supported higress will fallback to embedded
  613. self.gateway_mode = GatewayModeEnum.embedded
  614. else:
  615. self.gateway_mode = GatewayModeEnum.external
  616. if (
  617. self.server_role() == self.ServerRole.WORKER
  618. and self.gateway_mode == GatewayModeEnum.embedded
  619. ):
  620. raise Exception("Cannot run embedded gateway when running as worker.")
  621. if self.gateway_mode == GatewayModeEnum.embedded:
  622. # path to embed kubeconfig
  623. self.gateway_kubeconfig = os.path.join(
  624. self.higress_base_dir(), "kubeconfig"
  625. )
  626. if (
  627. self.gateway_mode == GatewayModeEnum.external
  628. and not platform.is_supported_higress(
  629. self.gateway_ingress_class, self.gateway_kubeconfig
  630. )
  631. ):
  632. raise Exception("The k8s cluster for gpustack does not support Higress.")
  633. if self.gateway_plugin_server_url is None:
  634. # for embedded gateway model, higress will fetch plugins from gpustack server
  635. # for disabled gateway model, gateway_plugin_server_url is not used, so it doesn't matter if it's set or not.
  636. address = "127.0.0.1"
  637. if self.gateway_mode == GatewayModeEnum.incluster:
  638. address = get_first_non_loopback_ip()
  639. elif self.gateway_mode == GatewayModeEnum.external:
  640. address = self.get_advertise_address()
  641. self.gateway_plugin_server_url = f"http://{address}:{self.get_api_port()}"
  642. else:
  643. if self.gateway_mode == GatewayModeEnum.embedded:
  644. raise Exception(
  645. "Cannot set gateway_plugin_server_url when running embedded gateway, as the embedded gateway will use the local plugin server."
  646. )
  647. class ServerRole(Enum):
  648. SERVER = "server"
  649. WORKER = "worker"
  650. BOTH = "both"
  651. def server_role(self) -> ServerRole:
  652. if self._is_worker():
  653. return self.ServerRole.WORKER
  654. elif self._is_both_role():
  655. return self.ServerRole.BOTH
  656. else:
  657. return self.ServerRole.SERVER
  658. def _is_both_role(self) -> bool:
  659. """
  660. Determine if the server is running in both server and worker mode. If the
  661. `enable_worker` flag is set to True, the server is running in both modes. If the
  662. `disable_worker` flag is set to True, the server is running in server-only mode.
  663. If neither flag is set, the presence of a `bootstrap_version` file in the data
  664. directory is checked. If the file does not exist, it indicates that the server was
  665. installed using a version that defaults to running in both modes.
  666. Returns:
  667. bool: True if running in both server and worker mode, False otherwise.
  668. """
  669. if self._is_worker():
  670. return False
  671. elif self.enable_worker:
  672. return True
  673. elif self.disable_worker:
  674. return False
  675. # As of v2.0.1, a `bootstrap_version` file is created in data_dir.
  676. # If the file exists, it indicates that the server was installed
  677. # using a version that defaults to running server-only mode.
  678. bootstrap_version_path = os.path.join(self.data_dir, "bootstrap_version")
  679. if os.path.exists(bootstrap_version_path):
  680. return False
  681. return True
  682. def get_advertise_address(self) -> str:
  683. return self.advertise_address or get_first_non_loopback_ip()
  684. def get_namespace(self) -> str:
  685. # for the embedded gateway, use the gateway namespace
  686. if self.gateway_mode in [GatewayModeEnum.embedded, GatewayModeEnum.external]:
  687. return self.gateway_namespace
  688. return self.namespace
  689. def get_external_hostname(self) -> Optional[str]:
  690. hostname = None
  691. if self.server_external_url:
  692. parsed_url = urlparse(self.server_external_url)
  693. hostname = parsed_url.hostname
  694. if not hostname:
  695. return None
  696. try:
  697. ipaddress.ip_address(hostname)
  698. return None
  699. except Exception:
  700. return hostname
  701. def get_tls_secret_name(self) -> Optional[str]:
  702. if not self.ssl_certfile or not self.ssl_keyfile:
  703. return None
  704. hostname = self.get_external_hostname()
  705. if hostname:
  706. return f"gpustack-tls-{hostname.replace('.', '-')}"
  707. else:
  708. return "gpustack-tls-default"
  709. def get_server_url(self) -> str:
  710. # returns server if not None else returns embedded server url
  711. return (
  712. self.server_url or f"http://127.0.0.1:{self.api_port}"
  713. if self.api_port
  714. else "http://127.0.0.1"
  715. )
  716. def get_api_port(self, embedded_worker: bool = False) -> int:
  717. if embedded_worker:
  718. return self.worker_port
  719. if self.server_role() != self.ServerRole.WORKER:
  720. return self.api_port
  721. return (
  722. self.api_port
  723. if self.gateway_mode == GatewayModeEnum.embedded
  724. else self.worker_port
  725. )
  726. def get_gateway_port(self) -> int:
  727. return (
  728. self.port
  729. if self.server_role() != self.ServerRole.WORKER
  730. else self.worker_port
  731. )
  732. def reload_token(self):
  733. token = read_registration_token(self.data_dir)
  734. if token:
  735. self.token = token
  736. def reload_worker_config(self, worker_config: Optional[PredefinedConfigNoDefaults]):
  737. if worker_config is None:
  738. return
  739. updated = {
  740. **worker_config.model_dump(exclude_none=True),
  741. **self._set_worker_fields,
  742. }
  743. for key, value in updated.items():
  744. if key in self.__class__.model_fields:
  745. setattr(self, key, value)
  746. self.check_all()
  747. def get_system_reserved(self) -> Dict[str, int]:
  748. system_reserved_in_bytes = {**(self.system_reserved or {})}
  749. system_reserved_in_bytes["ram"] = (
  750. system_reserved_in_bytes.get(
  751. "ram", system_reserved_in_bytes.pop("memory", 0)
  752. )
  753. << 30
  754. )
  755. system_reserved_in_bytes["vram"] = (
  756. system_reserved_in_bytes.get(
  757. "vram", system_reserved_in_bytes.pop("gpu_memory", 0)
  758. )
  759. << 30
  760. )
  761. return system_reserved_in_bytes
  762. def get_proxy_port(self) -> int:
  763. return self.proxy_port
  764. def get_proxy_listen_address(self, default: str = "0.0.0.0") -> str:
  765. return "127.0.0.1" if self.gateway_mode == GatewayModeEnum.embedded else default
  766. def get_proxy_url(self) -> Optional[str]:
  767. return f"http://{self.get_proxy_listen_address(self.get_advertise_address())}:{self.get_proxy_port()}"
  768. def get_derived_gateway_token(self) -> str:
  769. return self._derive_gateway_token
  770. def get_image_name(
  771. image_name_override: Optional[str],
  772. registry: Optional[str] = None,
  773. image_repo: str = "gpustack/gpustack",
  774. ) -> str:
  775. if image_name_override:
  776. return image_name_override
  777. version = __version__
  778. if version.removeprefix("v") == "0.0.0":
  779. version = "dev"
  780. prefix = f"{registry}/" if registry else ""
  781. return f"{prefix}{image_repo}:{version}"
  782. def get_cluster_image_name(worker_config: Optional[PredefinedConfigNoDefaults]) -> str:
  783. cfg = get_global_config()
  784. if worker_config is None:
  785. return get_image_name(
  786. image_repo=cfg.image_repo,
  787. image_name_override=cfg.image_name_override,
  788. registry=determine_default_registry(cfg.system_default_container_registry),
  789. )
  790. registry = determine_default_registry(
  791. worker_config.system_default_container_registry
  792. or cfg.system_default_container_registry
  793. )
  794. return get_image_name(
  795. image_name_override=worker_config.image_name_override
  796. or cfg.image_name_override,
  797. image_repo=worker_config.image_repo or cfg.image_repo,
  798. registry=registry,
  799. )
  800. def get_openid_configuration(issuer: str) -> dict:
  801. """Fetch OpenID configuration from the issuer."""
  802. url = f"{issuer.rstrip('/')}/.well-known/openid-configuration"
  803. try:
  804. use_proxy_env = use_proxy_env_for_url(url)
  805. verify = get_system_trust_store_ssl_context()
  806. with httpx.Client(timeout=10, verify=verify, trust_env=use_proxy_env) as client:
  807. resp = client.get(url)
  808. resp.raise_for_status()
  809. return resp.json()
  810. except Exception as e:
  811. raise Exception(
  812. f"Failed to get OpenID configuration: {str(e)}. Please check the issuer URL and ensure {url} is accessible."
  813. ) from e
  814. def get_global_config() -> Config:
  815. return _config
  816. def set_global_config(cfg: Config):
  817. global _config
  818. _config = cfg
  819. return cfg