prerun.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import os
  2. import sys
  3. import argparse
  4. import logging
  5. from pathlib import Path
  6. from shutil import move
  7. from typing import List, Dict
  8. from gpustack.config.config import Config
  9. from gpustack.schemas.config import GatewayModeEnum
  10. from gpustack.envs import MIGRATION_DATA_DIR, DATA_MIGRATION
  11. from gpustack.logging import setup_logging
  12. from gpustack.cmd.start import (
  13. start_cmd_options,
  14. parse_args,
  15. )
  16. from gpustack.utils.envs import get_gpustack_env
  17. from gpustack.utils.ephemeral_ports import ensure_reserved_against_ephemeral
  18. from gpustack.utils.network import is_port_available, parse_port_range
  19. from gpustack.utils.s6_services import (
  20. gateway_services,
  21. postgres_services,
  22. migration_services,
  23. observability_services,
  24. gpustack_server_services,
  25. all_services,
  26. gpustack_service_name,
  27. )
  28. logger = logging.getLogger(__name__)
  29. def setup_prerun_cmd(subparsers: argparse._SubParsersAction):
  30. parser_server: argparse.ArgumentParser = subparsers.add_parser(
  31. "prerun",
  32. description="Perform pre-run checks and setup s6-overlay configuration for GPUStack.",
  33. )
  34. # following args are hidden and used for debugging or advanced usage
  35. parser_server.add_argument(
  36. "--s6-base-path",
  37. type=str,
  38. help=argparse.SUPPRESS,
  39. default=get_gpustack_env("S6_BASE_PATH"),
  40. )
  41. start_cmd_options(parser_server)
  42. parser_server.set_defaults(func=run)
  43. def run(args: argparse.Namespace):
  44. try:
  45. cfg = parse_args(args)
  46. setup_logging(cfg.debug)
  47. logger.info(
  48. "Starting pre-run checks and setup s6-overlay configuration for GPUStack..."
  49. )
  50. s6_base_path = args.s6_base_path or "/etc/s6-overlay/s6-rc.d"
  51. enabled_services = determine_enabled_services(cfg)
  52. # migrate hardcode postgres data dir if needed for determining dependency services
  53. migrate_hardcode_postgres_data_and_password(cfg, enabled_services)
  54. dependency_services = determine_dependency_services(cfg)
  55. if len(dependency_services) == 0 and len(enabled_services) == 0:
  56. logger.info("No extra s6 services for gpustack to enable.")
  57. else:
  58. logger.info(
  59. f"Enabled s6 services: {enabled_services}, dependencies for gpustack: {dependency_services}"
  60. )
  61. prepare_s6_overlay(enabled_services, dependency_services, Path(s6_base_path))
  62. check_ports_availability(cfg, *enabled_services)
  63. reserve_ports_against_ephemeral(cfg)
  64. if "postgres" in enabled_services:
  65. prepare_postgres_config(cfg)
  66. if cfg.gateway_mode == GatewayModeEnum.embedded:
  67. prepare_gateway_config(cfg)
  68. if use_builtin_grafana(cfg):
  69. prepare_observability_config(cfg)
  70. logger.info("Pre-run checks and setup completed successfully.")
  71. except Exception as e:
  72. logger.fatal(f"Failed to pre-check the configuration: {e}")
  73. sys.exit(1)
  74. def ports_for_services(cfg: Config) -> Dict[int, str]:
  75. ports = {}
  76. is_server = cfg.server_role() in [Config.ServerRole.SERVER, Config.ServerRole.BOTH]
  77. is_worker = cfg.server_role() in [Config.ServerRole.WORKER, Config.ServerRole.BOTH]
  78. # postgres
  79. if not cfg.database_url and is_server:
  80. postgres_services.set_ports(cfg, ports)
  81. if cfg.gateway_mode == GatewayModeEnum.embedded:
  82. gateway_services.set_ports(cfg, ports)
  83. if use_builtin_grafana(cfg):
  84. observability_services.set_ports(cfg, ports)
  85. # gpustack server/worker
  86. gateway_disabled = cfg.gateway_mode == GatewayModeEnum.disabled
  87. enabled_tls = cfg.ssl_certfile is not None and cfg.ssl_keyfile is not None
  88. if is_server:
  89. ports[cfg.port] = gpustack_service_name
  90. ports[cfg.proxy_port] = gpustack_service_name
  91. if enabled_tls:
  92. ports[cfg.tls_port] = gpustack_service_name
  93. if not cfg.disable_metrics:
  94. ports[cfg.metrics_port] = gpustack_service_name
  95. if is_worker:
  96. ports[cfg.worker_port] = gpustack_service_name
  97. if not cfg.disable_worker_metrics:
  98. ports[cfg.worker_metrics_port] = gpustack_service_name
  99. # when gateway is disabled, api port is not required
  100. if not gateway_disabled:
  101. ports[cfg.api_port] = gpustack_service_name
  102. return ports
  103. def check_ports_availability(cfg: Config, *services: str):
  104. # Implement port availability checks here
  105. all_services = list(services) + [gpustack_service_name]
  106. ports = ports_for_services(cfg)
  107. ports_to_check = {
  108. port: service
  109. for port, service in ports.items()
  110. if not all_services or service in all_services
  111. }
  112. should_fail = False
  113. for port, service in ports_to_check.items():
  114. if not is_port_available(port):
  115. logger.error(
  116. f"Port {port} required for service '{service}' is not available."
  117. )
  118. should_fail = True
  119. if should_fail:
  120. raise Exception("One or more required ports are not available.")
  121. def reserve_ports_against_ephemeral(cfg: Config):
  122. """
  123. Reserve gpustack's service and Ray port ranges against the kernel's
  124. ephemeral port range so outbound connections (by gpustack, higress, or
  125. any other sibling process sharing the netns) don't transiently squat on
  126. ports that Ray or inference servers will later bind().
  127. Runs here (in prerun) rather than inside the worker so the reservation
  128. is applied before any s6 service starts. Skipped on server-only hosts:
  129. the configured ranges are only bound by model instances and Ray, both
  130. of which run on worker hosts.
  131. """
  132. if cfg.server_role() not in [Config.ServerRole.WORKER, Config.ServerRole.BOTH]:
  133. return
  134. ranges = []
  135. for name in ("service_port_range", "ray_port_range"):
  136. value = getattr(cfg, name, None)
  137. if not value:
  138. continue
  139. try:
  140. ranges.append((name, parse_port_range(value)))
  141. except ValueError:
  142. logger.debug("Skipping unparseable %s=%r", name, value)
  143. if ranges:
  144. ensure_reserved_against_ephemeral(ranges)
  145. def cleanup_s6_services(base_path: Path, *services: str):
  146. for service in services:
  147. service_path = base_path / service
  148. if service_path.exists():
  149. service_path.unlink()
  150. def create_s6_services(base_path: Path, *services: str):
  151. for service in services:
  152. service_path = base_path / service
  153. service_path.parent.mkdir(parents=True, exist_ok=True)
  154. service_path.touch()
  155. def migrate_hardcode_postgres_data_and_password(
  156. cfg: Config, enabled_services: List[str]
  157. ):
  158. if "postgres" not in enabled_services:
  159. return
  160. # following paths are hardcoded in the postgres s6 service scripts in v2.0.0.
  161. # in post 2.0.0 versions, we support custom data dir via cfg.data_dir.
  162. # here we migrate the data from hardcoded paths to the new paths if needed.
  163. pair = {
  164. Path("/var/lib/gpustack/postgres/data"): Path(cfg.postgres_base_dir()) / "data",
  165. Path("/var/lib/gpustack/postgres_root_pass"): Path(cfg.data_dir)
  166. / "postgres_root_pass",
  167. Path("/var/lib/gpustack/run/migration_done"): get_migration_done_file(cfg),
  168. }
  169. for hardcode_path, target_path in pair.items():
  170. if hardcode_path == target_path or not hardcode_path.exists():
  171. continue
  172. if target_path.exists():
  173. logger.warning(
  174. f"Both hardcoded postgres file/dir {hardcode_path} and postgres file/dir with data_dir {target_path} exist. Only {target_path} will be used."
  175. )
  176. continue
  177. logger.info(
  178. f"Migrating hardcoded postgres file/dir {hardcode_path} to {target_path}"
  179. )
  180. target_path.parent.mkdir(parents=True, exist_ok=True)
  181. move(hardcode_path, target_path)
  182. def prepare_postgres_config(cfg: Config):
  183. # prepare postgres dirs
  184. # same reason as gateway_shared_config_dir
  185. config_path = prepare_env("GPUSTACK_POSTGRES_CONFIG", "postgresql")
  186. with open(config_path, "w") as f:
  187. f.write(f"DATA_DIR={cfg.data_dir}\n")
  188. f.write(f"LOG_DIR={cfg.log_dir}\n")
  189. f.write(f"EMBEDDED_DATABASE_PORT={cfg.database_port}\n")
  190. f.write(f"STATE_MIGRATION_DONE_FILE={get_migration_done_file(cfg)}\n")
  191. f.write(f"POSTGRES_DATA_DIR={os.path.join(cfg.postgres_base_dir(), 'data')}\n")
  192. def get_migration_done_file(cfg: Config) -> Path:
  193. return Path(cfg.data_dir) / "run" / "state_migration_done"
  194. def prepare_gateway_config(cfg: Config):
  195. # prepare gateway dirs
  196. config_path = prepare_env("GPUSTACK_GATEWAY_CONFIG", "gateway")
  197. higress_embedded_kubeconfig = Path(cfg.higress_base_dir()) / "kubeconfig"
  198. if cfg.gateway_mode == GatewayModeEnum.embedded:
  199. with open(config_path, "w") as f:
  200. f.write(f"DATA_DIR={cfg.data_dir}\n")
  201. f.write(f"LOG_DIR={cfg.log_dir}\n")
  202. f.write(f"GATEWAY_HTTP_PORT={cfg.get_gateway_port()}\n")
  203. f.write(f"GATEWAY_HTTPS_PORT={cfg.tls_port}\n")
  204. f.write(f"GATEWAY_CONCURRENCY={cfg.gateway_concurrency}\n")
  205. f.write(f"GPUSTACK_API_PORT={cfg.get_api_port()}\n")
  206. f.write(f"EMBEDDED_KUBECONFIG_PATH={higress_embedded_kubeconfig}\n")
  207. with open(higress_embedded_kubeconfig, "w") as f:
  208. f.write(
  209. f"""apiVersion: v1
  210. kind: Config
  211. clusters:
  212. - name: higress
  213. cluster:
  214. server: https://127.0.0.1:{os.getenv('APISERVER_PORT', '18443')}
  215. insecure-skip-tls-verify: true
  216. users:
  217. - name: higress-admin
  218. user: {{}}
  219. contexts:
  220. - name: higress
  221. context:
  222. cluster: higress
  223. user: higress-admin
  224. current-context: higress
  225. """
  226. )
  227. def prepare_observability_config(cfg: Config):
  228. env_config_path = prepare_env("GPUSTACK_OBSERVABILITY_CONFIG", "observability")
  229. with open(env_config_path, "w") as f:
  230. f.write(f"DATA_DIR={cfg.data_dir}\n")
  231. f.write(f"LOG_DIR={cfg.log_dir}\n")
  232. f.write(f"PROMETHEUS_PORT={cfg.builtin_prometheus_port}\n")
  233. f.write(f"GF_SERVER_HTTP_PORT={cfg.builtin_grafana_port}\n")
  234. f.write(f"PROMETHEUS_DATA_DIR={os.path.join(cfg.data_dir, 'prometheus')}\n")
  235. f.write(f"GF_PATHS_DATA={os.path.join(cfg.data_dir, 'grafana')}\n")
  236. f.write(f"GF_PATHS_LOGS={os.path.join(cfg.log_dir, 'grafana')}\n")
  237. f.write(
  238. f"GF_PATHS_PLUGINS={os.path.join(cfg.data_dir, 'grafana', 'plugins')}\n"
  239. )
  240. prometheus_config_path = Path(
  241. os.getenv("PROMETHEUS_CONFIG_FILE", "/etc/prometheus/prometheus.yml")
  242. )
  243. prometheus_config_path.parent.mkdir(parents=True, exist_ok=True)
  244. prometheus_config_path.write_text(
  245. f"""# Managed by GPUStack
  246. global:
  247. scrape_interval: 15s
  248. scrape_timeout: 10s
  249. evaluation_interval: 15s
  250. scrape_configs:
  251. - job_name: gpustack-worker-discovery
  252. scrape_interval: 5s
  253. http_sd_configs:
  254. - url: "http://127.0.0.1:{cfg.metrics_port}/metrics/targets"
  255. refresh_interval: 1m
  256. - job_name: gpustack-proxy-worker-discovery
  257. scrape_interval: 5s
  258. proxy_url: "http://127.0.0.1:{cfg.get_proxy_port()}"
  259. http_sd_configs:
  260. - url: "http://127.0.0.1:{cfg.metrics_port}/metrics/proxy-targets"
  261. refresh_interval: 1m
  262. - job_name: gpustack-server
  263. scrape_interval: 5s
  264. static_configs:
  265. - targets:
  266. - 127.0.0.1:{cfg.metrics_port}
  267. """
  268. )
  269. grafana_provisioning_dir = Path(
  270. os.getenv("GF_PATHS_PROVISIONING", "/etc/grafana/provisioning")
  271. )
  272. datasource_path = grafana_provisioning_dir / "datasources" / "datasource.yaml"
  273. datasource_path.parent.mkdir(parents=True, exist_ok=True)
  274. datasource_path.write_text(
  275. f"""apiVersion: 1
  276. datasources:
  277. - name: Prometheus
  278. type: prometheus
  279. uid: prometheus
  280. url: http://127.0.0.1:{cfg.builtin_prometheus_port}/prometheus
  281. isDefault: true
  282. access: proxy
  283. editable: true
  284. orgId: 1
  285. """
  286. )
  287. def determine_enabled_services(cfg: Config) -> List[str]:
  288. services = []
  289. # embedded database
  290. if cfg.database_url is None and cfg.server_role() in [
  291. Config.ServerRole.SERVER,
  292. Config.ServerRole.BOTH,
  293. ]:
  294. services.extend(postgres_services.all_services())
  295. # gateway services
  296. if cfg.gateway_mode == GatewayModeEnum.embedded:
  297. services.extend(gateway_services.all_services())
  298. # embedded observability
  299. if use_builtin_grafana(cfg):
  300. services.extend(observability_services.all_services())
  301. # gpustack server (always enabled for server/both roles)
  302. if cfg.server_role() in [
  303. Config.ServerRole.SERVER,
  304. Config.ServerRole.BOTH,
  305. ]:
  306. services.extend(gpustack_server_services.all_services())
  307. return services
  308. def use_builtin_grafana(cfg: Config) -> bool:
  309. if cfg.disable_builtin_observability:
  310. return False
  311. if cfg.grafana_url is not None:
  312. return False
  313. server_role = cfg.server_role() in [
  314. Config.ServerRole.SERVER,
  315. Config.ServerRole.BOTH,
  316. ]
  317. return server_role
  318. def determine_dependency_services(cfg: Config) -> List[str]:
  319. dependencies = []
  320. if cfg.server_role() in [
  321. Config.ServerRole.SERVER,
  322. Config.ServerRole.BOTH,
  323. ]:
  324. # embedded database
  325. if cfg.database_url is None:
  326. dependencies.extend(postgres_services.dep_services)
  327. # migration
  328. old_db_file = Path(cfg.data_dir) / "database.db"
  329. should_migrate = (
  330. MIGRATION_DATA_DIR is not None or DATA_MIGRATION
  331. ) and old_db_file.exists()
  332. if should_migrate and MIGRATION_DATA_DIR is not None:
  333. logger.warning(
  334. f"The environment variable GPUSTACK_MIGRATION_DATA_DIR is deprecated. The migration target dir will be set to {cfg.data_dir} instead."
  335. )
  336. # This is the hardcooded migration done file path
  337. migration_done = get_migration_done_file(cfg).exists()
  338. postgres_data_exists = (Path(cfg.data_dir) / "postgres" / "data").exists()
  339. if (
  340. cfg.database_url is None
  341. and should_migrate
  342. and not migration_done
  343. and not postgres_data_exists
  344. ):
  345. dependencies.extend(migration_services.dep_services)
  346. # gateway services
  347. if cfg.gateway_mode == GatewayModeEnum.embedded:
  348. dependencies.extend(gateway_services.dep_services)
  349. # gpustack server dependencies
  350. if cfg.server_role() in [
  351. Config.ServerRole.SERVER,
  352. Config.ServerRole.BOTH,
  353. ]:
  354. dependencies.extend(gpustack_server_services.dep_services)
  355. return dependencies
  356. def prepare_s6_overlay(
  357. enabled_services: List[str],
  358. dependency_services: List[str],
  359. s6_base_path: Path = Path("/etc/s6-overlay/s6-rc.d"),
  360. ):
  361. s6_overlay_path = s6_base_path / "user/contents.d"
  362. s6_overlay_path.mkdir(parents=True, exist_ok=True)
  363. cleanup_s6_services(s6_overlay_path, *all_services())
  364. create_s6_services(
  365. s6_overlay_path, *(set(enabled_services) | set(dependency_services))
  366. )
  367. def prepare_env(env_name: str, scope: str, env_file_name: str = ".env") -> Path:
  368. config_path = os.getenv(env_name)
  369. if config_path is None:
  370. base_dir = Path(os.getenv("GPUSTACK_RUN_DIR", "/run/gpustack"))
  371. config_path = base_dir / scope / env_file_name
  372. else:
  373. config_path = Path(config_path)
  374. config_path.parent.mkdir(parents=True, exist_ok=True)
  375. return config_path