prerun.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import os
  2. import sys
  3. import argparse
  4. import logging
  5. from pathlib import Path
  6. from shutil import move
  7. from typing import List, Dict
  8. from gpustack.config.config import Config
  9. from gpustack.schemas.config import GatewayModeEnum
  10. from gpustack.envs import MIGRATION_DATA_DIR, DATA_MIGRATION
  11. from gpustack.logging import setup_logging
  12. from gpustack.cmd.start import (
  13. start_cmd_options,
  14. parse_args,
  15. )
  16. from gpustack.utils.envs import get_gpustack_env
  17. from gpustack.utils.ephemeral_ports import ensure_reserved_against_ephemeral
  18. from gpustack.utils.network import is_port_available, parse_port_range
  19. from gpustack.utils.s6_services import (
  20. gateway_services,
  21. postgres_services,
  22. migration_services,
  23. observability_services,
  24. all_services,
  25. gpustack_service_name,
  26. )
  27. logger = logging.getLogger(__name__)
  28. def setup_prerun_cmd(subparsers: argparse._SubParsersAction):
  29. parser_server: argparse.ArgumentParser = subparsers.add_parser(
  30. "prerun",
  31. description="Perform pre-run checks and setup s6-overlay configuration for GPUStack.",
  32. )
  33. # following args are hidden and used for debugging or advanced usage
  34. parser_server.add_argument(
  35. "--s6-base-path",
  36. type=str,
  37. help=argparse.SUPPRESS,
  38. default=get_gpustack_env("S6_BASE_PATH"),
  39. )
  40. start_cmd_options(parser_server)
  41. parser_server.set_defaults(func=run)
  42. def run(args: argparse.Namespace):
  43. try:
  44. cfg = parse_args(args)
  45. setup_logging(cfg.debug)
  46. logger.info(
  47. "Starting pre-run checks and setup s6-overlay configuration for GPUStack..."
  48. )
  49. s6_base_path = args.s6_base_path or "/etc/s6-overlay/s6-rc.d"
  50. enabled_services = determine_enabled_services(cfg)
  51. # migrate hardcode postgres data dir if needed for determining dependency services
  52. migrate_hardcode_postgres_data_and_password(cfg, enabled_services)
  53. dependency_services = determine_dependency_services(cfg)
  54. if len(dependency_services) == 0 and len(enabled_services) == 0:
  55. logger.info("No extra s6 services for gpustack to enable.")
  56. else:
  57. logger.info(
  58. f"Enabled s6 services: {enabled_services}, dependencies for gpustack: {dependency_services}"
  59. )
  60. prepare_s6_overlay(enabled_services, dependency_services, Path(s6_base_path))
  61. check_ports_availability(cfg, *enabled_services)
  62. reserve_ports_against_ephemeral(cfg)
  63. if "postgres" in enabled_services:
  64. prepare_postgres_config(cfg)
  65. if cfg.gateway_mode == GatewayModeEnum.embedded:
  66. prepare_gateway_config(cfg)
  67. if use_builtin_grafana(cfg):
  68. prepare_observability_config(cfg)
  69. logger.info("Pre-run checks and setup completed successfully.")
  70. except Exception as e:
  71. logger.fatal(f"Failed to pre-check the configuration: {e}")
  72. sys.exit(1)
  73. def ports_for_services(cfg: Config) -> Dict[int, str]:
  74. ports = {}
  75. is_server = cfg.server_role() in [Config.ServerRole.SERVER, Config.ServerRole.BOTH]
  76. is_worker = cfg.server_role() in [Config.ServerRole.WORKER, Config.ServerRole.BOTH]
  77. # postgres
  78. if not cfg.database_url and is_server:
  79. postgres_services.set_ports(cfg, ports)
  80. if cfg.gateway_mode == GatewayModeEnum.embedded:
  81. gateway_services.set_ports(cfg, ports)
  82. if use_builtin_grafana(cfg):
  83. observability_services.set_ports(cfg, ports)
  84. # gpustack server/worker
  85. gateway_disabled = cfg.gateway_mode == GatewayModeEnum.disabled
  86. enabled_tls = cfg.ssl_certfile is not None and cfg.ssl_keyfile is not None
  87. if is_server:
  88. ports[cfg.port] = gpustack_service_name
  89. ports[cfg.proxy_port] = gpustack_service_name
  90. if enabled_tls:
  91. ports[cfg.tls_port] = gpustack_service_name
  92. if not cfg.disable_metrics:
  93. ports[cfg.metrics_port] = gpustack_service_name
  94. if is_worker:
  95. ports[cfg.worker_port] = gpustack_service_name
  96. if not cfg.disable_worker_metrics:
  97. ports[cfg.worker_metrics_port] = gpustack_service_name
  98. # when gateway is disabled, api port is not required
  99. if not gateway_disabled:
  100. ports[cfg.api_port] = gpustack_service_name
  101. return ports
  102. def check_ports_availability(cfg: Config, *services: str):
  103. # Implement port availability checks here
  104. all_services = list(services) + [gpustack_service_name]
  105. ports = ports_for_services(cfg)
  106. ports_to_check = {
  107. port: service
  108. for port, service in ports.items()
  109. if not all_services or service in all_services
  110. }
  111. should_fail = False
  112. for port, service in ports_to_check.items():
  113. if not is_port_available(port):
  114. logger.error(
  115. f"Port {port} required for service '{service}' is not available."
  116. )
  117. should_fail = True
  118. if should_fail:
  119. raise Exception("One or more required ports are not available.")
  120. def reserve_ports_against_ephemeral(cfg: Config):
  121. """
  122. Reserve gpustack's service and Ray port ranges against the kernel's
  123. ephemeral port range so outbound connections (by gpustack, higress, or
  124. any other sibling process sharing the netns) don't transiently squat on
  125. ports that Ray or inference servers will later bind().
  126. Runs here (in prerun) rather than inside the worker so the reservation
  127. is applied before any s6 service starts. Skipped on server-only hosts:
  128. the configured ranges are only bound by model instances and Ray, both
  129. of which run on worker hosts.
  130. """
  131. if cfg.server_role() not in [Config.ServerRole.WORKER, Config.ServerRole.BOTH]:
  132. return
  133. ranges = []
  134. for name in ("service_port_range", "ray_port_range"):
  135. value = getattr(cfg, name, None)
  136. if not value:
  137. continue
  138. try:
  139. ranges.append((name, parse_port_range(value)))
  140. except ValueError:
  141. logger.debug("Skipping unparseable %s=%r", name, value)
  142. if ranges:
  143. ensure_reserved_against_ephemeral(ranges)
  144. def cleanup_s6_services(base_path: Path, *services: str):
  145. for service in services:
  146. service_path = base_path / service
  147. if service_path.exists():
  148. service_path.unlink()
  149. def create_s6_services(base_path: Path, *services: str):
  150. for service in services:
  151. service_path = base_path / service
  152. service_path.parent.mkdir(parents=True, exist_ok=True)
  153. service_path.touch()
  154. def migrate_hardcode_postgres_data_and_password(
  155. cfg: Config, enabled_services: List[str]
  156. ):
  157. if "postgres" not in enabled_services:
  158. return
  159. # following paths are hardcoded in the postgres s6 service scripts in v2.0.0.
  160. # in post 2.0.0 versions, we support custom data dir via cfg.data_dir.
  161. # here we migrate the data from hardcoded paths to the new paths if needed.
  162. pair = {
  163. Path("/var/lib/gpustack/postgres/data"): Path(cfg.postgres_base_dir()) / "data",
  164. Path("/var/lib/gpustack/postgres_root_pass"): Path(cfg.data_dir)
  165. / "postgres_root_pass",
  166. Path("/var/lib/gpustack/run/migration_done"): get_migration_done_file(cfg),
  167. }
  168. for hardcode_path, target_path in pair.items():
  169. if hardcode_path == target_path or not hardcode_path.exists():
  170. continue
  171. if target_path.exists():
  172. logger.warning(
  173. f"Both hardcoded postgres file/dir {hardcode_path} and postgres file/dir with data_dir {target_path} exist. Only {target_path} will be used."
  174. )
  175. continue
  176. logger.info(
  177. f"Migrating hardcoded postgres file/dir {hardcode_path} to {target_path}"
  178. )
  179. target_path.parent.mkdir(parents=True, exist_ok=True)
  180. move(hardcode_path, target_path)
  181. def prepare_postgres_config(cfg: Config):
  182. # prepare postgres dirs
  183. # same reason as gateway_shared_config_dir
  184. config_path = prepare_env("GPUSTACK_POSTGRES_CONFIG", "postgresql")
  185. with open(config_path, "w") as f:
  186. f.write(f"DATA_DIR={cfg.data_dir}\n")
  187. f.write(f"LOG_DIR={cfg.log_dir}\n")
  188. f.write(f"EMBEDDED_DATABASE_PORT={cfg.database_port}\n")
  189. f.write(f"STATE_MIGRATION_DONE_FILE={get_migration_done_file(cfg)}\n")
  190. f.write(f"POSTGRES_DATA_DIR={os.path.join(cfg.postgres_base_dir(), 'data')}\n")
  191. def get_migration_done_file(cfg: Config) -> Path:
  192. return Path(cfg.data_dir) / "run" / "state_migration_done"
  193. def prepare_gateway_config(cfg: Config):
  194. # prepare gateway dirs
  195. config_path = prepare_env("GPUSTACK_GATEWAY_CONFIG", "gateway")
  196. higress_embedded_kubeconfig = Path(cfg.higress_base_dir()) / "kubeconfig"
  197. if cfg.gateway_mode == GatewayModeEnum.embedded:
  198. with open(config_path, "w") as f:
  199. f.write(f"DATA_DIR={cfg.data_dir}\n")
  200. f.write(f"LOG_DIR={cfg.log_dir}\n")
  201. f.write(f"GATEWAY_HTTP_PORT={cfg.get_gateway_port()}\n")
  202. f.write(f"GATEWAY_HTTPS_PORT={cfg.tls_port}\n")
  203. f.write(f"GATEWAY_CONCURRENCY={cfg.gateway_concurrency}\n")
  204. f.write(f"GPUSTACK_API_PORT={cfg.get_api_port()}\n")
  205. f.write(f"EMBEDDED_KUBECONFIG_PATH={higress_embedded_kubeconfig}\n")
  206. with open(higress_embedded_kubeconfig, "w") as f:
  207. f.write(
  208. f"""apiVersion: v1
  209. kind: Config
  210. clusters:
  211. - name: higress
  212. cluster:
  213. server: https://127.0.0.1:{os.getenv('APISERVER_PORT', '18443')}
  214. insecure-skip-tls-verify: true
  215. users:
  216. - name: higress-admin
  217. user: {{}}
  218. contexts:
  219. - name: higress
  220. context:
  221. cluster: higress
  222. user: higress-admin
  223. current-context: higress
  224. """
  225. )
  226. def prepare_observability_config(cfg: Config):
  227. env_config_path = prepare_env("GPUSTACK_OBSERVABILITY_CONFIG", "observability")
  228. with open(env_config_path, "w") as f:
  229. f.write(f"DATA_DIR={cfg.data_dir}\n")
  230. f.write(f"LOG_DIR={cfg.log_dir}\n")
  231. f.write(f"PROMETHEUS_PORT={cfg.builtin_prometheus_port}\n")
  232. f.write(f"GF_SERVER_HTTP_PORT={cfg.builtin_grafana_port}\n")
  233. f.write(f"PROMETHEUS_DATA_DIR={os.path.join(cfg.data_dir, 'prometheus')}\n")
  234. f.write(f"GF_PATHS_DATA={os.path.join(cfg.data_dir, 'grafana')}\n")
  235. f.write(f"GF_PATHS_LOGS={os.path.join(cfg.log_dir, 'grafana')}\n")
  236. f.write(
  237. f"GF_PATHS_PLUGINS={os.path.join(cfg.data_dir, 'grafana', 'plugins')}\n"
  238. )
  239. prometheus_config_path = Path(
  240. os.getenv("PROMETHEUS_CONFIG_FILE", "/etc/prometheus/prometheus.yml")
  241. )
  242. prometheus_config_path.parent.mkdir(parents=True, exist_ok=True)
  243. prometheus_config_path.write_text(
  244. f"""# Managed by GPUStack
  245. global:
  246. scrape_interval: 15s
  247. scrape_timeout: 10s
  248. evaluation_interval: 15s
  249. scrape_configs:
  250. - job_name: gpustack-worker-discovery
  251. scrape_interval: 5s
  252. http_sd_configs:
  253. - url: "http://127.0.0.1:{cfg.metrics_port}/metrics/targets"
  254. refresh_interval: 1m
  255. - job_name: gpustack-proxy-worker-discovery
  256. scrape_interval: 5s
  257. proxy_url: "http://127.0.0.1:{cfg.get_proxy_port()}"
  258. http_sd_configs:
  259. - url: "http://127.0.0.1:{cfg.metrics_port}/metrics/proxy-targets"
  260. refresh_interval: 1m
  261. - job_name: gpustack-server
  262. scrape_interval: 5s
  263. static_configs:
  264. - targets:
  265. - 127.0.0.1:{cfg.metrics_port}
  266. """
  267. )
  268. grafana_provisioning_dir = Path(
  269. os.getenv("GF_PATHS_PROVISIONING", "/etc/grafana/provisioning")
  270. )
  271. datasource_path = grafana_provisioning_dir / "datasources" / "datasource.yaml"
  272. datasource_path.parent.mkdir(parents=True, exist_ok=True)
  273. datasource_path.write_text(
  274. f"""apiVersion: 1
  275. datasources:
  276. - name: Prometheus
  277. type: prometheus
  278. uid: prometheus
  279. url: http://127.0.0.1:{cfg.builtin_prometheus_port}/prometheus
  280. isDefault: true
  281. access: proxy
  282. editable: true
  283. orgId: 1
  284. """
  285. )
  286. def determine_enabled_services(cfg: Config) -> List[str]:
  287. services = []
  288. # embedded database
  289. if cfg.database_url is None and cfg.server_role() in [
  290. Config.ServerRole.SERVER,
  291. Config.ServerRole.BOTH,
  292. ]:
  293. services.extend(postgres_services.all_services())
  294. # gateway services
  295. if cfg.gateway_mode == GatewayModeEnum.embedded:
  296. services.extend(gateway_services.all_services())
  297. # embedded observability
  298. if use_builtin_grafana(cfg):
  299. services.extend(observability_services.all_services())
  300. return services
  301. def use_builtin_grafana(cfg: Config) -> bool:
  302. if cfg.disable_builtin_observability:
  303. return False
  304. if cfg.grafana_url is not None:
  305. return False
  306. server_role = cfg.server_role() in [
  307. Config.ServerRole.SERVER,
  308. Config.ServerRole.BOTH,
  309. ]
  310. return server_role
  311. def determine_dependency_services(cfg: Config) -> List[str]:
  312. dependencies = []
  313. if cfg.server_role() in [
  314. Config.ServerRole.SERVER,
  315. Config.ServerRole.BOTH,
  316. ]:
  317. # embedded database
  318. if cfg.database_url is None:
  319. dependencies.extend(postgres_services.dep_services)
  320. # migration
  321. old_db_file = Path(cfg.data_dir) / "database.db"
  322. should_migrate = (
  323. MIGRATION_DATA_DIR is not None or DATA_MIGRATION
  324. ) and old_db_file.exists()
  325. if should_migrate and MIGRATION_DATA_DIR is not None:
  326. logger.warning(
  327. f"The environment variable GPUSTACK_MIGRATION_DATA_DIR is deprecated. The migration target dir will be set to {cfg.data_dir} instead."
  328. )
  329. # This is the hardcooded migration done file path
  330. migration_done = get_migration_done_file(cfg).exists()
  331. postgres_data_exists = (Path(cfg.data_dir) / "postgres" / "data").exists()
  332. if (
  333. cfg.database_url is None
  334. and should_migrate
  335. and not migration_done
  336. and not postgres_data_exists
  337. ):
  338. dependencies.extend(migration_services.dep_services)
  339. # gateway services
  340. if cfg.gateway_mode == GatewayModeEnum.embedded:
  341. dependencies.extend(gateway_services.dep_services)
  342. return dependencies
  343. def prepare_s6_overlay(
  344. enabled_services: List[str],
  345. dependency_services: List[str],
  346. s6_base_path: Path = Path("/etc/s6-overlay/s6-rc.d"),
  347. ):
  348. s6_overlay_path = s6_base_path / "user/contents.d"
  349. s6_overlay_path.mkdir(parents=True, exist_ok=True)
  350. cleanup_s6_services(s6_overlay_path, *all_services())
  351. create_s6_services(
  352. s6_overlay_path, *(set(enabled_services) | set(dependency_services))
  353. )
  354. def prepare_env(env_name: str, scope: str, env_file_name: str = ".env") -> Path:
  355. config_path = os.getenv(env_name)
  356. if config_path is None:
  357. base_dir = Path(os.getenv("GPUSTACK_RUN_DIR", "/run/gpustack"))
  358. config_path = base_dir / scope / env_file_name
  359. else:
  360. config_path = Path(config_path)
  361. config_path.parent.mkdir(parents=True, exist_ok=True)
  362. return config_path