| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990 |
- import asyncio
- from multiprocessing import Process
- import os
- import re
- import importlib.util
- import aiohttp
- import uvicorn
- from fastapi import FastAPI
- import logging
- import secrets
- import tenacity
- from sqlmodel.ext.asyncio.session import AsyncSession
- from gpustack.logging import setup_logging
- from gpustack.schemas.users import (
- User,
- UserRole,
- get_default_cluster_user,
- default_cluster_user_name,
- )
- from gpustack.schemas.principals import PLATFORM_PRINCIPAL_ID
- from gpustack.schemas.models import ModelInstance
- from gpustack.schemas.api_keys import ApiKey
- from gpustack.schemas.workers import Worker
- from gpustack.schemas.clusters import Cluster, ClusterProvider, ClusterStateEnum
- from gpustack.schemas.model_routes import ModelRoute, ModelRouteTarget
- from gpustack.schemas.model_provider import ModelProvider
- from gpustack.security import (
- generate_secure_password,
- get_secret_hash,
- API_KEY_PREFIX,
- )
- from gpustack.routes.auth import remove_initial_password_file_if_exists
- from gpustack.server.app import create_app
- from gpustack.server.services import (
- create_user_with_principal,
- provision_bootstrap_admin_orgs,
- )
- from gpustack.config.config import Config
- from gpustack.schemas.config import GatewayModeEnum
- from gpustack.config import registration
- from gpustack.server.catalog import init_model_catalog
- from gpustack.server.controllers import (
- ModelController,
- ModelFileController,
- ModelInstanceController,
- WorkerController,
- ClusterController,
- WorkerPoolController,
- InferenceBackendController,
- ModelRouteController,
- ModelRouteTargetController,
- ModelProviderController,
- )
- from gpustack.server.db import async_session
- from gpustack.server.init_db import init_db, get_query_count
- from gpustack.scheduler.scheduler import Scheduler
- from gpustack.server.system_load import SystemLoadCollector
- from gpustack.server.update_check import UpdateChecker
- from gpustack.server.worker_status_buffer import flush_worker_status_to_db
- from gpustack.server.metrics_collector import flush_gateway_metrics_to_db
- from gpustack.server.usage_details_archiver import UsageDetailsArchiver
- from gpustack.server.worker_instance_cleaner import WorkerInstanceCleaner
- from gpustack.server.worker_syncer import WorkerSyncer
- from gpustack.utils.process import add_signal_handlers_in_loop
- from gpustack.config.registration import write_registration_token
- from gpustack.exporter.exporter import MetricExporter
- from gpustack.gateway.utils import (
- model_ingress_prefix,
- model_route_ingress_prefix,
- model_route_ingress_name,
- fallback_ingress_name,
- cleanup_ingresses,
- cleanup_model_mapper,
- cleanup_fallback_filters,
- cleanup_ai_proxy_config,
- cleanup_generic_route_transformer,
- cleanup_mcpbridge_registry,
- resolve_instance_address_from_model_header,
- )
- from gpustack.gateway import get_async_k8s_config
- from gpustack.envs import (
- GATEWAY_PORT_CHECK_INTERVAL,
- GATEWAY_PORT_CHECK_RETRY_COUNT,
- DEFAULT_CLUSTER_KUBERNETES,
- )
- from gpustack.server.coordinator import LocalCoordinator
- from gpustack.server.coordinator.cache import preload_cache
- from gpustack.server.coordinator.models import get_model_for_topic
- from gpustack.server import bus
- from gpustack.server import cache as cache_module
- from alembic import command
- from alembic.config import Config as AlembicConfig
- from gpustack.websocket_proxy.proxy_server import HTTPSProxyServer
- from gpustack.api.auth import (
- authenticate_worker_by_request_headers,
- )
- logger = logging.getLogger(__name__)
- class Server:
- def __init__(self, config: Config, worker_process: Process):
- self._config: Config = config
- self._sub_processes = []
- self._async_tasks = []
- self._worker_process = worker_process
- # Coordination components
- self._coordinator = None
- self._leader_election_task = None
- @property
- def all_processes(self):
- return self._sub_processes
- def _create_async_task(self, coro):
- self._async_tasks.append(asyncio.create_task(coro))
- @property
- def config(self):
- return self._config
- async def start(self):
- logger.info("Starting GPUStack server.")
- add_signal_handlers_in_loop()
- self._run_migrations()
- await self._prepare_data()
- init_model_catalog(self._config.model_catalog_file)
- # it's safe to determine server_role after migration
- if self._config.server_role() == Config.ServerRole.BOTH:
- self._sub_processes.append(self._worker_process)
- # Create FastAPI app. Plugin ``__init__(app, cfg)`` runs here and
- # may attach a distributed-mode coordinator to the plugin instance.
- app = create_app(self._config)
- self._app = app
- # Initialize coordinator from plugin instances (LocalCoordinator if
- # none supplied). Must run before the event bus goes online so any
- # early publishes are routed correctly.
- await self._init_coordinator(app)
- # Preload change-detection cache after the coordinator is up.
- # Required in distributed mode so the first cross-instance event
- # on each topic carries accurate ``changed_fields``.
- await self._preload_change_detector_cache()
- self._start_sub_processes()
- # Start Leader-Only tasks (includes scheduler and controllers)
- # In single-node mode, they start immediately.
- # In distributed mode, they start only when this node becomes leader.
- await self._start_leader_only_tasks()
- # These tasks can run on all instances
- self._start_worker_status_flusher()
- self._start_gateway_metrics_flusher()
- self._start_metrics_exporter()
- self._start_query_count_logger()
- self._start_default_registry_checker()
- self._start_proxy_servers(app)
- self._start_extension_plugins(app)
- serving_host = (
- "127.0.0.1"
- if self._config.gateway_mode == GatewayModeEnum.embedded
- else "0.0.0.0"
- )
- config = uvicorn.Config(
- app,
- host=serving_host,
- port=self._config.get_api_port(),
- access_log=False,
- log_level="error",
- )
- setup_logging()
- logger.info(f"Gateway mode: {self._config.gateway_mode.value}.")
- serving_api_message = f"Serving GPUStack API on {config.host}:{config.port}."
- if self._config.gateway_mode == GatewayModeEnum.embedded:
- logger.debug(serving_api_message)
- logger.info(
- f"GPUStack Server will serve on 0.0.0.0:{self._config.get_gateway_port()}."
- )
- if self._config.get_tls_secret_name() is not None:
- logger.info(
- f"GPUStack Server will serve TLS on 0.0.0.0:{self._config.tls_port}."
- )
- else:
- logger.info(serving_api_message)
- server = uvicorn.Server(config)
- self._create_async_task(server.serve())
- await asyncio.gather(*self._async_tasks)
- def _start_default_registry_checker(self):
- registration.determine_default_registry(
- self._config.system_default_container_registry,
- ),
- def _run_migrations(self):
- logger.info("Running database migration.")
- spec = importlib.util.find_spec("gpustack")
- if spec is None:
- raise ImportError("The 'gpustack' package is not found.")
- pkg_path = spec.submodule_search_locations[0]
- alembic_cfg = AlembicConfig()
- alembic_cfg.set_main_option(
- "script_location", os.path.join(pkg_path, "migrations")
- )
- db_url = self._config.get_database_url()
- # Use the pymysql driver to execute migrations to avoid compatibility issues between asynchronous drivers and Alembic.
- if db_url.startswith("mysql://"):
- db_url = re.sub(r'^mysql://', 'mysql+pymysql://', db_url)
- db_url_escaped = db_url.replace("%", "%%")
- alembic_cfg.set_main_option("sqlalchemy.url", db_url_escaped)
- try:
- command.upgrade(alembic_cfg, "head")
- except Exception as e:
- raise RuntimeError(f"Database migration failed: {e}") from e
- logger.info("Database migration completed.")
- async def _prepare_data(self):
- self._setup_data_dir(self._config.data_dir)
- await init_db(self._config.get_database_url())
- async with async_session() as session:
- await self._init_data(session)
- logger.debug("Data initialization completed.")
- def _start_scheduler(self):
- """Start the scheduler and return the task."""
- scheduler = Scheduler(self._config)
- task = asyncio.create_task(scheduler.start())
- logger.debug("Scheduler started.")
- return task
- def _start_controllers(self):
- """Start all controllers and return the list of tasks."""
- tasks = []
- model_provider_controller = ModelProviderController(self._config)
- tasks.append(asyncio.create_task(model_provider_controller.start()))
- model_route_target_controller = ModelRouteTargetController(self._config)
- tasks.append(asyncio.create_task(model_route_target_controller.start()))
- model_route_controller = ModelRouteController(self._config)
- tasks.append(asyncio.create_task(model_route_controller.start()))
- model_controller = ModelController(self._config)
- tasks.append(asyncio.create_task(model_controller.start()))
- model_instance_controller = ModelInstanceController(self._config)
- tasks.append(asyncio.create_task(model_instance_controller.start()))
- worker_controller = WorkerController(self._config)
- tasks.append(asyncio.create_task(worker_controller.start()))
- model_file_controller = ModelFileController()
- tasks.append(asyncio.create_task(model_file_controller.start()))
- cluster_controller = ClusterController(self._config)
- tasks.append(asyncio.create_task(cluster_controller.start()))
- worker_pool_controller = WorkerPoolController()
- tasks.append(asyncio.create_task(worker_pool_controller.start()))
- inference_backend_controller = InferenceBackendController()
- tasks.append(asyncio.create_task(inference_backend_controller.start()))
- logger.debug("Controllers started.")
- return tasks
- def _start_system_load_collector(self):
- collector = SystemLoadCollector()
- self._create_async_task(collector.start())
- logger.debug("System load collector started.")
- def _start_worker_syncer(self, app: FastAPI):
- worker_syncer = WorkerSyncer(
- lambda: getattr(app.state, "http_client", None),
- lambda: getattr(app.state, "http_client_no_proxy", None),
- )
- self._create_async_task(worker_syncer.start())
- logger.debug("Worker syncer started.")
- def _start_worker_status_flusher(self):
- self._create_async_task(flush_worker_status_to_db())
- logger.debug("Worker status flusher started.")
- def _start_gateway_metrics_flusher(self):
- # Always start — both the gateway report endpoint and the in-process
- # ModelUsageMiddleware feed the same buffer, so the flusher must run
- # even when the external gateway is disabled.
- self._create_async_task(flush_gateway_metrics_to_db())
- logger.debug("Gateway metrics flusher started.")
- def _start_worker_instance_cleaner(self):
- worker_instance_cleaner = WorkerInstanceCleaner()
- self._create_async_task(worker_instance_cleaner.start())
- logger.debug("Worker instance cleaner started.")
- def _start_usage_details_archiver(self):
- # Construction can fail on schema drift between hot/archive tables or
- # an invalid cron expression. Surface that loudly and skip launching
- # the loop so the rest of the leader tasks (and the leader-election
- # retry) aren't taken down with it. Without the archiver the
- # model_usage_details hot table will grow unbounded — operators must
- # see this in logs rather than have it buried as "Leader election
- # error" by the outer election handler.
- try:
- archiver = UsageDetailsArchiver()
- except Exception:
- logger.critical(
- "Usage details archiver failed to initialize — archival is "
- "DISABLED. The model_usage_details hot table will grow "
- "unbounded until this is resolved.",
- exc_info=True,
- )
- return
- self._create_async_task(archiver.start())
- logger.debug("Usage details archiver started.")
- def _start_update_checker(self):
- """Start update checker."""
- if self._config.disable_update_check:
- return
- update_checker = UpdateChecker(update_check_url=self._config.update_check_url)
- self._create_async_task(update_checker.start())
- logger.debug("Update checker started.")
- async def _monitor_sub_processes(self):
- while self._sub_processes:
- for process in self._sub_processes[:]:
- if not process.is_alive():
- if process.exitcode != 0:
- raise RuntimeError(
- f"Sub process {process.name} died with exit code {process.exitcode}"
- )
- self._sub_processes.remove(process)
- await asyncio.sleep(5)
- def _start_sub_processes(self):
- async def start_process_after_api_ready():
- api_url = f"http://127.0.0.1:{self._config.api_port}/healthz"
- async with aiohttp.ClientSession() as session:
- while True:
- try:
- await asyncio.sleep(2)
- async with session.get(api_url) as response:
- if response.status == 200:
- break
- except aiohttp.ClientError:
- pass
- except asyncio.CancelledError:
- return
- for process in self._sub_processes:
- process.start()
- await self._monitor_sub_processes()
- if len(self._sub_processes) == 0:
- return
- self._create_async_task(start_process_after_api_ready())
- async def _wait_for_gateway_ready(self):
- if self._config.gateway_mode != GatewayModeEnum.embedded:
- return
- # http port is always started
- ports = [self._config.port]
- if self._config.get_tls_secret_name() is not None:
- ports.append(self._config.tls_port)
- logger.info(f"Waiting for ports {ports} of GPUStack to be ready...")
- # wait for gateway ready for about 60s
- await self._check_ports_ready(*ports)
- logger.info("GPUStack Server is ready.")
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(GATEWAY_PORT_CHECK_RETRY_COUNT),
- wait=tenacity.wait_fixed(GATEWAY_PORT_CHECK_INTERVAL),
- reraise=True,
- before_sleep=lambda retry_state: logger.debug(
- f"Waiting for ports {retry_state.args[1]} to be healthy (attempt {retry_state.attempt_number}) due to: {retry_state.outcome.exception()}"
- ),
- )
- async def _check_ports_ready(self, *ports: int):
- for port in ports:
- try:
- _, writer = await asyncio.open_connection("127.0.0.1", port)
- writer.close()
- await writer.wait_closed()
- except Exception:
- raise RuntimeError(f"Port {port} is not healthy or not listening")
- def _start_metrics_exporter(self):
- if self._config.disable_metrics:
- return
- exporter = MetricExporter(cfg=self._config)
- self._create_async_task(exporter.generate_metrics_cache())
- self._create_async_task(exporter.start())
- def _start_query_count_logger(self):
- """Start a background task to log query count periodically."""
- async def log_query_count():
- while True:
- await asyncio.sleep(60) # Log every minute
- count = get_query_count()
- logger.debug(f"[DB QUERY COUNT] Total queries since startup: {count}")
- self._create_async_task(log_query_count())
- @staticmethod
- def _setup_data_dir(data_dir: str):
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- async def _init_data(self, session: AsyncSession):
- init_data_funcs = [
- self._init_user,
- self._init_default_cluster,
- self._migrate_legacy_token,
- self._migrate_legacy_workers,
- self._ensure_registration_token,
- self._cleanup_orphaned_gateway_data,
- ]
- for init_data_func in init_data_funcs:
- await init_data_func(session)
- async def _init_user(self, session: AsyncSession):
- # Skip bootstrap when any non-system admin already exists, so that
- # renaming the default "admin" account does not cause a duplicate
- # admin to be regenerated on master restart.
- existing_admin = await User.first_by_fields(
- session=session,
- fields={"is_admin": True, "is_system": False, "is_active": True},
- )
- if existing_admin:
- return
- # Drop any stale initial password file from a prior bootstrap before
- # generating a new one, so the login page does not show an outdated
- # "retrieve initial password" hint.
- remove_initial_password_file_if_exists(self._config)
- bootstrap_password = self._config.bootstrap_password
- require_password_change = False
- if not bootstrap_password:
- require_password_change = True
- bootstrap_password = generate_secure_password()
- bootstrap_password_file = os.path.join(
- self._config.data_dir, "initial_admin_password"
- )
- with open(bootstrap_password_file, "w") as file:
- file.write(bootstrap_password + "\n")
- logger.info(
- "Generated initial admin password. "
- f"You can get it from {bootstrap_password_file}"
- )
- user = User(
- username="admin",
- full_name="Default System Admin",
- hashed_password=get_secret_hash(bootstrap_password),
- is_admin=True,
- require_password_change=require_password_change,
- )
- user = await create_user_with_principal(session, user)
- await provision_bootstrap_admin_orgs(session, user)
- await session.commit()
- async def _migrate_legacy_token(self, session: AsyncSession):
- if not self._config.token:
- return
- # this should be created from sql migration script.
- cluster_user = await get_default_cluster_user(session)
- if cluster_user is None or cluster_user.cluster is None:
- logger.debug(
- "Default cluster user not exist, skipping legacy token migration."
- )
- return
- default_cluster = cluster_user.cluster
- if not default_cluster:
- logger.debug(
- "Default cluster does not exist, skipping legacy token migration."
- )
- return
- if default_cluster.registration_token:
- return
- try:
- default_cluster.registration_token = self._config.token
- await default_cluster.update(session=session, auto_commit=False)
- default_cluster_user = await User.one_by_fields(
- session=session,
- fields={
- "cluster_id": default_cluster.id,
- "is_system": True,
- "role": UserRole.Cluster,
- },
- )
- if default_cluster_user is None:
- raise RuntimeError("Default cluster user does not exist.")
- if len(default_cluster_user.api_keys) > 0:
- raise RuntimeError(
- "Default cluster user already has API keys, cannot migrate legacy token."
- )
- new_key = ApiKey(
- name="Legacy Cluster Token",
- access_key="",
- hashed_secret_key=get_secret_hash(self._config.token),
- user_id=default_cluster_user.id,
- user=default_cluster_user,
- )
- await ApiKey.create(session, new_key, auto_commit=False)
- await session.commit()
- except Exception as e:
- logger.error(f"Failed to migrate legacy token: {e}")
- await session.rollback()
- raise e
- async def _migrate_legacy_workers(self, session: AsyncSession):
- # Use hardcode cluster 1 to make sure the cluster is created in migration step
- default_cluster = await Cluster.one_by_id(session=session, id=1)
- if not default_cluster:
- logger.debug(
- "Default cluster does not exist, skipping legacy worker migration."
- )
- return
- workers = await Worker.all_by_fields(
- session=session,
- fields={
- "cluster_id": default_cluster.id,
- "token": None,
- },
- )
- if len(workers) == 0:
- return
- system_name_prefix = "system/worker"
- worker_ids = [worker.id for worker in workers]
- worker_users = await User.all_by_fields(
- session=session,
- fields={
- "cluster_id": default_cluster.id,
- "is_system": True,
- "role": UserRole.Worker,
- },
- extra_conditions=[User.worker_id.in_(worker_ids)],
- )
- user_by_worker_id = {user.worker_id: user for user in worker_users}
- for worker in workers:
- try:
- worker_user = user_by_worker_id.get(worker.id, None)
- if not worker_user:
- to_create_user = User(
- username=f'{system_name_prefix}-{worker.id}',
- is_system=True,
- role=UserRole.Worker,
- hashed_password="",
- cluster=default_cluster,
- cluster_id=default_cluster.id,
- worker=worker,
- worker_id=worker.id,
- )
- worker_user = await create_user_with_principal(
- session, to_create_user
- )
- access_key = secrets.token_hex(8)
- secret_key = secrets.token_hex(16)
- to_create_apikey = ApiKey(
- name=worker_user.username,
- access_key=access_key,
- hashed_secret_key=get_secret_hash(secret_key),
- user=worker_user,
- user_id=worker_user.id,
- )
- await ApiKey.create(session, to_create_apikey, auto_commit=False)
- await worker.update(
- session=session,
- source={"token": f"{API_KEY_PREFIX}_{access_key}_{secret_key}"},
- auto_commit=False,
- )
- await session.commit()
- except Exception as e:
- logger.error(
- f"Failed to migrate worker {worker.id} ({worker.name}): {e}"
- )
- await session.rollback()
- raise e
- async def _ensure_registration_token(self, session: AsyncSession):
- cluster_user = await get_default_cluster_user(session)
- if cluster_user is None or cluster_user.cluster is None:
- logger.debug(
- "Default cluster user not exist, skipping registration token generation."
- )
- return
- # Hold a local reference: ``ApiKey.create`` triggers
- # ``ActiveRecordMixin._refresh_related_objects`` which calls
- # ``session.refresh(cluster_user)``, expiring its eagerly-loaded
- # ``cluster`` attribute. With ``User.cluster`` set to
- # ``lazy="noload"``, accessing ``cluster_user.cluster``
- # afterwards returns ``None`` and the subsequent update would
- # blow up.
- cluster = cluster_user.cluster
- token = cluster.registration_token
- if not token:
- try:
- access_key = secrets.token_hex(8)
- secret_key = secrets.token_hex(16)
- new_key = ApiKey(
- name="Default Cluster Token",
- access_key=access_key,
- hashed_secret_key=get_secret_hash(secret_key),
- user_id=cluster_user.id,
- user=cluster_user,
- )
- await ApiKey.create(session, new_key, auto_commit=False)
- token = f"{API_KEY_PREFIX}_{access_key}_{secret_key}"
- await cluster.update(
- session=session,
- source={"registration_token": token},
- auto_commit=False,
- )
- await session.commit()
- except Exception as e:
- logger.error(f"Failed to ensure registration token: {e}")
- await session.rollback()
- raise e
- write_registration_token(
- data_dir=self._config.data_dir,
- token=token,
- )
- async def _cleanup_orphaned_gateway_data(self, session: AsyncSession):
- if self.config.gateway_mode == GatewayModeEnum.disabled:
- return
- # Remove the orphaned ingresses of model routes
- model_routes = await ModelRoute.all_by_field(
- session=session, field="deleted_at", value=None
- )
- route_targets = await ModelRouteTarget.all_by_fields(
- session=session,
- fields={"deleted_at": None},
- )
- providers = await ModelProvider.all_by_fields(
- session=session,
- fields={"deleted_at": None},
- )
- model_instances = await ModelInstance.all_by_fields(
- session=session,
- fields={"deleted_at": None},
- )
- workers = await Worker.all_by_fields(
- session=session,
- fields={"deleted_at": None},
- )
- fallback_route_ids = [
- ep.route_id
- for ep in route_targets
- if ep.fallback_status_codes is not None
- and len(ep.fallback_status_codes) > 0
- ]
- expected_ingress_names = [
- model_route_ingress_name(model_route.id) for model_route in model_routes
- ]
- expected_names = expected_ingress_names + [
- fallback_ingress_name(model_route_ingress_name(id))
- for id in fallback_route_ids
- ]
- k8s_config = get_async_k8s_config(cfg=self.config)
- await cleanup_ingresses(
- namespace=self.config.get_namespace(),
- expected_names=expected_names,
- config=k8s_config,
- cleanup_prefix=model_route_ingress_prefix,
- reason="orphaned",
- )
- await cleanup_ingresses(
- namespace=self.config.get_namespace(),
- expected_names=expected_names,
- config=k8s_config,
- cleanup_prefix=model_ingress_prefix,
- reason="legacy",
- )
- await cleanup_model_mapper(
- namespace=self.config.gateway_namespace,
- expected_ingresses=expected_ingress_names,
- config=k8s_config,
- )
- await cleanup_fallback_filters(
- namespace=self.config.get_namespace(),
- expected_names=expected_names,
- cleanup_prefix=model_route_ingress_prefix,
- reason="orphaned",
- k8s_config=k8s_config,
- )
- await cleanup_ai_proxy_config(
- namespace=self.config.gateway_namespace,
- providers=providers,
- routes=model_routes,
- k8s_config=k8s_config,
- )
- await cleanup_generic_route_transformer(
- routes=model_routes,
- k8s_config=k8s_config,
- namespace=self.config.gateway_namespace,
- )
- await cleanup_mcpbridge_registry(
- providers=providers,
- namespace=self.config.gateway_namespace,
- model_instances=model_instances,
- workers=workers,
- k8s_config=k8s_config,
- )
- def _should_create_default_cluster(self) -> bool:
- # only server or both will get into this logic
- if self._config.server_role() == Config.ServerRole.BOTH:
- return True
- if self._config.token:
- return True
- return False
- async def _init_default_cluster(self, session: AsyncSession):
- if not self._should_create_default_cluster():
- return
- default_cluster_user = await get_default_cluster_user(session)
- if default_cluster_user:
- return
- user_defined_default_cluster = await self.user_defined_default_cluster(session)
- set_default = user_defined_default_cluster is None
- logger.info("Creating default cluster...")
- provider = ClusterProvider.Docker
- if DEFAULT_CLUSTER_KUBERNETES:
- provider = ClusterProvider.Kubernetes
- hashed_suffix = secrets.token_hex(6)
- default_cluster = Cluster(
- name="Default Cluster",
- description="The default cluster for GPUStack",
- provider=provider,
- state=ClusterStateEnum.READY,
- hashed_suffix=hashed_suffix,
- registration_token="",
- is_default=set_default,
- owner_principal_id=PLATFORM_PRINCIPAL_ID,
- )
- default_cluster = await Cluster.create(
- session, default_cluster, auto_commit=False
- )
- default_cluster_user = User(
- username=default_cluster_user_name,
- is_system=True,
- is_admin=False,
- require_password_change=False,
- role=UserRole.Cluster,
- hashed_password="",
- cluster=default_cluster,
- )
- await create_user_with_principal(session, default_cluster_user)
- # No cluster_access grant needed: the cluster's `owner_principal_id`
- # already binds it to the platform Org, whose members are
- # implicit USER-level consumers. cluster_access rows are only
- # for cross-Org / group / user borrowing.
- await session.commit()
- logger.debug("Default cluster created.")
- async def user_defined_default_cluster(self, session: AsyncSession) -> Cluster:
- # Used during initial bootstrap to decide whether to create a
- # platform-Org default — only need to check the platform Org slot
- # since per-Org defaults are independent.
- cluster = await Cluster.one_by_fields(
- session=session,
- fields={
- "is_default": True,
- "owner_principal_id": PLATFORM_PRINCIPAL_ID,
- "deleted_at": None,
- },
- )
- return cluster
- def _start_proxy_servers(self, app: FastAPI) -> None:
- _proxy_server = HTTPSProxyServer(
- host=self._config.get_proxy_listen_address(),
- port=self._config.get_proxy_port(),
- connection_manager_getter=app.state.message_server_handler.get_connection_manager,
- authenticator=lambda headers: authenticate_worker_by_request_headers(
- headers, validate_proxy=None
- ),
- header_router=resolve_instance_address_from_model_header,
- )
- self._create_async_task(_proxy_server.start())
- def _start_extension_plugins(self, app: FastAPI) -> None:
- for plugin in getattr(app.state, "extension_plugins", []):
- try:
- for coro in plugin.async_tasks():
- self._create_async_task(coro)
- except Exception:
- logger.exception(
- "Failed to start async tasks from extension plugin %s",
- type(plugin).__name__,
- )
- async def _init_coordinator(self, app: FastAPI):
- """Pick a coordinator from extension plugins (if any) and start it.
- Plugins attach a ``Coordinator`` to ``self.coordinator`` inside
- their ``__init__(app, cfg)``. We scan ``app.state.extension_plugins``
- after ``create_app`` has run and take the first non-None one. If
- no plugin supplies one, we fall back to ``LocalCoordinator``.
- """
- coordinator = None
- for plugin in getattr(app.state, "extension_plugins", []):
- candidate = getattr(plugin, "coordinator", None)
- if candidate is not None:
- coordinator = candidate
- logger.info(f"Coordinator provided by plugin: {type(plugin).__name__}")
- break
- if coordinator is None:
- coordinator = LocalCoordinator(self._config)
- logger.debug("Using LocalCoordinator")
- self._coordinator = coordinator
- await self._coordinator.start()
- # Set up bus and cache to use coordinator
- bus.set_coordinator(coordinator)
- await bus.event_bus.start()
- cache_module.set_coordinator(coordinator)
- await self._prepare_jwt_secret_key()
- async def _preload_change_detector_cache(self):
- if isinstance(self._coordinator, LocalCoordinator):
- return
- topics = [
- "worker",
- "model",
- "modelinstance",
- "modelroute",
- "modelroutetarget",
- "workerpool",
- "inferencebackend",
- ]
- async with async_session() as session:
- for topic in topics:
- model_class = get_model_for_topic(topic)
- if model_class is None:
- continue
- try:
- await preload_cache(topic, model_class, session)
- except Exception as e:
- logger.warning(
- f"Failed to preload change-detection cache for {topic}: {e}"
- )
- async def _prepare_jwt_secret_key(self):
- """Enforce that distributed deployments use an explicit JWT secret.
- ``Config`` auto-generates a local ``jwt_secret_key`` file during init
- so early startup paths (e.g. ``initialize_gateway``) have a usable key.
- That auto-generated value is safe only in single-node mode; distributed
- instances must share the SAME secret or JWTs signed by one instance
- won't verify on another. We rely on the ``_jwt_secret_key_user_provided``
- flag (set from --jwt-secret-key / GPUSTACK_JWT_SECRET_KEY / config file)
- rather than the current value, since the value is always populated by
- the time this runs.
- """
- if self._config._jwt_secret_key_user_provided:
- return
- if isinstance(self._coordinator, LocalCoordinator):
- return
- raise RuntimeError(
- "jwt_secret_key must be explicitly set in distributed mode. "
- "Mount a Kubernetes Secret or pass it via the --jwt-secret-key flag "
- "or set the GPUSTACK_JWT_SECRET_KEY environment variable."
- )
- async def _start_leader_only_tasks(self):
- """Start tasks that should only run on the Leader instance."""
- if isinstance(self._coordinator, LocalCoordinator):
- # Local mode: start leader tasks directly (always run)
- self._start_leader_tasks()
- return
- # Distributed mode: start leader election loop
- logger.info("Starting leader election loop...")
- self._leader_election_task = asyncio.create_task(self._leader_election_loop())
- async def _leader_election_loop(self):
- """Main leader election loop using coordinator."""
- server_id = self._config.server_id
- ttl = self._coordinator.leader_election_ttl
- renew_interval = self._coordinator.leader_election_renew_interval
- is_first_attempt = True
- while True:
- try:
- if not self._coordinator.is_leader():
- # Try to acquire leadership
- if is_first_attempt:
- logger.info(
- f"Server {server_id} attempting to acquire leadership..."
- )
- acquired = await self._coordinator.acquire_leadership(ttl)
- if acquired:
- logger.info(
- f"Server {server_id} became leader, starting scheduler and controllers"
- )
- # Start leader-only tasks
- self._start_leader_tasks()
- elif is_first_attempt:
- logger.info(
- f"Server {server_id} is standby, waiting for leadership..."
- )
- is_first_attempt = False
- else:
- # Renew leadership
- renewed = await self._coordinator.renew_leadership(ttl)
- if not renewed:
- logger.error(
- f"Server {server_id} lost leadership, exiting for restart"
- )
- # Hard exit to prevent split-brain: os._exit bypasses
- # cleanup so the process stops immediately and the
- # container runtime can restart it as a standby.
- os._exit(1)
- await asyncio.sleep(renew_interval)
- except Exception as e:
- logger.error(f"Leader election error: {e}")
- await asyncio.sleep(5)
- def _start_leader_tasks(self):
- """Start tasks that run only on the leader.
- Note: If leadership is lost, the process exits directly (os._exit),
- so we don't need to track and cancel these tasks.
- """
- # Scheduler
- self._start_scheduler()
- # Controllers
- self._start_controllers()
- # System Load Collector
- self._start_system_load_collector()
- # Update Checker
- self._start_update_checker()
- # Worker Instance Cleaner
- self._start_worker_instance_cleaner()
- # Usage Details Archiver (move aged rows to archive table)
- self._start_usage_details_archiver()
- # Worker Syncer (checks worker reachability and updates states)
- self._start_worker_syncer(self._app)
|