import asyncio import contextlib from datetime import datetime, timezone import multiprocessing import re import threading import time import requests import setproctitle import os from typing import Dict, Optional, Set, List, Callable from pathlib import Path import logging from gpustack_runtime.deployer import ( get_workload, WorkloadStatusStateEnum, delete_workload, logs_workload, ) from gpustack_runtime.deployer.__utils__ import compare_versions from gpustack.api.exceptions import NotFoundException from gpustack.config.config import Config from gpustack.config import registration from gpustack.logging import ( RedirectStdoutStderr, ) from gpustack.schemas.inference_backend import ( InferenceBackend, is_built_in_backend, is_custom_backend, ) from gpustack.utils import network from gpustack.utils.convert import safe_int from gpustack.utils.attrs import set_attr from gpustack.utils.command import find_int_parameter from gpustack.utils.process import terminate_process_tree, add_signal_handlers from gpustack.worker.backends.ascend_mindie import AscendMindIEServer from gpustack.worker.backends.sglang import SGLangServer from gpustack.worker.backends.vllm import VLLMServer from gpustack.worker.backends.vox_box import VoxBoxServer from gpustack.worker.backends.custom import CustomServer from gpustack.routes.worker.logs import ( extract_container_restart_count, extract_restart_count, ) from gpustack.worker.model_meta import get_meta_from_running_instance from gpustack.client import ClientSet from gpustack.schemas.models import ( BackendEnum, Model, ModelUpdate, ModelInstance, ModelInstanceUpdate, ModelInstanceStateEnum, get_backend, DistributedServerCoordinateModeEnum, ModelInstanceSubordinateWorker, CategoryEnum, ) from gpustack.server.bus import Event, EventType from gpustack.worker.inference_backend_manager import InferenceBackendManager logger = logging.getLogger(__name__) # Inference health check error message _INFERENCE_HEALTH_CHECK_FAILED_MESSAGE = "Inference health check failed." # Global lock for port assignment to avoid pickle serialization issues _port_lock = threading.Lock() _SERVER_CLASS_MAPPING = { BackendEnum.VLLM: VLLMServer, BackendEnum.SGLANG: SGLangServer, BackendEnum.VOX_BOX: VoxBoxServer, BackendEnum.ASCEND_MINDIE: AscendMindIEServer, } class ServeManager: @property def _worker_id(self) -> int: return self._worker_id_getter() """ The ID of current worker. """ _config: Config """ Global configuration. """ _serve_log_dir: str """ The directory to store logs of serving model instances(in subprocess). """ @property def _clientset(self) -> ClientSet: return self._clientset_getter() """ The clientset to access the API server. """ _inference_backend_manager: InferenceBackendManager """ The inference backend manager. """ _provisioning_processes: Dict[int, multiprocessing.Process] """ The mapping of model instance ID to provisioning (sub)process. When the (sub)process is alive, the model instance is provisioning. If the (sub)process exited, the model instance is either running or failed. """ _log_persistence_threads: Dict[int, List[threading.Thread]] """ The mapping of model instance ID to log persistence threads. Each model instance may have multiple threads (one per loggable container). """ _log_persistence_stop_events: Dict[int, List[threading.Event]] """ The mapping of model instance ID to stop events for log persistence threads. Used to signal threads to stop gracefully. """ _error_model_instances: Dict[int, ModelInstance] """ The mapping of model instance ID to error model instances. Used to restart error model instances. """ _model_cache_by_instance: Dict[int, Model] """ The cache of models by model instance ID. Used to avoid redundant API calls to get model information. """ _model_instance_by_instance_id: Dict[int, ModelInstance] _clientset_getter: Callable[[], ClientSet] _worker_id_getter: Callable[[], int] def __init__( self, worker_id_getter: Callable[[], int], clientset_getter: Callable[[], ClientSet], cfg: Config, ): self._worker_id_getter = worker_id_getter self._config = cfg self._serve_log_dir = f"{cfg.log_dir}/serve" self._clientset_getter = clientset_getter self._provisioning_processes = {} self._log_persistence_threads = {} self._log_persistence_stop_events = {} self._error_model_instances = {} self._model_cache_by_instance = {} self._model_instance_by_instance_id = {} # Instance-level port tracking to avoid conflicts self._assigned_ports: Dict[int, Set[int]] = {} self._restart_backoff_counts: Dict[int, int] = {} # Inference health check failure tracking # {model_instance_id: failure_count} self._inference_health_check_failures: Dict[int, int] = {} # Track last successful inference per port (set by worker proxy) self._last_successful_inference: Dict[int, float] = {} # Track last health check time per model instance self._last_health_check_time: Dict[int, float] = {} os.makedirs(self._serve_log_dir, exist_ok=True) def record_successful_inference(self, instance_id: int): """Called by worker proxy on successful inference response.""" self._last_successful_inference[instance_id] = time.time() async def watch_models(self): """ Loop to watch models to keep the cache updated. """ logger.debug("Watching models.") while True: try: # Watch models without callback to keep the cache updated. await self._clientset.models.awatch(callback=None) except asyncio.CancelledError: break except Exception as e: logger.error(f"Error watching models: {e}") await asyncio.sleep(5) async def watch_model_instances_event(self): """ Loop to watch model instances' event and handle. """ logger.debug("Watching model instances event.") while True: try: await self._clientset.model_instances.awatch( callback=self._handle_model_instance_event ) except asyncio.CancelledError: break except Exception as e: logger.error(f"Error watching model instances: {e}") await asyncio.sleep(5) async def watch_model_instances(self): """ Loop to post process model instances, for example, restarting error instances. """ logger.debug("Watching model instances.") while True: try: for mi in list(self._error_model_instances.values()): self._restart_error_model_instance(mi) await asyncio.sleep(10) except Exception as e: logger.error(f"Error restarting model instances: {e}") await asyncio.sleep(5) def sync_model_instances_state(self): # noqa: C901 """ Synchronize model instances' state. - If the model instance is scheduled but not initialized, skip. - If the provision process is still alive, skip. - If the workload is still launching, skip. - If the workload is not existed, unhealthy, inactive or failed, update the model instance state to ERROR. - If everything is fine, update the model instance state to RUNNING. """ # Get all model instances assigned to this worker. # # FIXME(thxCode): This may cause performance issues when there are many model instances in the system. # A mechanism is needed to improve efficiency here. model_instances_page = self._clientset.model_instances.list(use_cache=False) if not model_instances_page.items: return model_instances: List[ModelInstance] = [] for model_instance in model_instances_page.items: # if the model instance is assigned to this worker, it must be scheduled. # But we don't need to sync the scheduled model when it is not initialized yet. if ( model_instance.worker_id == self._worker_id and model_instance.state != ModelInstanceStateEnum.SCHEDULED ): model_instances.append(model_instance) if ( model_instance.distributed_servers and model_instance.distributed_servers.subordinate_workers ): for sw in model_instance.distributed_servers.subordinate_workers: if sw.worker_id == self._worker_id: model_instances.append(model_instance) break for model_instance in model_instances: # Skip if the provision process has not exited yet. if self._is_provisioning(model_instance): logger.trace( f"Model instance {model_instance.name} is provisioning. Skipping sync." ) continue is_main_worker = model_instance.worker_id == self._worker_id # Skip if the workload is still launching. # Use deployment metadata name for subordinate workers (e.g., "model-f0") # since their workload name differs from the model instance name. if is_main_worker: workload = get_workload(model_instance.name) else: deployment_metadata = model_instance.get_deployment_metadata( self._worker_id ) workload_name = ( deployment_metadata.name if deployment_metadata else model_instance.name ) workload = get_workload(workload_name) if workload and workload.state in [ WorkloadStatusStateEnum.PENDING, WorkloadStatusStateEnum.INITIALIZING, ]: logger.trace( f"Model instance {model_instance.name} workload is still launching. Skipping sync." ) continue # Update model instance state to ERROR if the workload is not existed, unhealthy, inactive or failed. if not workload or workload.state in [ WorkloadStatusStateEnum.UNKNOWN, # Rare, but possible, for example, leaving pause container. WorkloadStatusStateEnum.UNHEALTHY, WorkloadStatusStateEnum.INACTIVE, WorkloadStatusStateEnum.FAILED, ]: # Only if not in ERROR state yet. if model_instance.state != ModelInstanceStateEnum.ERROR: with contextlib.suppress(NotFoundException): # Get patch dict for main worker. if is_main_worker: patch_dict = { "state": ModelInstanceStateEnum.ERROR, "state_message": "Inference server exited or unhealthy.", } # Get patch dict for subordinate worker. else: sw_pos = next( ( i for i, sw in enumerate( model_instance.distributed_servers.subordinate_workers ) if sw.worker_id == self._worker_id ), ) sw = model_instance.distributed_servers.subordinate_workers[ sw_pos ] sw.state = ModelInstanceStateEnum.ERROR sw.state_message = "Inference server exited or unhealthy." patch_dict = { f"distributed_servers.subordinate_workers.{sw_pos}": sw, } # Update model instance. self._update_model_instance(model_instance.id, **patch_dict) continue # Otherwise, update model instance state to RUNNING if everything is fine. model = self._get_model(model_instance) if not model.backend_version: # backend version may be empty on initialization. # try to refresh to get updated model info on syncs. model = self._refresh_model(model_instance) backend = get_backend(model) health_check_path = self._get_health_check_path(backend) if model.env and 'GPUSTACK_MODEL_HEALTH_CHECK_PATH' in model.env: # NOTE: There is no known use case for now. Keep this in case the built-in backends # introduce breaking changes and the default health check path no longer works. health_check_path = model.env['GPUSTACK_MODEL_HEALTH_CHECK_PATH'] with contextlib.suppress(NotFoundException): # Get patch dict for main worker. if is_main_worker: subordinate_state = self._get_main_worker_distributed_state( model_instance ) if subordinate_state is None: if model_instance.state == ModelInstanceStateEnum.RUNNING: self._restart_backoff_counts.pop(model_instance.id, None) continue if ( model_instance.state == ModelInstanceStateEnum.ERROR or not is_ready( backend, model_instance, health_check_path, model ) ): continue self._restart_backoff_counts.pop(model_instance.id, None) patch_dict = { "state": ModelInstanceStateEnum.RUNNING, "state_message": "", } # Fetch model meta once running. meta = get_meta_from_running_instance( model_instance, backend, model ) if meta: # Some meta is set in server evaluation and should be preserved, so we update meta instead of overwrite. merged_meta = dict(model.meta or {}) merged_meta.update(meta) if merged_meta != model.meta: self._update_model(model.id, meta=merged_meta) elif subordinate_state["should_update"]: patch_dict = { "state": subordinate_state["state"], "state_message": subordinate_state["state_message"], } else: continue # Get patch dict for subordinate worker. else: # For initialize later mode, the state is set to RUNNING directly, # which means the subordinate worker doesn't need to wait for the main worker to be healthy. if ( model_instance.distributed_servers.mode == DistributedServerCoordinateModeEnum.INITIALIZE_LATER ): continue # Otherwise, update subordinate worker state to RUNNING. sw_pos = next( ( i for i, sw in enumerate( model_instance.distributed_servers.subordinate_workers ) if sw.worker_id == self._worker_id ), ) sw = model_instance.distributed_servers.subordinate_workers[sw_pos] if sw.state == ModelInstanceStateEnum.RUNNING: continue sw.state = ModelInstanceStateEnum.RUNNING sw.state_message = "" patch_dict = { f"distributed_servers.subordinate_workers.{sw_pos}": sw, } # Update model instance. self._update_model_instance(model_instance.id, **patch_dict) @staticmethod def _get_main_worker_distributed_state( model_instance: ModelInstance, ) -> Optional[dict]: subordinate_workers = ( model_instance.distributed_servers.subordinate_workers if ( model_instance.distributed_servers and model_instance.distributed_servers.subordinate_workers ) else [] ) if not subordinate_workers: return None error_sw = None unreachable_sw = None all_running = True for sw in subordinate_workers: if sw.state == ModelInstanceStateEnum.ERROR: error_sw = sw break if ( sw.state == ModelInstanceStateEnum.UNREACHABLE and unreachable_sw is None ): unreachable_sw = sw if sw.state != ModelInstanceStateEnum.RUNNING: all_running = False if error_sw: return { "should_update": model_instance.state != ModelInstanceStateEnum.ERROR, "state": ModelInstanceStateEnum.ERROR, "state_message": ( f"Distributed serving error in subordinate worker " f"{error_sw.worker_ip}: {error_sw.state_message}." ), } if unreachable_sw: return { "should_update": model_instance.state != ModelInstanceStateEnum.UNREACHABLE, "state": ModelInstanceStateEnum.UNREACHABLE, "state_message": ( f"Distributed serving unreachable in subordinate worker " f"{unreachable_sw.worker_ip}: {unreachable_sw.state_message}." ), } if not all_running: return {"should_update": False} return None @staticmethod def _serve_model_instance( mi: ModelInstance, backend: BackendEnum, client_headers: dict, log_file_path: str, cfg: Config, worker_id: int, inference_backend: InferenceBackend, fallback_registry: Optional[str] = None, ): """ Serve model instance in a subprocess. Exits the subprocess when serving ends. Args: mi: The model instance to serve. backend: The backend of the model instance. client_headers: The headers for the clientset. log_file_path: The path to the log file. cfg: The configuration. worker_id: The ID of the worker. inference_backend: The inference backend configuration. fallback_registry: The fallback container registry to use if needed. """ setproctitle.setproctitle(f"gpustack_model_instance_{mi.id}") add_signal_handlers() clientset = ClientSet( base_url=cfg.get_server_url(), headers=client_headers, ) with open(log_file_path, "w", buffering=1, encoding="utf-8") as log_file: with RedirectStdoutStderr(log_file): try: server_cls = _SERVER_CLASS_MAPPING.get(backend, CustomServer) server_ins = server_cls( clientset, mi, cfg, worker_id, inference_backend, fallback_registry, ) logger.info(f"Provisioning model instance {mi.name}") server_ins.start() logger.info(f"Finished provisioning model instance {mi.name}") except Exception as e: logger.exception( f"Error provisioning model instance {mi.name}: {e}" ) raise e def sync_model_instances_inference_health(self): """ Synchronize model instances' inference health by sending actual inference requests. Per-model configuration is read from model.env: - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_ENABLED: "true"/"false" (default: false) - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_INTERVAL: seconds (default: global env) - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_TIMEOUT: seconds (default: 15) - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD: count (default: global env) If the model has received successful inference traffic recently (within the configured interval), the active health check is skipped. """ # Use the event-driven local cache instead of an API call. model_instances = [ mi for mi in self._model_instance_by_instance_id.values() if mi.state == ModelInstanceStateEnum.RUNNING ] if not model_instances: return now = time.time() for model_instance in model_instances: model = self._get_model(model_instance) if not model: continue # Read per-model config from model.env. config = _get_inference_health_check_config(model) if not config["enabled"]: continue interval = config["interval"] timeout = config["timeout"] threshold = config["threshold"] # Skip if the model is still provisioning. if self._is_provisioning(model_instance): continue # Skip if not enough time has passed since last check. last_check = self._last_health_check_time.get(model_instance.id, 0) if now - last_check < interval: continue self._last_health_check_time[model_instance.id] = now # Skip if recent successful inference was observed for this instance. last_success = self._last_successful_inference.get(model_instance.id, 0) if last_success > now - interval: logger.debug( f"Model instance {model_instance.name} had recent successful " f"inference, skipping health check." ) # Reset failure count since real traffic is succeeding. self._inference_health_check_failures.pop(model_instance.id, None) continue # Perform inference health check. if not is_inference_ready(model_instance, model, timeout=timeout): failure_count = self._inference_health_check_failures.get( model_instance.id, 0 ) failure_count += 1 self._inference_health_check_failures[model_instance.id] = failure_count if failure_count >= threshold: logger.warning( f"Model instance {model_instance.name} inference health check failed " f"{failure_count} times, updating state to ERROR." ) patch_dict = { "state": ModelInstanceStateEnum.ERROR, "state_message": _INFERENCE_HEALTH_CHECK_FAILED_MESSAGE, } self._update_model_instance(model_instance.id, **patch_dict) # Reset failure count after marking as error. del self._inference_health_check_failures[model_instance.id] else: logger.debug( f"Model instance {model_instance.name} inference health check failed " f"{failure_count}/{threshold} times." ) else: # Reset failure count on success. self._inference_health_check_failures.pop(model_instance.id, None) def _handle_model_instance_event(self, event: Event): # noqa: C901 """ Handle model instance events. Args: event: The model instance event to handle. """ mi = ModelInstance.model_validate(event.data) logger.trace( f"Received event: {str(event.type)}, id: {mi.id}, name: {mi.name}, state: {str(mi.state)}" ) is_main_worker = mi.worker_id == self._worker_id if is_main_worker: self._model_instance_by_instance_id[mi.id] = mi # Return if all subordinate workers aren't running. if ( mi.distributed_servers and mi.distributed_servers.mode == DistributedServerCoordinateModeEnum.RUN_FIRST and mi.distributed_servers.subordinate_workers ): ready = all( sw.state == ModelInstanceStateEnum.RUNNING for sw in mi.distributed_servers.subordinate_workers ) if not ready: logger.info( f"Model instance {mi.name} waits for all subordinate workers to be ready." ) return else: # Return if it isn't a distribution serving. if not mi.distributed_servers: return # Return if it's a delegated distribution, # which means the main worker is responsible for serving. if ( mi.distributed_servers.mode == DistributedServerCoordinateModeEnum.DELEGATED ): return # Return if it isn't the member of the distribution serving. joined = any( sw.worker_id == self._worker_id for sw in mi.distributed_servers.subordinate_workers or [] ) if not joined: return # Return if the main worker isn't initialized. if ( mi.distributed_servers.mode == DistributedServerCoordinateModeEnum.INITIALIZE_LATER and ( mi.state not in [ ModelInstanceStateEnum.STARTING, ModelInstanceStateEnum.RUNNING, ModelInstanceStateEnum.ERROR, ] ) ): logger.info( f"Model instance {mi.name} waits for main worker {mi.worker_ip} to be initialized." ) return # FIXME: This is a temporary solution to prevent the main worker from being unable to start due to phantom reads. # We confirm whether the operation should be performed by checking the state of the earlier subordinate worker. for sw in mi.distributed_servers.subordinate_workers: if sw.worker_id == self._worker_id: break if sw.state not in [ ModelInstanceStateEnum.RUNNING, ModelInstanceStateEnum.ERROR, ]: logger.info( f"Model instance {mi.name} waits for previous subordinate worker {sw.worker_ip} to be ready." ) return if event.type == EventType.DELETED: self._stop_model_instance(mi) logger.trace(f"DELETED event: stopped deleted model instance {mi.name}.") return if event.type == EventType.UPDATED: # Caching matched ERROR instances for restart handling. if mi.state == ModelInstanceStateEnum.ERROR: model = self._get_model(mi) if model.restart_on_error: self._error_model_instances[mi.id] = mi logger.trace( f"UPDATED event: cached error model instance {mi.name} for restart." ) return # Restart if scheduled and this is the assigned worker. if is_main_worker and mi.state == ModelInstanceStateEnum.SCHEDULED: self._restart_model_instance(mi) logger.trace( f"UPDATED event: restarted scheduled model instance {mi.name}." ) # Start on subordinate worker if not started yet, or restart if failed. if not is_main_worker: deployment_metadata = mi.get_deployment_metadata(self._worker_id) workload_name = ( deployment_metadata.name if deployment_metadata else mi.name ) workload = get_workload(workload_name) if not workload: self._start_model_instance(mi) logger.trace( f"UPDATED event: started model instance {mi.name} on subordinate worker." ) elif workload.state in [ WorkloadStatusStateEnum.UNKNOWN, WorkloadStatusStateEnum.UNHEALTHY, WorkloadStatusStateEnum.INACTIVE, WorkloadStatusStateEnum.FAILED, ]: self._stop_model_instance(mi, clear_restart_backoff=False) self._start_model_instance(mi) logger.trace( f"UPDATED event: restarted failed model instance {mi.name} on subordinate worker." ) return if event.type == EventType.CREATED: # Only handle CREATED if this is the assigned worker if not is_main_worker: return if mi.state == ModelInstanceStateEnum.RUNNING: logger.warning( f"Model instance {mi.name} is already running. Skipping start." ) return self._start_model_instance(mi) logger.trace(f"CREATED event: started created model instance {mi.name}.") def _get_numbered_log_path(self, mi: ModelInstance) -> str: """Get log file path with restart count. Args: mi: The model instance. Returns: Log file path with format: {log_dir}/{model_instance_id}.{restart_count}.log """ restart_count = mi.restart_count or 0 return f"{self._serve_log_dir}/{mi.id}.{restart_count}.log" def _persist_container_logs( self, workload_name: str, log_path: str, stop_event: threading.Event, token: Optional[str] = None, ): """Persist container logs to local file. This is a blocking operation that runs in a separate thread. Retries indefinitely until container is created. Args: workload_name: Name of the container workload log_path: Path to save container logs stop_event: Event to signal thread to stop token: Operation token identifying a specific container in the workload. If None, logs from the default (index=0) container are fetched. """ retry_count = 0 while not stop_event.is_set(): try: log_stream = logs_workload( name=workload_name, token=token, tail=-1, follow=True, ) if hasattr(log_stream, '__iter__'): with open(log_path, 'w', buffering=1, encoding='utf-8') as f: for line in log_stream: if stop_event.is_set(): break if isinstance(line, bytes): f.write(line.decode('utf-8', errors='replace')) else: f.write(str(line)) f.flush() break except Exception as e: if stop_event.is_set(): break retry_count += 1 logger.debug( f"Container not ready for {workload_name}, retrying " f"(attempt {retry_count}): {e}" ) stop_event.wait(timeout=2) logger.debug(f"Log persistence thread for {workload_name} exiting") def _discover_sidecar_logs( self, mi_id: int, workload_name: str, restart_count: int, stop_event: threading.Event, ): """Background thread that waits for sidecar containers to appear. Polls get_workload() until sidecar containers are found in the loggable list, then starts log persistence threads for each. Exits when sidecars are found or stop_event is set. Args: mi_id: Model instance ID workload_name: Workload name restart_count: Current restart count for log file naming stop_event: Event to signal thread to stop """ while not stop_event.is_set(): try: workload = get_workload(workload_name) if workload and workload.loggable: sidecars = [op for op in workload.loggable if op.name != "default"] if sidecars: self._start_sidecar_log_threads( mi_id, workload_name, workload.loggable, restart_count, ) logger.debug(f"Sidecar discovery for {workload_name} complete") return except Exception: pass stop_event.wait(timeout=2) def _start_sidecar_log_threads( self, mi_id: int, workload_name: str, loggable_ops: list, restart_count: int, ): """Start additional log persistence threads for sidecar containers. Called from the main log persistence thread once the workload is available and multiple loggable containers are discovered. Args: mi_id: Model instance ID workload_name: Workload name loggable_ops: List of WorkloadStatusOperation from workload.loggable restart_count: Current restart count for log file naming """ names = [] for op in loggable_ops: if op.name == "default": continue # Main container handled by caller thread log_path = ( f"{self._serve_log_dir}/{mi_id}.container." f"{op.name}.{restart_count}.log" ) stop_event = threading.Event() thread = threading.Thread( target=self._persist_container_logs, args=(workload_name, log_path, stop_event, op.token), daemon=True, name=f"log-persist-{workload_name}-{op.name}", ) thread.start() # Append to existing tracking lists. self._log_persistence_threads.setdefault(mi_id, []).append(thread) self._log_persistence_stop_events.setdefault(mi_id, []).append(stop_event) names.append(op.name) if names: logger.debug( f"Started sidecar log persistence threads for {workload_name}: " f"{names}" ) def _start_container_log_persistence(self, mi: ModelInstance): """Start a background thread to persist container logs. Starts a single "main" log persistence thread. The thread will automatically discover sidecar containers (e.g., Ray head) once the workload is created, and spawn additional threads for each. Args: mi: The model instance. """ # Stop and clean up existing threads if any self._stop_container_log_persistence(mi.id) # Use deployment metadata name for the actual workload name, # which differs for subordinate workers (e.g., "model-f0"). deployment_metadata = mi.get_deployment_metadata(self._worker_id) workload_name = deployment_metadata.name if deployment_metadata else mi.name restart_count = mi.restart_count or 0 log_path = f"{self._serve_log_dir}/{mi.id}.container.{restart_count}.log" stop_event = threading.Event() # Main container log thread. thread = threading.Thread( target=self._persist_container_logs, args=(workload_name, log_path, stop_event), daemon=True, name=f"log-persist-{workload_name}", ) thread.start() # Sidecar discovery thread — polls until sidecar containers appear, # then starts additional log threads for each. discovery_thread = threading.Thread( target=self._discover_sidecar_logs, args=(mi.id, workload_name, restart_count, stop_event), daemon=True, name=f"log-discover-{workload_name}", ) discovery_thread.start() self._log_persistence_threads[mi.id] = [thread, discovery_thread] self._log_persistence_stop_events[mi.id] = [stop_event] logger.debug(f"Started container log persistence thread for {mi.name}") def _stop_container_log_persistence( self, model_instance_id: int, timeout: float = 2.0 ): """Stop all container log persistence threads for a model instance. Args: model_instance_id: The model instance ID timeout: Maximum time to wait for each thread to stop (seconds) """ # Signal all threads to stop stop_events = self._log_persistence_stop_events.pop(model_instance_id, []) for stop_event in stop_events: stop_event.set() # Wait for all threads to finish threads = self._log_persistence_threads.pop(model_instance_id, []) for thread in threads: if thread and thread.is_alive(): thread.join(timeout=timeout) if thread.is_alive(): logger.warning( f"Log persistence thread {thread.name} for model instance " f"{model_instance_id} did not stop within {timeout}s" ) def _cleanup_old_logs(self, model_instance_id: int, current_restart_count: int): """Remove serve log files except the current and previous restart_count. Keeps files for restart_count in {R, R-1} where R is current_restart_count; when R is 0, only R is kept. Args: model_instance_id: Model instance ID current_restart_count: Restart count for the upcoming run (same as log path). """ try: log_dir = Path(self._serve_log_dir) # Separate main logs, container logs, and sidecar container logs main_log_pattern = f"{model_instance_id}.*.log" all_main_logs = [ f for f in log_dir.glob(main_log_pattern) if '.container.' not in f.name ] container_log_pattern = f"{model_instance_id}.container.*.log" all_container_files = list(log_dir.glob(container_log_pattern)) # Split into default container logs (e.g., 42.container.0.log) # and sidecar container logs (e.g., 42.container.ray-head.0.log) default_container_logs = [ f for f in all_container_files if extract_container_restart_count(f.name) > 0 or re.match(rf'{model_instance_id}\.container\.\d+\.log', f.name) ] sidecar_container_logs = [ f for f in all_container_files if f not in default_container_logs ] self._cleanup_log_type(all_main_logs, current_restart_count, "main") self._cleanup_log_type( default_container_logs, current_restart_count, "container" ) self._cleanup_log_type( sidecar_container_logs, current_restart_count, "sidecar_container" ) except Exception as e: logger.error(f"Failed to cleanup old logs for {model_instance_id}: {e}") def _cleanup_log_type( self, log_files: List[Path], current_restart_count: int, log_type: str, ): """Delete log files whose restart_count is not current or previous.""" keep = {current_restart_count} if current_restart_count > 0: keep.add(current_restart_count - 1) def _extract_sidecar_restart_count(filename: str) -> int: """Extract restart count from {id}.container.{name}.{restart_count}.log""" match = re.match(r'\d+\.container\.[^.]+\.(\d+)\.log', filename) return int(match.group(1)) if match else 0 extract_fns = { "main": extract_restart_count, "container": extract_container_restart_count, "sidecar_container": _extract_sidecar_restart_count, } extract_fn = extract_fns.get(log_type, extract_container_restart_count) for f in log_files: rc = extract_fn(f.name) if rc in keep: continue try: f.unlink() logger.info(f"Deleted old {log_type} log file: {f}") except Exception as e: logger.warning(f"Failed to delete {log_type} log file {f}: {e}") def _start_model_instance(self, mi: ModelInstance): # noqa: C901 """ Start model instance through a subprocess. Args: mi: The model instance to start. """ if self._is_provisioning(mi): logger.warning(f"Model instance {mi.name} is provisioning. Skipping start.") return # Clean up old log files before starting self._cleanup_old_logs(mi.id, mi.restart_count or 0) is_main_worker = mi.worker_id == self._worker_id log_file_path = self._get_numbered_log_path(mi) sw_pos: Optional[int] = None sw: Optional[ModelInstanceSubordinateWorker] = None if not is_main_worker: sw_pos = next( ( i for i, sw in enumerate(mi.distributed_servers.subordinate_workers) if sw.worker_id == self._worker_id ), ) sw = mi.distributed_servers.subordinate_workers[sw_pos] try: model = self._get_model(mi) backend = get_backend(model) self._assign_ports(mi, model, backend) logger.debug( f"Starting model instance {mi.name}" f"{'' if not is_main_worker else f' on ports {mi.ports if mi.ports else [mi.port]}'}" ) fallback_registry = ( registration.determine_default_registry( self._config.system_default_container_registry ) if is_built_in_backend(backend) else None ) process = multiprocessing.Process( target=ServeManager._serve_model_instance, args=( mi, backend, self._clientset.headers, log_file_path, self._config, self._worker_id, self._inference_backend_manager.get_backend_by_name(backend), fallback_registry, ), ) process.daemon = False process.start() self._provisioning_processes[mi.id] = process # Start container log persistence for containerized backends self._start_container_log_persistence(mi) # Get patch dict for main worker. if is_main_worker: patch_dict = { "state": ModelInstanceStateEnum.INITIALIZING, "port": mi.port, "ports": mi.ports, "pid": process.pid, } # Get patch dict for subordinate worker. else: sw.state = ModelInstanceStateEnum.INITIALIZING # For initialize later mode, the state is set to RUNNING directly, # which means the subordinate worker doesn't need to wait for the main worker to be healthy. if ( mi.distributed_servers.mode == DistributedServerCoordinateModeEnum.INITIALIZE_LATER ): sw.state = ModelInstanceStateEnum.RUNNING sw.pid = process.pid patch_dict = { f"distributed_servers.subordinate_workers.{sw_pos}": sw, } self._update_model_instance(mi.id, **patch_dict) logger.info( f"Started model instance {mi.name}" f"{'' if not is_main_worker else f' on ports {mi.ports if mi.ports else [mi.port]}'}" ) except Exception as e: # Clean up provisioning process if started. if mi.id in self._provisioning_processes: self._stop_model_instance(mi) # Get patch dict for main worker. if is_main_worker: patch_dict = { "state": ModelInstanceStateEnum.ERROR, "state_message": f"Failed to start model instance: {e}", } # Get patch dict for subordinate worker. else: sw.state = ModelInstanceStateEnum.ERROR sw.state_message = f"Failed to start model instance: {e}" patch_dict = { f"distributed_servers.subordinate_workers.{sw_pos}": sw, } self._update_model_instance(mi.id, **patch_dict) logger.error(f"Failed to start model instance {mi.name}: {e}") def _assign_ports( self, mi: ModelInstance, model: Model, backend: BackendEnum, ) -> None: """ Assign ports to the model instance. This method is thread-safe and allocates ports for: - Main serving port - RPC port for vLLM DP communication (if applicable) - Connecting port for subordinate workers (if applicable) Args: mi: The model instance to assign ports to. model: The model associated with the instance. backend: The backend type (e.g., vLLM, SGLang). """ if mi.port: # Port already assigned, skip. return with _port_lock: if mi.port: # Port already assigned, skip. return if self._assigned_ports: unavailable_ports = set.union(*self._assigned_ports.values()) else: unavailable_ports = set() # Main serving port mi.port = network.get_free_port( port_range=self._config.service_port_range, unavailable_ports=unavailable_ports, host=mi.worker_ip, ) mi.ports = [mi.port] unavailable_ports.add(mi.port) # Additional ports for distributed servers if mi.distributed_servers and mi.distributed_servers.subordinate_workers: # RPC port for DP communication in vLLM backend if backend == BackendEnum.VLLM: dps = find_int_parameter( model.backend_parameters, ["data-parallel-size", "dp"], ) if dps and dps > 1: dp_connecting_port = network.get_free_port( port_range=self._config.service_port_range, unavailable_ports=unavailable_ports, host=mi.worker_ip, ) mi.ports.append(dp_connecting_port) unavailable_ports.add(dp_connecting_port) # Connecting port for subordinate workers communication connecting_port = network.get_free_port( port_range=self._config.service_port_range, unavailable_ports=unavailable_ports, host=mi.worker_ip, ) mi.ports.append(connecting_port) unavailable_ports.add(connecting_port) self._assigned_ports[mi.id] = set(mi.ports) def _restart_model_instance(self, mi: ModelInstance): """ Restart model instance. Args: mi: The model instance to restart. """ self._stop_model_instance(mi, clear_restart_backoff=False) self._start_model_instance(mi) def _update_model(self, id: int, **kwargs): """ Update model instance with given fields. Args: id: The ID of the model instance to update. **kwargs: The fields to update, group by field name and value. """ try: m_public = self._clientset.models.get(id=id) m = ModelUpdate(**m_public.model_dump()) for key, value in kwargs.items(): set_attr(m, key, value) self._clientset.models.update(id=id, model_update=m) except NotFoundException: logger.warning(f"Model with ID {id} not found when trying to update.") def _update_model_instance(self, id: int, **kwargs): """ Update model instance with given fields. Args: id: The ID of the model instance to update. **kwargs: The fields to update, group by field name and value. """ try: mi_public = self._clientset.model_instances.get(id=id) mi = ModelInstanceUpdate(**mi_public.model_dump()) for key, value in kwargs.items(): set_attr(mi, key, value) self._clientset.model_instances.update(id=id, model_update=mi) except NotFoundException: logger.warning( f"Model instance with ID {id} not found when trying to update." ) def _stop_model_instance( self, mi: ModelInstance, clear_restart_backoff: bool = True ): """ Stop model instance and clean up. Args: mi: The model instance to stop. clear_restart_backoff: Whether to clear transient restart backoff state. """ logger.debug(f"Stopping model instance {mi.name or mi.id}") # Stop container log persistence thread self._stop_container_log_persistence(mi.id) # Teardown provisioning process if still alive. if self._is_provisioning(mi): terminate_process_tree(self._provisioning_processes[mi.id].pid) # Delete workload. deployment_metadata = mi.get_deployment_metadata(self._worker_id) if deployment_metadata: delete_workload(deployment_metadata.name) # Cleanup internal states. self._provisioning_processes.pop(mi.id, None) self._assigned_ports.pop(mi.id, None) self._error_model_instances.pop(mi.id, None) self._model_cache_by_instance.pop(mi.id, None) self._model_instance_by_instance_id.pop(mi.id, None) if clear_restart_backoff: self._restart_backoff_counts.pop(mi.id, None) self._inference_health_check_failures.pop(mi.id, None) self._last_health_check_time.pop(mi.id, None) self._last_successful_inference.pop(mi.id, None) logger.info(f"Stopped model instance {mi.name or mi.id}") def _restart_error_model_instance(self, mi: ModelInstance): """ Restart error model instance with exponential backoff, maximum delay 5 minutes. Args: mi: The model instance to restart. """ if self._is_provisioning(mi): logger.debug(f"Model instance {mi.name} is provisioning. Skipping restart.") return restart_count = mi.restart_count or 0 backoff_count = self._restart_backoff_counts.get(mi.id, 0) last_restart_time = mi.last_restart_time or mi.updated_at current_time = datetime.now(timezone.utc) delay = min(10 * (2 ** (backoff_count - 1)), 300) if backoff_count > 0 else 0 if backoff_count > 0 and last_restart_time: elapsed_time = (current_time - last_restart_time).total_seconds() if elapsed_time < delay: logger.trace( f"Delaying restart of {mi.name} for {delay - elapsed_time:.2f} seconds." ) return logger.info( f"Restarting model instance {mi.name} " f"(attempt {backoff_count + 1}) after {delay} seconds delay." ) with contextlib.suppress(NotFoundException): self._restart_backoff_counts[mi.id] = backoff_count + 1 self._update_model_instance( mi.id, restart_count=restart_count + 1, last_restart_time=current_time, state=ModelInstanceStateEnum.SCHEDULED, state_message="", ) # Pop from error model instances, # if failed to restart next time, it will be added again in watch_model_instance_events(). self._error_model_instances.pop(mi.id, None) def _get_model(self, mi: ModelInstance) -> Model: """ Efficiently get model related to the model instance with caching. Args: mi: The model instance whose model to get. """ if model := self._model_cache_by_instance.get(mi.id): return model model = self._clientset.models.get(mi.model_id) self._model_cache_by_instance[mi.id] = model return model def _refresh_model(self, mi: ModelInstance) -> Model: """ Refresh the model information from the server. Args: mi: The model instance whose model to refresh. Returns: The refreshed model. """ logger.debug(f"Refreshing model {mi.model_name} information from server.") refreshed_model = self._clientset.models.get(mi.model_id) self._model_cache_by_instance[mi.id] = refreshed_model return refreshed_model def _is_provisioning(self, mi: ModelInstance) -> bool: """ Check if the model instance is still provisioning. Args: mi: The model instance to check. """ if process := self._provisioning_processes.get(mi.id): if process.is_alive(): process.join(timeout=0) return process.is_alive() return False def _get_health_check_path(self, backend: str) -> Optional[str]: """ Get health check path for the given backend. Args: backend: The backend name. Returns: The health check path if exists, else None. """ inference_backend = self._inference_backend_manager.get_backend_by_name(backend) return inference_backend.health_check_path if inference_backend else None def get_instance_port_by_model_instance_id( self, model_instance_id: int ) -> Optional[int]: """ Get the port of the model instance related to the given model instance ID. Args: model_instance_id: The model instance ID to get the port for. Returns: The port of the model instance if it exists and is running, else None. """ instance = self._model_instance_by_instance_id.get( model_instance_id ) # Ensure the model instance is cached. return ( instance.ports[0] if instance and instance.state == ModelInstanceStateEnum.RUNNING else None ) def is_ready( backend: str, mi: ModelInstance, health_check_path: Optional[str] = None, model: Model = None, ) -> bool: """ Access the health endpoint of the given model instance to check if it is servable. """ is_built_in = is_built_in_backend(backend) if (not is_built_in or backend == BackendEnum.CUSTOM) and (not health_check_path): # If custom backend does not have health check path, consider it always ready. return True if backend == BackendEnum.ASCEND_MINDIE and not health_check_path: # Ref: https://www.hiascend.com/document/detail/zh/mindie/21RC2/mindieservice/servicedev/mindie_service0066.html # /info provides metadata information and requires more time to respond. Use it for health check. health_check_path = "/info" elif ( backend == BackendEnum.SGLANG and model and CategoryEnum.IMAGE in model.categories ): if not model.backend_version: # version may be empty at initialization, consider it not ready. return False elif compare_versions(model.backend_version, "0.5.5.post3") >= 0: # SGLang Diffusion supported health check path at v0.5.5.post3 health_check_path = "/health" else: # Older versions do not support health check, consider it always ready. return True elif is_built_in and backend != BackendEnum.CUSTOM and not health_check_path: # Built-in backends (vLLM, SGLang, vox-box) except (Custom, MindIE) use /v1/models as health check path. health_check_path = "/v1/models" try: # Use the worker IP instead of localhost for health check. # Reasons: # 1. Connectivity to the loopback address does not work with Ascend MindIE. # 2. More adaptable to container networks. health_check_url = f"http://{mi.worker_ip}:{mi.port}{health_check_path}" response = requests.get(health_check_url, timeout=1) if response.status_code == 200: return True except Exception as e: logger.debug(f"Error checking model instance {mi.name} health: {e}") pass return False def _get_inference_endpoint_and_payload(model: Model) -> tuple[str, dict] | None: """ Get inference endpoint and payload for the model. Returns None if the model type should skip health check. """ skip_categories = { CategoryEnum.IMAGE, CategoryEnum.SPEECH_TO_TEXT, CategoryEnum.TEXT_TO_SPEECH, CategoryEnum.UNKNOWN, } if not skip_categories.isdisjoint(model.categories): return None # Return endpoint and payload based on model type (priority order) if CategoryEnum.EMBEDDING in model.categories: return "/v1/embeddings", {"model": model.name, "input": "test"} if CategoryEnum.RERANKER in model.categories: return "/v1/rerank", { "model": model.name, "query": "test", "documents": ["test"], } return "/v1/chat/completions", { "model": model.name, "messages": [{"role": "user", "content": "ping"}], "max_tokens": 1, "max_completion_tokens": 1, } def _get_inference_health_check_config(model: Model) -> dict: """Read per-model inference health check config from model.env.""" env = model.env or {} enabled = env.get( "GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_ENABLED", "false" ).lower() in ( "true", "1", ) interval = safe_int( env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_INTERVAL"), 300, ) timeout = safe_int( env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_TIMEOUT"), 15, ) threshold = safe_int( env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"), 3, ) return { "enabled": enabled, "interval": interval, "timeout": timeout, "threshold": threshold, } def is_inference_ready(mi: ModelInstance, model: Model, timeout: int = 15) -> bool: """ Send a minimal inference request to verify the inference capability is working. """ # Check Custom backend (no standard inference API) if is_custom_backend(model.backend): return True # Check port assignment if not mi.port: logger.debug(f"Model instance {mi.name} does not have port assigned yet.") return False # Get endpoint and payload, None means skip health check result = _get_inference_endpoint_and_payload(model) if not result: logger.debug(f"Skipping inference check for {mi.name}") return True endpoint_path, payload = result inference_url = f"http://{mi.worker_ip}:{mi.port}{endpoint_path}" try: response = requests.post(inference_url, json=payload, timeout=timeout) if response.status_code == 200: return True else: logger.warning( f"Model instance {mi.name} inference health check failed " f"with status {response.status_code} for endpoint {endpoint_path}" ) except Exception as e: logger.debug( f"Error checking model instance {mi.name} inference at {endpoint_path}: {e}" ) return False