server.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. import asyncio
  2. from multiprocessing import Process
  3. import os
  4. import re
  5. import importlib.util
  6. import aiohttp
  7. import uvicorn
  8. from fastapi import FastAPI
  9. import logging
  10. import secrets
  11. import tenacity
  12. from sqlmodel.ext.asyncio.session import AsyncSession
  13. from gpustack.logging import setup_logging
  14. from gpustack.schemas.users import (
  15. User,
  16. UserRole,
  17. get_default_cluster_user,
  18. default_cluster_user_name,
  19. )
  20. from gpustack.schemas.principals import PLATFORM_PRINCIPAL_ID
  21. from gpustack.schemas.models import ModelInstance
  22. from gpustack.schemas.api_keys import ApiKey
  23. from gpustack.schemas.workers import Worker
  24. from gpustack.schemas.clusters import Cluster, ClusterProvider, ClusterStateEnum
  25. from gpustack.schemas.model_routes import ModelRoute, ModelRouteTarget
  26. from gpustack.schemas.model_provider import ModelProvider
  27. from gpustack.security import (
  28. generate_secure_password,
  29. get_secret_hash,
  30. API_KEY_PREFIX,
  31. )
  32. from gpustack.routes.auth import remove_initial_password_file_if_exists
  33. from gpustack.server.app import create_app
  34. from gpustack.server.services import (
  35. create_user_with_principal,
  36. provision_bootstrap_admin_orgs,
  37. )
  38. from gpustack.config.config import Config
  39. from gpustack.schemas.config import GatewayModeEnum
  40. from gpustack.config import registration
  41. from gpustack.server.catalog import init_model_catalog
  42. from gpustack.server.controllers import (
  43. ModelController,
  44. ModelFileController,
  45. ModelInstanceController,
  46. WorkerController,
  47. ClusterController,
  48. WorkerPoolController,
  49. InferenceBackendController,
  50. ModelRouteController,
  51. ModelRouteTargetController,
  52. ModelProviderController,
  53. )
  54. from gpustack.server.db import async_session
  55. from gpustack.server.init_db import init_db, get_query_count
  56. from gpustack.scheduler.scheduler import Scheduler
  57. from gpustack.server.system_load import SystemLoadCollector
  58. from gpustack.server.update_check import UpdateChecker
  59. from gpustack.server.worker_status_buffer import flush_worker_status_to_db
  60. from gpustack.server.metrics_collector import flush_gateway_metrics_to_db
  61. from gpustack.server.usage_details_archiver import UsageDetailsArchiver
  62. from gpustack.server.worker_instance_cleaner import WorkerInstanceCleaner
  63. from gpustack.server.worker_syncer import WorkerSyncer
  64. from gpustack.utils.process import add_signal_handlers_in_loop
  65. from gpustack.config.registration import write_registration_token
  66. from gpustack.exporter.exporter import MetricExporter
  67. from gpustack.gateway.utils import (
  68. model_ingress_prefix,
  69. model_route_ingress_prefix,
  70. model_route_ingress_name,
  71. fallback_ingress_name,
  72. cleanup_ingresses,
  73. cleanup_model_mapper,
  74. cleanup_fallback_filters,
  75. cleanup_ai_proxy_config,
  76. cleanup_generic_route_transformer,
  77. cleanup_mcpbridge_registry,
  78. resolve_instance_address_from_model_header,
  79. )
  80. from gpustack.gateway import get_async_k8s_config
  81. from gpustack.envs import (
  82. GATEWAY_PORT_CHECK_INTERVAL,
  83. GATEWAY_PORT_CHECK_RETRY_COUNT,
  84. DEFAULT_CLUSTER_KUBERNETES,
  85. )
  86. from gpustack.server.coordinator import LocalCoordinator
  87. from gpustack.server.coordinator.cache import preload_cache
  88. from gpustack.server.coordinator.models import get_model_for_topic
  89. from gpustack.server import bus
  90. from gpustack.server import cache as cache_module
  91. from alembic import command
  92. from alembic.config import Config as AlembicConfig
  93. from gpustack.websocket_proxy.proxy_server import HTTPSProxyServer
  94. from gpustack.api.auth import (
  95. authenticate_worker_by_request_headers,
  96. )
  97. logger = logging.getLogger(__name__)
  98. class Server:
  99. def __init__(self, config: Config, worker_process: Process):
  100. self._config: Config = config
  101. self._sub_processes = []
  102. self._async_tasks = []
  103. self._worker_process = worker_process
  104. # Coordination components
  105. self._coordinator = None
  106. self._leader_election_task = None
  107. @property
  108. def all_processes(self):
  109. return self._sub_processes
  110. def _create_async_task(self, coro):
  111. self._async_tasks.append(asyncio.create_task(coro))
  112. @property
  113. def config(self):
  114. return self._config
  115. async def start(self):
  116. logger.info("Starting GPUStack server.")
  117. add_signal_handlers_in_loop()
  118. self._run_migrations()
  119. await self._prepare_data()
  120. init_model_catalog(self._config.model_catalog_file)
  121. # it's safe to determine server_role after migration
  122. if self._config.server_role() == Config.ServerRole.BOTH:
  123. self._sub_processes.append(self._worker_process)
  124. # Create FastAPI app. Plugin ``__init__(app, cfg)`` runs here and
  125. # may attach a distributed-mode coordinator to the plugin instance.
  126. app = create_app(self._config)
  127. self._app = app
  128. # Initialize coordinator from plugin instances (LocalCoordinator if
  129. # none supplied). Must run before the event bus goes online so any
  130. # early publishes are routed correctly.
  131. await self._init_coordinator(app)
  132. # Preload change-detection cache after the coordinator is up.
  133. # Required in distributed mode so the first cross-instance event
  134. # on each topic carries accurate ``changed_fields``.
  135. await self._preload_change_detector_cache()
  136. self._start_sub_processes()
  137. # Start Leader-Only tasks (includes scheduler and controllers)
  138. # In single-node mode, they start immediately.
  139. # In distributed mode, they start only when this node becomes leader.
  140. await self._start_leader_only_tasks()
  141. # These tasks can run on all instances
  142. self._start_worker_status_flusher()
  143. self._start_gateway_metrics_flusher()
  144. self._start_metrics_exporter()
  145. self._start_query_count_logger()
  146. self._start_default_registry_checker()
  147. self._start_proxy_servers(app)
  148. self._start_extension_plugins(app)
  149. serving_host = (
  150. "127.0.0.1"
  151. if self._config.gateway_mode == GatewayModeEnum.embedded
  152. else "0.0.0.0"
  153. )
  154. config = uvicorn.Config(
  155. app,
  156. host=serving_host,
  157. port=self._config.get_api_port(),
  158. access_log=False,
  159. log_level="error",
  160. )
  161. setup_logging()
  162. logger.info(f"Gateway mode: {self._config.gateway_mode.value}.")
  163. serving_api_message = f"Serving GPUStack API on {config.host}:{config.port}."
  164. if self._config.gateway_mode == GatewayModeEnum.embedded:
  165. logger.debug(serving_api_message)
  166. logger.info(
  167. f"GPUStack Server will serve on 0.0.0.0:{self._config.get_gateway_port()}."
  168. )
  169. if self._config.get_tls_secret_name() is not None:
  170. logger.info(
  171. f"GPUStack Server will serve TLS on 0.0.0.0:{self._config.tls_port}."
  172. )
  173. else:
  174. logger.info(serving_api_message)
  175. server = uvicorn.Server(config)
  176. self._create_async_task(server.serve())
  177. await asyncio.gather(*self._async_tasks)
  178. def _start_default_registry_checker(self):
  179. registration.determine_default_registry(
  180. self._config.system_default_container_registry,
  181. ),
  182. def _run_migrations(self):
  183. logger.info("Running database migration.")
  184. spec = importlib.util.find_spec("gpustack")
  185. if spec is None:
  186. raise ImportError("The 'gpustack' package is not found.")
  187. pkg_path = spec.submodule_search_locations[0]
  188. alembic_cfg = AlembicConfig()
  189. alembic_cfg.set_main_option(
  190. "script_location", os.path.join(pkg_path, "migrations")
  191. )
  192. db_url = self._config.get_database_url()
  193. # Use the pymysql driver to execute migrations to avoid compatibility issues between asynchronous drivers and Alembic.
  194. if db_url.startswith("mysql://"):
  195. db_url = re.sub(r'^mysql://', 'mysql+pymysql://', db_url)
  196. db_url_escaped = db_url.replace("%", "%%")
  197. alembic_cfg.set_main_option("sqlalchemy.url", db_url_escaped)
  198. try:
  199. command.upgrade(alembic_cfg, "head")
  200. except Exception as e:
  201. raise RuntimeError(f"Database migration failed: {e}") from e
  202. logger.info("Database migration completed.")
  203. async def _prepare_data(self):
  204. self._setup_data_dir(self._config.data_dir)
  205. await init_db(self._config.get_database_url())
  206. async with async_session() as session:
  207. await self._init_data(session)
  208. logger.debug("Data initialization completed.")
  209. def _start_scheduler(self):
  210. """Start the scheduler and return the task."""
  211. scheduler = Scheduler(self._config)
  212. task = asyncio.create_task(scheduler.start())
  213. logger.debug("Scheduler started.")
  214. return task
  215. def _start_controllers(self):
  216. """Start all controllers and return the list of tasks."""
  217. tasks = []
  218. model_provider_controller = ModelProviderController(self._config)
  219. tasks.append(asyncio.create_task(model_provider_controller.start()))
  220. model_route_target_controller = ModelRouteTargetController(self._config)
  221. tasks.append(asyncio.create_task(model_route_target_controller.start()))
  222. model_route_controller = ModelRouteController(self._config)
  223. tasks.append(asyncio.create_task(model_route_controller.start()))
  224. model_controller = ModelController(self._config)
  225. tasks.append(asyncio.create_task(model_controller.start()))
  226. model_instance_controller = ModelInstanceController(self._config)
  227. tasks.append(asyncio.create_task(model_instance_controller.start()))
  228. worker_controller = WorkerController(self._config)
  229. tasks.append(asyncio.create_task(worker_controller.start()))
  230. model_file_controller = ModelFileController()
  231. tasks.append(asyncio.create_task(model_file_controller.start()))
  232. cluster_controller = ClusterController(self._config)
  233. tasks.append(asyncio.create_task(cluster_controller.start()))
  234. worker_pool_controller = WorkerPoolController()
  235. tasks.append(asyncio.create_task(worker_pool_controller.start()))
  236. inference_backend_controller = InferenceBackendController()
  237. tasks.append(asyncio.create_task(inference_backend_controller.start()))
  238. logger.debug("Controllers started.")
  239. return tasks
  240. def _start_system_load_collector(self):
  241. collector = SystemLoadCollector()
  242. self._create_async_task(collector.start())
  243. logger.debug("System load collector started.")
  244. def _start_worker_syncer(self, app: FastAPI):
  245. worker_syncer = WorkerSyncer(
  246. lambda: getattr(app.state, "http_client", None),
  247. lambda: getattr(app.state, "http_client_no_proxy", None),
  248. )
  249. self._create_async_task(worker_syncer.start())
  250. logger.debug("Worker syncer started.")
  251. def _start_worker_status_flusher(self):
  252. self._create_async_task(flush_worker_status_to_db())
  253. logger.debug("Worker status flusher started.")
  254. def _start_gateway_metrics_flusher(self):
  255. # Always start — both the gateway report endpoint and the in-process
  256. # ModelUsageMiddleware feed the same buffer, so the flusher must run
  257. # even when the external gateway is disabled.
  258. self._create_async_task(flush_gateway_metrics_to_db())
  259. logger.debug("Gateway metrics flusher started.")
  260. def _start_worker_instance_cleaner(self):
  261. worker_instance_cleaner = WorkerInstanceCleaner()
  262. self._create_async_task(worker_instance_cleaner.start())
  263. logger.debug("Worker instance cleaner started.")
  264. def _start_usage_details_archiver(self):
  265. # Construction can fail on schema drift between hot/archive tables or
  266. # an invalid cron expression. Surface that loudly and skip launching
  267. # the loop so the rest of the leader tasks (and the leader-election
  268. # retry) aren't taken down with it. Without the archiver the
  269. # model_usage_details hot table will grow unbounded — operators must
  270. # see this in logs rather than have it buried as "Leader election
  271. # error" by the outer election handler.
  272. try:
  273. archiver = UsageDetailsArchiver()
  274. except Exception:
  275. logger.critical(
  276. "Usage details archiver failed to initialize — archival is "
  277. "DISABLED. The model_usage_details hot table will grow "
  278. "unbounded until this is resolved.",
  279. exc_info=True,
  280. )
  281. return
  282. self._create_async_task(archiver.start())
  283. logger.debug("Usage details archiver started.")
  284. def _start_update_checker(self):
  285. """Start update checker."""
  286. if self._config.disable_update_check:
  287. return
  288. update_checker = UpdateChecker(update_check_url=self._config.update_check_url)
  289. self._create_async_task(update_checker.start())
  290. logger.debug("Update checker started.")
  291. async def _monitor_sub_processes(self):
  292. while self._sub_processes:
  293. for process in self._sub_processes[:]:
  294. if not process.is_alive():
  295. if process.exitcode != 0:
  296. raise RuntimeError(
  297. f"Sub process {process.name} died with exit code {process.exitcode}"
  298. )
  299. self._sub_processes.remove(process)
  300. await asyncio.sleep(5)
  301. def _start_sub_processes(self):
  302. async def start_process_after_api_ready():
  303. api_url = f"http://127.0.0.1:{self._config.api_port}/healthz"
  304. async with aiohttp.ClientSession() as session:
  305. while True:
  306. try:
  307. await asyncio.sleep(2)
  308. async with session.get(api_url) as response:
  309. if response.status == 200:
  310. break
  311. except aiohttp.ClientError:
  312. pass
  313. except asyncio.CancelledError:
  314. return
  315. for process in self._sub_processes:
  316. process.start()
  317. await self._monitor_sub_processes()
  318. if len(self._sub_processes) == 0:
  319. return
  320. self._create_async_task(start_process_after_api_ready())
  321. async def _wait_for_gateway_ready(self):
  322. if self._config.gateway_mode != GatewayModeEnum.embedded:
  323. return
  324. # http port is always started
  325. ports = [self._config.port]
  326. if self._config.get_tls_secret_name() is not None:
  327. ports.append(self._config.tls_port)
  328. logger.info(f"Waiting for ports {ports} of GPUStack to be ready...")
  329. # wait for gateway ready for about 60s
  330. await self._check_ports_ready(*ports)
  331. logger.info("GPUStack Server is ready.")
  332. @tenacity.retry(
  333. stop=tenacity.stop_after_attempt(GATEWAY_PORT_CHECK_RETRY_COUNT),
  334. wait=tenacity.wait_fixed(GATEWAY_PORT_CHECK_INTERVAL),
  335. reraise=True,
  336. before_sleep=lambda retry_state: logger.debug(
  337. f"Waiting for ports {retry_state.args[1]} to be healthy (attempt {retry_state.attempt_number}) due to: {retry_state.outcome.exception()}"
  338. ),
  339. )
  340. async def _check_ports_ready(self, *ports: int):
  341. for port in ports:
  342. try:
  343. _, writer = await asyncio.open_connection("127.0.0.1", port)
  344. writer.close()
  345. await writer.wait_closed()
  346. except Exception:
  347. raise RuntimeError(f"Port {port} is not healthy or not listening")
  348. def _start_metrics_exporter(self):
  349. if self._config.disable_metrics:
  350. return
  351. exporter = MetricExporter(cfg=self._config)
  352. self._create_async_task(exporter.generate_metrics_cache())
  353. self._create_async_task(exporter.start())
  354. def _start_query_count_logger(self):
  355. """Start a background task to log query count periodically."""
  356. async def log_query_count():
  357. while True:
  358. await asyncio.sleep(60) # Log every minute
  359. count = get_query_count()
  360. logger.debug(f"[DB QUERY COUNT] Total queries since startup: {count}")
  361. self._create_async_task(log_query_count())
  362. @staticmethod
  363. def _setup_data_dir(data_dir: str):
  364. if not os.path.exists(data_dir):
  365. os.makedirs(data_dir)
  366. async def _init_data(self, session: AsyncSession):
  367. init_data_funcs = [
  368. self._init_user,
  369. self._init_default_cluster,
  370. self._migrate_legacy_token,
  371. self._migrate_legacy_workers,
  372. self._ensure_registration_token,
  373. self._cleanup_orphaned_gateway_data,
  374. ]
  375. for init_data_func in init_data_funcs:
  376. await init_data_func(session)
  377. async def _init_user(self, session: AsyncSession):
  378. # Skip bootstrap when any non-system admin already exists, so that
  379. # renaming the default "admin" account does not cause a duplicate
  380. # admin to be regenerated on master restart.
  381. existing_admin = await User.first_by_fields(
  382. session=session,
  383. fields={"is_admin": True, "is_system": False, "is_active": True},
  384. )
  385. if existing_admin:
  386. return
  387. # Drop any stale initial password file from a prior bootstrap before
  388. # generating a new one, so the login page does not show an outdated
  389. # "retrieve initial password" hint.
  390. remove_initial_password_file_if_exists(self._config)
  391. bootstrap_password = self._config.bootstrap_password
  392. require_password_change = False
  393. if not bootstrap_password:
  394. require_password_change = True
  395. bootstrap_password = generate_secure_password()
  396. bootstrap_password_file = os.path.join(
  397. self._config.data_dir, "initial_admin_password"
  398. )
  399. with open(bootstrap_password_file, "w") as file:
  400. file.write(bootstrap_password + "\n")
  401. logger.info(
  402. "Generated initial admin password. "
  403. f"You can get it from {bootstrap_password_file}"
  404. )
  405. user = User(
  406. username="admin",
  407. full_name="Default System Admin",
  408. hashed_password=get_secret_hash(bootstrap_password),
  409. is_admin=True,
  410. require_password_change=require_password_change,
  411. )
  412. user = await create_user_with_principal(session, user)
  413. await provision_bootstrap_admin_orgs(session, user)
  414. await session.commit()
  415. async def _migrate_legacy_token(self, session: AsyncSession):
  416. if not self._config.token:
  417. return
  418. # this should be created from sql migration script.
  419. cluster_user = await get_default_cluster_user(session)
  420. if cluster_user is None or cluster_user.cluster is None:
  421. logger.debug(
  422. "Default cluster user not exist, skipping legacy token migration."
  423. )
  424. return
  425. default_cluster = cluster_user.cluster
  426. if not default_cluster:
  427. logger.debug(
  428. "Default cluster does not exist, skipping legacy token migration."
  429. )
  430. return
  431. if default_cluster.registration_token:
  432. return
  433. try:
  434. default_cluster.registration_token = self._config.token
  435. await default_cluster.update(session=session, auto_commit=False)
  436. default_cluster_user = await User.one_by_fields(
  437. session=session,
  438. fields={
  439. "cluster_id": default_cluster.id,
  440. "is_system": True,
  441. "role": UserRole.Cluster,
  442. },
  443. )
  444. if default_cluster_user is None:
  445. raise RuntimeError("Default cluster user does not exist.")
  446. if len(default_cluster_user.api_keys) > 0:
  447. raise RuntimeError(
  448. "Default cluster user already has API keys, cannot migrate legacy token."
  449. )
  450. new_key = ApiKey(
  451. name="Legacy Cluster Token",
  452. access_key="",
  453. hashed_secret_key=get_secret_hash(self._config.token),
  454. user_id=default_cluster_user.id,
  455. user=default_cluster_user,
  456. )
  457. await ApiKey.create(session, new_key, auto_commit=False)
  458. await session.commit()
  459. except Exception as e:
  460. logger.error(f"Failed to migrate legacy token: {e}")
  461. await session.rollback()
  462. raise e
  463. async def _migrate_legacy_workers(self, session: AsyncSession):
  464. # Use hardcode cluster 1 to make sure the cluster is created in migration step
  465. default_cluster = await Cluster.one_by_id(session=session, id=1)
  466. if not default_cluster:
  467. logger.debug(
  468. "Default cluster does not exist, skipping legacy worker migration."
  469. )
  470. return
  471. workers = await Worker.all_by_fields(
  472. session=session,
  473. fields={
  474. "cluster_id": default_cluster.id,
  475. "token": None,
  476. },
  477. )
  478. if len(workers) == 0:
  479. return
  480. system_name_prefix = "system/worker"
  481. worker_ids = [worker.id for worker in workers]
  482. worker_users = await User.all_by_fields(
  483. session=session,
  484. fields={
  485. "cluster_id": default_cluster.id,
  486. "is_system": True,
  487. "role": UserRole.Worker,
  488. },
  489. extra_conditions=[User.worker_id.in_(worker_ids)],
  490. )
  491. user_by_worker_id = {user.worker_id: user for user in worker_users}
  492. for worker in workers:
  493. try:
  494. worker_user = user_by_worker_id.get(worker.id, None)
  495. if not worker_user:
  496. to_create_user = User(
  497. username=f'{system_name_prefix}-{worker.id}',
  498. is_system=True,
  499. role=UserRole.Worker,
  500. hashed_password="",
  501. cluster=default_cluster,
  502. cluster_id=default_cluster.id,
  503. worker=worker,
  504. worker_id=worker.id,
  505. )
  506. worker_user = await create_user_with_principal(
  507. session, to_create_user
  508. )
  509. access_key = secrets.token_hex(8)
  510. secret_key = secrets.token_hex(16)
  511. to_create_apikey = ApiKey(
  512. name=worker_user.username,
  513. access_key=access_key,
  514. hashed_secret_key=get_secret_hash(secret_key),
  515. user=worker_user,
  516. user_id=worker_user.id,
  517. )
  518. await ApiKey.create(session, to_create_apikey, auto_commit=False)
  519. await worker.update(
  520. session=session,
  521. source={"token": f"{API_KEY_PREFIX}_{access_key}_{secret_key}"},
  522. auto_commit=False,
  523. )
  524. await session.commit()
  525. except Exception as e:
  526. logger.error(
  527. f"Failed to migrate worker {worker.id} ({worker.name}): {e}"
  528. )
  529. await session.rollback()
  530. raise e
  531. async def _ensure_registration_token(self, session: AsyncSession):
  532. cluster_user = await get_default_cluster_user(session)
  533. if cluster_user is None or cluster_user.cluster is None:
  534. logger.debug(
  535. "Default cluster user not exist, skipping registration token generation."
  536. )
  537. return
  538. # Hold a local reference: ``ApiKey.create`` triggers
  539. # ``ActiveRecordMixin._refresh_related_objects`` which calls
  540. # ``session.refresh(cluster_user)``, expiring its eagerly-loaded
  541. # ``cluster`` attribute. With ``User.cluster`` set to
  542. # ``lazy="noload"``, accessing ``cluster_user.cluster``
  543. # afterwards returns ``None`` and the subsequent update would
  544. # blow up.
  545. cluster = cluster_user.cluster
  546. token = cluster.registration_token
  547. if not token:
  548. try:
  549. access_key = secrets.token_hex(8)
  550. secret_key = secrets.token_hex(16)
  551. new_key = ApiKey(
  552. name="Default Cluster Token",
  553. access_key=access_key,
  554. hashed_secret_key=get_secret_hash(secret_key),
  555. user_id=cluster_user.id,
  556. user=cluster_user,
  557. )
  558. await ApiKey.create(session, new_key, auto_commit=False)
  559. token = f"{API_KEY_PREFIX}_{access_key}_{secret_key}"
  560. await cluster.update(
  561. session=session,
  562. source={"registration_token": token},
  563. auto_commit=False,
  564. )
  565. await session.commit()
  566. except Exception as e:
  567. logger.error(f"Failed to ensure registration token: {e}")
  568. await session.rollback()
  569. raise e
  570. write_registration_token(
  571. data_dir=self._config.data_dir,
  572. token=token,
  573. )
  574. async def _cleanup_orphaned_gateway_data(self, session: AsyncSession):
  575. if self.config.gateway_mode == GatewayModeEnum.disabled:
  576. return
  577. # Remove the orphaned ingresses of model routes
  578. model_routes = await ModelRoute.all_by_field(
  579. session=session, field="deleted_at", value=None
  580. )
  581. route_targets = await ModelRouteTarget.all_by_fields(
  582. session=session,
  583. fields={"deleted_at": None},
  584. )
  585. providers = await ModelProvider.all_by_fields(
  586. session=session,
  587. fields={"deleted_at": None},
  588. )
  589. model_instances = await ModelInstance.all_by_fields(
  590. session=session,
  591. fields={"deleted_at": None},
  592. )
  593. workers = await Worker.all_by_fields(
  594. session=session,
  595. fields={"deleted_at": None},
  596. )
  597. fallback_route_ids = [
  598. ep.route_id
  599. for ep in route_targets
  600. if ep.fallback_status_codes is not None
  601. and len(ep.fallback_status_codes) > 0
  602. ]
  603. expected_ingress_names = [
  604. model_route_ingress_name(model_route.id) for model_route in model_routes
  605. ]
  606. expected_names = expected_ingress_names + [
  607. fallback_ingress_name(model_route_ingress_name(id))
  608. for id in fallback_route_ids
  609. ]
  610. k8s_config = get_async_k8s_config(cfg=self.config)
  611. await cleanup_ingresses(
  612. namespace=self.config.get_namespace(),
  613. expected_names=expected_names,
  614. config=k8s_config,
  615. cleanup_prefix=model_route_ingress_prefix,
  616. reason="orphaned",
  617. )
  618. await cleanup_ingresses(
  619. namespace=self.config.get_namespace(),
  620. expected_names=expected_names,
  621. config=k8s_config,
  622. cleanup_prefix=model_ingress_prefix,
  623. reason="legacy",
  624. )
  625. await cleanup_model_mapper(
  626. namespace=self.config.gateway_namespace,
  627. expected_ingresses=expected_ingress_names,
  628. config=k8s_config,
  629. )
  630. await cleanup_fallback_filters(
  631. namespace=self.config.get_namespace(),
  632. expected_names=expected_names,
  633. cleanup_prefix=model_route_ingress_prefix,
  634. reason="orphaned",
  635. k8s_config=k8s_config,
  636. )
  637. await cleanup_ai_proxy_config(
  638. namespace=self.config.gateway_namespace,
  639. providers=providers,
  640. routes=model_routes,
  641. k8s_config=k8s_config,
  642. )
  643. await cleanup_generic_route_transformer(
  644. routes=model_routes,
  645. k8s_config=k8s_config,
  646. namespace=self.config.gateway_namespace,
  647. )
  648. await cleanup_mcpbridge_registry(
  649. providers=providers,
  650. namespace=self.config.gateway_namespace,
  651. model_instances=model_instances,
  652. workers=workers,
  653. k8s_config=k8s_config,
  654. )
  655. def _should_create_default_cluster(self) -> bool:
  656. # only server or both will get into this logic
  657. if self._config.server_role() == Config.ServerRole.BOTH:
  658. return True
  659. if self._config.token:
  660. return True
  661. return False
  662. async def _init_default_cluster(self, session: AsyncSession):
  663. if not self._should_create_default_cluster():
  664. return
  665. default_cluster_user = await get_default_cluster_user(session)
  666. if default_cluster_user:
  667. return
  668. user_defined_default_cluster = await self.user_defined_default_cluster(session)
  669. set_default = user_defined_default_cluster is None
  670. logger.info("Creating default cluster...")
  671. provider = ClusterProvider.Docker
  672. if DEFAULT_CLUSTER_KUBERNETES:
  673. provider = ClusterProvider.Kubernetes
  674. hashed_suffix = secrets.token_hex(6)
  675. default_cluster = Cluster(
  676. name="Default Cluster",
  677. description="The default cluster for GPUStack",
  678. provider=provider,
  679. state=ClusterStateEnum.READY,
  680. hashed_suffix=hashed_suffix,
  681. registration_token="",
  682. is_default=set_default,
  683. owner_principal_id=PLATFORM_PRINCIPAL_ID,
  684. )
  685. default_cluster = await Cluster.create(
  686. session, default_cluster, auto_commit=False
  687. )
  688. default_cluster_user = User(
  689. username=default_cluster_user_name,
  690. is_system=True,
  691. is_admin=False,
  692. require_password_change=False,
  693. role=UserRole.Cluster,
  694. hashed_password="",
  695. cluster=default_cluster,
  696. )
  697. await create_user_with_principal(session, default_cluster_user)
  698. # No cluster_access grant needed: the cluster's `owner_principal_id`
  699. # already binds it to the platform Org, whose members are
  700. # implicit USER-level consumers. cluster_access rows are only
  701. # for cross-Org / group / user borrowing.
  702. await session.commit()
  703. logger.debug("Default cluster created.")
  704. async def user_defined_default_cluster(self, session: AsyncSession) -> Cluster:
  705. # Used during initial bootstrap to decide whether to create a
  706. # platform-Org default — only need to check the platform Org slot
  707. # since per-Org defaults are independent.
  708. cluster = await Cluster.one_by_fields(
  709. session=session,
  710. fields={
  711. "is_default": True,
  712. "owner_principal_id": PLATFORM_PRINCIPAL_ID,
  713. "deleted_at": None,
  714. },
  715. )
  716. return cluster
  717. def _start_proxy_servers(self, app: FastAPI) -> None:
  718. _proxy_server = HTTPSProxyServer(
  719. host=self._config.get_proxy_listen_address(),
  720. port=self._config.get_proxy_port(),
  721. connection_manager_getter=app.state.message_server_handler.get_connection_manager,
  722. authenticator=lambda headers: authenticate_worker_by_request_headers(
  723. headers, validate_proxy=None
  724. ),
  725. header_router=resolve_instance_address_from_model_header,
  726. )
  727. self._create_async_task(_proxy_server.start())
  728. def _start_extension_plugins(self, app: FastAPI) -> None:
  729. for plugin in getattr(app.state, "extension_plugins", []):
  730. try:
  731. for coro in plugin.async_tasks():
  732. self._create_async_task(coro)
  733. except Exception:
  734. logger.exception(
  735. "Failed to start async tasks from extension plugin %s",
  736. type(plugin).__name__,
  737. )
  738. async def _init_coordinator(self, app: FastAPI):
  739. """Pick a coordinator from extension plugins (if any) and start it.
  740. Plugins attach a ``Coordinator`` to ``self.coordinator`` inside
  741. their ``__init__(app, cfg)``. We scan ``app.state.extension_plugins``
  742. after ``create_app`` has run and take the first non-None one. If
  743. no plugin supplies one, we fall back to ``LocalCoordinator``.
  744. """
  745. coordinator = None
  746. for plugin in getattr(app.state, "extension_plugins", []):
  747. candidate = getattr(plugin, "coordinator", None)
  748. if candidate is not None:
  749. coordinator = candidate
  750. logger.info(f"Coordinator provided by plugin: {type(plugin).__name__}")
  751. break
  752. if coordinator is None:
  753. coordinator = LocalCoordinator(self._config)
  754. logger.debug("Using LocalCoordinator")
  755. self._coordinator = coordinator
  756. await self._coordinator.start()
  757. # Set up bus and cache to use coordinator
  758. bus.set_coordinator(coordinator)
  759. await bus.event_bus.start()
  760. cache_module.set_coordinator(coordinator)
  761. await self._prepare_jwt_secret_key()
  762. async def _preload_change_detector_cache(self):
  763. if isinstance(self._coordinator, LocalCoordinator):
  764. return
  765. topics = [
  766. "worker",
  767. "model",
  768. "modelinstance",
  769. "modelroute",
  770. "modelroutetarget",
  771. "workerpool",
  772. "inferencebackend",
  773. ]
  774. async with async_session() as session:
  775. for topic in topics:
  776. model_class = get_model_for_topic(topic)
  777. if model_class is None:
  778. continue
  779. try:
  780. await preload_cache(topic, model_class, session)
  781. except Exception as e:
  782. logger.warning(
  783. f"Failed to preload change-detection cache for {topic}: {e}"
  784. )
  785. async def _prepare_jwt_secret_key(self):
  786. """Enforce that distributed deployments use an explicit JWT secret.
  787. ``Config`` auto-generates a local ``jwt_secret_key`` file during init
  788. so early startup paths (e.g. ``initialize_gateway``) have a usable key.
  789. That auto-generated value is safe only in single-node mode; distributed
  790. instances must share the SAME secret or JWTs signed by one instance
  791. won't verify on another. We rely on the ``_jwt_secret_key_user_provided``
  792. flag (set from --jwt-secret-key / GPUSTACK_JWT_SECRET_KEY / config file)
  793. rather than the current value, since the value is always populated by
  794. the time this runs.
  795. """
  796. if self._config._jwt_secret_key_user_provided:
  797. return
  798. if isinstance(self._coordinator, LocalCoordinator):
  799. return
  800. raise RuntimeError(
  801. "jwt_secret_key must be explicitly set in distributed mode. "
  802. "Mount a Kubernetes Secret or pass it via the --jwt-secret-key flag "
  803. "or set the GPUSTACK_JWT_SECRET_KEY environment variable."
  804. )
  805. async def _start_leader_only_tasks(self):
  806. """Start tasks that should only run on the Leader instance."""
  807. if isinstance(self._coordinator, LocalCoordinator):
  808. # Local mode: start leader tasks directly (always run)
  809. self._start_leader_tasks()
  810. return
  811. # Distributed mode: start leader election loop
  812. logger.info("Starting leader election loop...")
  813. self._leader_election_task = asyncio.create_task(self._leader_election_loop())
  814. async def _leader_election_loop(self):
  815. """Main leader election loop using coordinator."""
  816. server_id = self._config.server_id
  817. ttl = self._coordinator.leader_election_ttl
  818. renew_interval = self._coordinator.leader_election_renew_interval
  819. is_first_attempt = True
  820. while True:
  821. try:
  822. if not self._coordinator.is_leader():
  823. # Try to acquire leadership
  824. if is_first_attempt:
  825. logger.info(
  826. f"Server {server_id} attempting to acquire leadership..."
  827. )
  828. acquired = await self._coordinator.acquire_leadership(ttl)
  829. if acquired:
  830. logger.info(
  831. f"Server {server_id} became leader, starting scheduler and controllers"
  832. )
  833. # Start leader-only tasks
  834. self._start_leader_tasks()
  835. elif is_first_attempt:
  836. logger.info(
  837. f"Server {server_id} is standby, waiting for leadership..."
  838. )
  839. is_first_attempt = False
  840. else:
  841. # Renew leadership
  842. renewed = await self._coordinator.renew_leadership(ttl)
  843. if not renewed:
  844. logger.error(
  845. f"Server {server_id} lost leadership, exiting for restart"
  846. )
  847. # Hard exit to prevent split-brain: os._exit bypasses
  848. # cleanup so the process stops immediately and the
  849. # container runtime can restart it as a standby.
  850. os._exit(1)
  851. await asyncio.sleep(renew_interval)
  852. except Exception as e:
  853. logger.error(f"Leader election error: {e}")
  854. await asyncio.sleep(5)
  855. def _start_leader_tasks(self):
  856. """Start tasks that run only on the leader.
  857. Note: If leadership is lost, the process exits directly (os._exit),
  858. so we don't need to track and cancel these tasks.
  859. """
  860. # Scheduler
  861. self._start_scheduler()
  862. # Controllers
  863. self._start_controllers()
  864. # System Load Collector
  865. self._start_system_load_collector()
  866. # Update Checker
  867. self._start_update_checker()
  868. # Worker Instance Cleaner
  869. self._start_worker_instance_cleaner()
  870. # Usage Details Archiver (move aged rows to archive table)
  871. self._start_usage_details_archiver()
  872. # Worker Syncer (checks worker reachability and updates states)
  873. self._start_worker_syncer(self._app)