| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633 |
- import asyncio
- import contextlib
- from datetime import datetime, timezone
- import multiprocessing
- import re
- import threading
- import time
- import requests
- import setproctitle
- import os
- from typing import Dict, Optional, Set, List, Callable
- from pathlib import Path
- import logging
- from gpustack_runtime.deployer import (
- get_workload,
- WorkloadStatusStateEnum,
- delete_workload,
- logs_workload,
- )
- from gpustack_runtime.deployer.__utils__ import compare_versions
- from gpustack.api.exceptions import NotFoundException
- from gpustack.config.config import Config
- from gpustack.config import registration
- from gpustack.logging import (
- RedirectStdoutStderr,
- )
- from gpustack.schemas.inference_backend import (
- InferenceBackend,
- is_built_in_backend,
- is_custom_backend,
- )
- from gpustack.utils import network
- from gpustack.utils.convert import safe_int
- from gpustack.utils.attrs import set_attr
- from gpustack.utils.command import find_int_parameter
- from gpustack.utils.process import terminate_process_tree, add_signal_handlers
- from gpustack.worker.backends.ascend_mindie import AscendMindIEServer
- from gpustack.worker.backends.sglang import SGLangServer
- from gpustack.worker.backends.vllm import VLLMServer
- from gpustack.worker.backends.vox_box import VoxBoxServer
- from gpustack.worker.backends.custom import CustomServer
- from gpustack.routes.worker.logs import (
- extract_container_restart_count,
- extract_restart_count,
- )
- from gpustack.worker.model_meta import get_meta_from_running_instance
- from gpustack.client import ClientSet
- from gpustack.schemas.models import (
- BackendEnum,
- Model,
- ModelUpdate,
- ModelInstance,
- ModelInstanceUpdate,
- ModelInstanceStateEnum,
- get_backend,
- DistributedServerCoordinateModeEnum,
- ModelInstanceSubordinateWorker,
- CategoryEnum,
- )
- from gpustack.server.bus import Event, EventType
- from gpustack.worker.inference_backend_manager import InferenceBackendManager
- logger = logging.getLogger(__name__)
- # Inference health check error message
- _INFERENCE_HEALTH_CHECK_FAILED_MESSAGE = "Inference health check failed."
- # Global lock for port assignment to avoid pickle serialization issues
- _port_lock = threading.Lock()
- _SERVER_CLASS_MAPPING = {
- BackendEnum.VLLM: VLLMServer,
- BackendEnum.SGLANG: SGLangServer,
- BackendEnum.VOX_BOX: VoxBoxServer,
- BackendEnum.ASCEND_MINDIE: AscendMindIEServer,
- }
- class ServeManager:
- @property
- def _worker_id(self) -> int:
- return self._worker_id_getter()
- """
- The ID of current worker.
- """
- _config: Config
- """
- Global configuration.
- """
- _serve_log_dir: str
- """
- The directory to store logs of serving model instances(in subprocess).
- """
- @property
- def _clientset(self) -> ClientSet:
- return self._clientset_getter()
- """
- The clientset to access the API server.
- """
- _inference_backend_manager: InferenceBackendManager
- """
- The inference backend manager.
- """
- _provisioning_processes: Dict[int, multiprocessing.Process]
- """
- The mapping of model instance ID to provisioning (sub)process.
- When the (sub)process is alive, the model instance is provisioning.
- If the (sub)process exited, the model instance is either running or failed.
- """
- _log_persistence_threads: Dict[int, List[threading.Thread]]
- """
- The mapping of model instance ID to log persistence threads.
- Each model instance may have multiple threads (one per loggable container).
- """
- _log_persistence_stop_events: Dict[int, List[threading.Event]]
- """
- The mapping of model instance ID to stop events for log persistence threads.
- Used to signal threads to stop gracefully.
- """
- _error_model_instances: Dict[int, ModelInstance]
- """
- The mapping of model instance ID to error model instances.
- Used to restart error model instances.
- """
- _model_cache_by_instance: Dict[int, Model]
- """
- The cache of models by model instance ID.
- Used to avoid redundant API calls to get model information.
- """
- _model_instance_by_instance_id: Dict[int, ModelInstance]
- _clientset_getter: Callable[[], ClientSet]
- _worker_id_getter: Callable[[], int]
- def __init__(
- self,
- worker_id_getter: Callable[[], int],
- clientset_getter: Callable[[], ClientSet],
- cfg: Config,
- ):
- self._worker_id_getter = worker_id_getter
- self._config = cfg
- self._serve_log_dir = f"{cfg.log_dir}/serve"
- self._clientset_getter = clientset_getter
- self._provisioning_processes = {}
- self._log_persistence_threads = {}
- self._log_persistence_stop_events = {}
- self._error_model_instances = {}
- self._model_cache_by_instance = {}
- self._model_instance_by_instance_id = {}
- # Instance-level port tracking to avoid conflicts
- self._assigned_ports: Dict[int, Set[int]] = {}
- self._restart_backoff_counts: Dict[int, int] = {}
- # Inference health check failure tracking
- # {model_instance_id: failure_count}
- self._inference_health_check_failures: Dict[int, int] = {}
- # Track last successful inference per port (set by worker proxy)
- self._last_successful_inference: Dict[int, float] = {}
- # Track last health check time per model instance
- self._last_health_check_time: Dict[int, float] = {}
- os.makedirs(self._serve_log_dir, exist_ok=True)
- def record_successful_inference(self, instance_id: int):
- """Called by worker proxy on successful inference response."""
- self._last_successful_inference[instance_id] = time.time()
- async def watch_models(self):
- """
- Loop to watch models to keep the cache updated.
- """
- logger.debug("Watching models.")
- while True:
- try:
- # Watch models without callback to keep the cache updated.
- await self._clientset.models.awatch(callback=None)
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Error watching models: {e}")
- await asyncio.sleep(5)
- async def watch_model_instances_event(self):
- """
- Loop to watch model instances' event and handle.
- """
- logger.debug("Watching model instances event.")
- while True:
- try:
- await self._clientset.model_instances.awatch(
- callback=self._handle_model_instance_event
- )
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Error watching model instances: {e}")
- await asyncio.sleep(5)
- async def watch_model_instances(self):
- """
- Loop to post process model instances, for example, restarting error instances.
- """
- logger.debug("Watching model instances.")
- while True:
- try:
- for mi in list(self._error_model_instances.values()):
- self._restart_error_model_instance(mi)
- await asyncio.sleep(10)
- except Exception as e:
- logger.error(f"Error restarting model instances: {e}")
- await asyncio.sleep(5)
- def sync_model_instances_state(self): # noqa: C901
- """
- Synchronize model instances' state.
- - If the model instance is scheduled but not initialized, skip.
- - If the provision process is still alive, skip.
- - If the workload is still launching, skip.
- - If the workload is not existed, unhealthy, inactive or failed, update the model instance state to ERROR.
- - If everything is fine, update the model instance state to RUNNING.
- """
- # Get all model instances assigned to this worker.
- #
- # FIXME(thxCode): This may cause performance issues when there are many model instances in the system.
- # A mechanism is needed to improve efficiency here.
- model_instances_page = self._clientset.model_instances.list(use_cache=False)
- if not model_instances_page.items:
- return
- model_instances: List[ModelInstance] = []
- for model_instance in model_instances_page.items:
- # if the model instance is assigned to this worker, it must be scheduled.
- # But we don't need to sync the scheduled model when it is not initialized yet.
- if (
- model_instance.worker_id == self._worker_id
- and model_instance.state != ModelInstanceStateEnum.SCHEDULED
- ):
- model_instances.append(model_instance)
- if (
- model_instance.distributed_servers
- and model_instance.distributed_servers.subordinate_workers
- ):
- for sw in model_instance.distributed_servers.subordinate_workers:
- if sw.worker_id == self._worker_id:
- model_instances.append(model_instance)
- break
- for model_instance in model_instances:
- # Skip if the provision process has not exited yet.
- if self._is_provisioning(model_instance):
- logger.trace(
- f"Model instance {model_instance.name} is provisioning. Skipping sync."
- )
- continue
- is_main_worker = model_instance.worker_id == self._worker_id
- # Skip if the workload is still launching.
- # Use deployment metadata name for subordinate workers (e.g., "model-f0")
- # since their workload name differs from the model instance name.
- if is_main_worker:
- workload = get_workload(model_instance.name)
- else:
- deployment_metadata = model_instance.get_deployment_metadata(
- self._worker_id
- )
- workload_name = (
- deployment_metadata.name
- if deployment_metadata
- else model_instance.name
- )
- workload = get_workload(workload_name)
- if workload and workload.state in [
- WorkloadStatusStateEnum.PENDING,
- WorkloadStatusStateEnum.INITIALIZING,
- ]:
- logger.trace(
- f"Model instance {model_instance.name} workload is still launching. Skipping sync."
- )
- continue
- # Update model instance state to ERROR if the workload is not existed, unhealthy, inactive or failed.
- if not workload or workload.state in [
- WorkloadStatusStateEnum.UNKNOWN, # Rare, but possible, for example, leaving pause container.
- WorkloadStatusStateEnum.UNHEALTHY,
- WorkloadStatusStateEnum.INACTIVE,
- WorkloadStatusStateEnum.FAILED,
- ]:
- # Only if not in ERROR state yet.
- if model_instance.state != ModelInstanceStateEnum.ERROR:
- with contextlib.suppress(NotFoundException):
- # Get patch dict for main worker.
- if is_main_worker:
- patch_dict = {
- "state": ModelInstanceStateEnum.ERROR,
- "state_message": "Inference server exited or unhealthy.",
- }
- # Get patch dict for subordinate worker.
- else:
- sw_pos = next(
- (
- i
- for i, sw in enumerate(
- model_instance.distributed_servers.subordinate_workers
- )
- if sw.worker_id == self._worker_id
- ),
- )
- sw = model_instance.distributed_servers.subordinate_workers[
- sw_pos
- ]
- sw.state = ModelInstanceStateEnum.ERROR
- sw.state_message = "Inference server exited or unhealthy."
- patch_dict = {
- f"distributed_servers.subordinate_workers.{sw_pos}": sw,
- }
- # Update model instance.
- self._update_model_instance(model_instance.id, **patch_dict)
- continue
- # Otherwise, update model instance state to RUNNING if everything is fine.
- model = self._get_model(model_instance)
- if not model.backend_version:
- # backend version may be empty on initialization.
- # try to refresh to get updated model info on syncs.
- model = self._refresh_model(model_instance)
- backend = get_backend(model)
- health_check_path = self._get_health_check_path(backend)
- if model.env and 'GPUSTACK_MODEL_HEALTH_CHECK_PATH' in model.env:
- # NOTE: There is no known use case for now. Keep this in case the built-in backends
- # introduce breaking changes and the default health check path no longer works.
- health_check_path = model.env['GPUSTACK_MODEL_HEALTH_CHECK_PATH']
- with contextlib.suppress(NotFoundException):
- # Get patch dict for main worker.
- if is_main_worker:
- subordinate_state = self._get_main_worker_distributed_state(
- model_instance
- )
- if subordinate_state is None:
- if model_instance.state == ModelInstanceStateEnum.RUNNING:
- self._restart_backoff_counts.pop(model_instance.id, None)
- continue
- if (
- model_instance.state == ModelInstanceStateEnum.ERROR
- or not is_ready(
- backend, model_instance, health_check_path, model
- )
- ):
- continue
- self._restart_backoff_counts.pop(model_instance.id, None)
- patch_dict = {
- "state": ModelInstanceStateEnum.RUNNING,
- "state_message": "",
- }
- # Fetch model meta once running.
- meta = get_meta_from_running_instance(
- model_instance, backend, model
- )
- if meta:
- # Some meta is set in server evaluation and should be preserved, so we update meta instead of overwrite.
- merged_meta = dict(model.meta or {})
- merged_meta.update(meta)
- if merged_meta != model.meta:
- self._update_model(model.id, meta=merged_meta)
- elif subordinate_state["should_update"]:
- patch_dict = {
- "state": subordinate_state["state"],
- "state_message": subordinate_state["state_message"],
- }
- else:
- continue
- # Get patch dict for subordinate worker.
- else:
- # For initialize later mode, the state is set to RUNNING directly,
- # which means the subordinate worker doesn't need to wait for the main worker to be healthy.
- if (
- model_instance.distributed_servers.mode
- == DistributedServerCoordinateModeEnum.INITIALIZE_LATER
- ):
- continue
- # Otherwise, update subordinate worker state to RUNNING.
- sw_pos = next(
- (
- i
- for i, sw in enumerate(
- model_instance.distributed_servers.subordinate_workers
- )
- if sw.worker_id == self._worker_id
- ),
- )
- sw = model_instance.distributed_servers.subordinate_workers[sw_pos]
- if sw.state == ModelInstanceStateEnum.RUNNING:
- continue
- sw.state = ModelInstanceStateEnum.RUNNING
- sw.state_message = ""
- patch_dict = {
- f"distributed_servers.subordinate_workers.{sw_pos}": sw,
- }
- # Update model instance.
- self._update_model_instance(model_instance.id, **patch_dict)
- @staticmethod
- def _get_main_worker_distributed_state(
- model_instance: ModelInstance,
- ) -> Optional[dict]:
- subordinate_workers = (
- model_instance.distributed_servers.subordinate_workers
- if (
- model_instance.distributed_servers
- and model_instance.distributed_servers.subordinate_workers
- )
- else []
- )
- if not subordinate_workers:
- return None
- error_sw = None
- unreachable_sw = None
- all_running = True
- for sw in subordinate_workers:
- if sw.state == ModelInstanceStateEnum.ERROR:
- error_sw = sw
- break
- if (
- sw.state == ModelInstanceStateEnum.UNREACHABLE
- and unreachable_sw is None
- ):
- unreachable_sw = sw
- if sw.state != ModelInstanceStateEnum.RUNNING:
- all_running = False
- if error_sw:
- return {
- "should_update": model_instance.state != ModelInstanceStateEnum.ERROR,
- "state": ModelInstanceStateEnum.ERROR,
- "state_message": (
- f"Distributed serving error in subordinate worker "
- f"{error_sw.worker_ip}: {error_sw.state_message}."
- ),
- }
- if unreachable_sw:
- return {
- "should_update": model_instance.state
- != ModelInstanceStateEnum.UNREACHABLE,
- "state": ModelInstanceStateEnum.UNREACHABLE,
- "state_message": (
- f"Distributed serving unreachable in subordinate worker "
- f"{unreachable_sw.worker_ip}: {unreachable_sw.state_message}."
- ),
- }
- if not all_running:
- return {"should_update": False}
- return None
- @staticmethod
- def _serve_model_instance(
- mi: ModelInstance,
- backend: BackendEnum,
- client_headers: dict,
- log_file_path: str,
- cfg: Config,
- worker_id: int,
- inference_backend: InferenceBackend,
- fallback_registry: Optional[str] = None,
- ):
- """
- Serve model instance in a subprocess.
- Exits the subprocess when serving ends.
- Args:
- mi: The model instance to serve.
- backend: The backend of the model instance.
- client_headers: The headers for the clientset.
- log_file_path: The path to the log file.
- cfg: The configuration.
- worker_id: The ID of the worker.
- inference_backend: The inference backend configuration.
- fallback_registry: The fallback container registry to use if needed.
- """
- setproctitle.setproctitle(f"gpustack_model_instance_{mi.id}")
- add_signal_handlers()
- clientset = ClientSet(
- base_url=cfg.get_server_url(),
- headers=client_headers,
- )
- with open(log_file_path, "w", buffering=1, encoding="utf-8") as log_file:
- with RedirectStdoutStderr(log_file):
- try:
- server_cls = _SERVER_CLASS_MAPPING.get(backend, CustomServer)
- server_ins = server_cls(
- clientset,
- mi,
- cfg,
- worker_id,
- inference_backend,
- fallback_registry,
- )
- logger.info(f"Provisioning model instance {mi.name}")
- server_ins.start()
- logger.info(f"Finished provisioning model instance {mi.name}")
- except Exception as e:
- logger.exception(
- f"Error provisioning model instance {mi.name}: {e}"
- )
- raise e
- def sync_model_instances_inference_health(self):
- """
- Synchronize model instances' inference health by sending actual inference requests.
- Per-model configuration is read from model.env:
- - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_ENABLED: "true"/"false" (default: false)
- - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_INTERVAL: seconds (default: global env)
- - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_TIMEOUT: seconds (default: 15)
- - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD: count (default: global env)
- If the model has received successful inference traffic recently
- (within the configured interval), the active health check is skipped.
- """
- # Use the event-driven local cache instead of an API call.
- model_instances = [
- mi
- for mi in self._model_instance_by_instance_id.values()
- if mi.state == ModelInstanceStateEnum.RUNNING
- ]
- if not model_instances:
- return
- now = time.time()
- for model_instance in model_instances:
- model = self._get_model(model_instance)
- if not model:
- continue
- # Read per-model config from model.env.
- config = _get_inference_health_check_config(model)
- if not config["enabled"]:
- continue
- interval = config["interval"]
- timeout = config["timeout"]
- threshold = config["threshold"]
- # Skip if the model is still provisioning.
- if self._is_provisioning(model_instance):
- continue
- # Skip if not enough time has passed since last check.
- last_check = self._last_health_check_time.get(model_instance.id, 0)
- if now - last_check < interval:
- continue
- self._last_health_check_time[model_instance.id] = now
- # Skip if recent successful inference was observed for this instance.
- last_success = self._last_successful_inference.get(model_instance.id, 0)
- if last_success > now - interval:
- logger.debug(
- f"Model instance {model_instance.name} had recent successful "
- f"inference, skipping health check."
- )
- # Reset failure count since real traffic is succeeding.
- self._inference_health_check_failures.pop(model_instance.id, None)
- continue
- # Perform inference health check.
- if not is_inference_ready(model_instance, model, timeout=timeout):
- failure_count = self._inference_health_check_failures.get(
- model_instance.id, 0
- )
- failure_count += 1
- self._inference_health_check_failures[model_instance.id] = failure_count
- if failure_count >= threshold:
- logger.warning(
- f"Model instance {model_instance.name} inference health check failed "
- f"{failure_count} times, updating state to ERROR."
- )
- patch_dict = {
- "state": ModelInstanceStateEnum.ERROR,
- "state_message": _INFERENCE_HEALTH_CHECK_FAILED_MESSAGE,
- }
- self._update_model_instance(model_instance.id, **patch_dict)
- # Reset failure count after marking as error.
- del self._inference_health_check_failures[model_instance.id]
- else:
- logger.debug(
- f"Model instance {model_instance.name} inference health check failed "
- f"{failure_count}/{threshold} times."
- )
- else:
- # Reset failure count on success.
- self._inference_health_check_failures.pop(model_instance.id, None)
- def _handle_model_instance_event(self, event: Event): # noqa: C901
- """
- Handle model instance events.
- Args:
- event: The model instance event to handle.
- """
- mi = ModelInstance.model_validate(event.data)
- logger.trace(
- f"Received event: {str(event.type)}, id: {mi.id}, name: {mi.name}, state: {str(mi.state)}"
- )
- is_main_worker = mi.worker_id == self._worker_id
- if is_main_worker:
- self._model_instance_by_instance_id[mi.id] = mi
- # Return if all subordinate workers aren't running.
- if (
- mi.distributed_servers
- and mi.distributed_servers.mode
- == DistributedServerCoordinateModeEnum.RUN_FIRST
- and mi.distributed_servers.subordinate_workers
- ):
- ready = all(
- sw.state == ModelInstanceStateEnum.RUNNING
- for sw in mi.distributed_servers.subordinate_workers
- )
- if not ready:
- logger.info(
- f"Model instance {mi.name} waits for all subordinate workers to be ready."
- )
- return
- else:
- # Return if it isn't a distribution serving.
- if not mi.distributed_servers:
- return
- # Return if it's a delegated distribution,
- # which means the main worker is responsible for serving.
- if (
- mi.distributed_servers.mode
- == DistributedServerCoordinateModeEnum.DELEGATED
- ):
- return
- # Return if it isn't the member of the distribution serving.
- joined = any(
- sw.worker_id == self._worker_id
- for sw in mi.distributed_servers.subordinate_workers or []
- )
- if not joined:
- return
- # Return if the main worker isn't initialized.
- if (
- mi.distributed_servers.mode
- == DistributedServerCoordinateModeEnum.INITIALIZE_LATER
- and (
- mi.state
- not in [
- ModelInstanceStateEnum.STARTING,
- ModelInstanceStateEnum.RUNNING,
- ModelInstanceStateEnum.ERROR,
- ]
- )
- ):
- logger.info(
- f"Model instance {mi.name} waits for main worker {mi.worker_ip} to be initialized."
- )
- return
- # FIXME: This is a temporary solution to prevent the main worker from being unable to start due to phantom reads.
- # We confirm whether the operation should be performed by checking the state of the earlier subordinate worker.
- for sw in mi.distributed_servers.subordinate_workers:
- if sw.worker_id == self._worker_id:
- break
- if sw.state not in [
- ModelInstanceStateEnum.RUNNING,
- ModelInstanceStateEnum.ERROR,
- ]:
- logger.info(
- f"Model instance {mi.name} waits for previous subordinate worker {sw.worker_ip} to be ready."
- )
- return
- if event.type == EventType.DELETED:
- self._stop_model_instance(mi)
- logger.trace(f"DELETED event: stopped deleted model instance {mi.name}.")
- return
- if event.type == EventType.UPDATED:
- # Caching matched ERROR instances for restart handling.
- if mi.state == ModelInstanceStateEnum.ERROR:
- model = self._get_model(mi)
- if model.restart_on_error:
- self._error_model_instances[mi.id] = mi
- logger.trace(
- f"UPDATED event: cached error model instance {mi.name} for restart."
- )
- return
- # Restart if scheduled and this is the assigned worker.
- if is_main_worker and mi.state == ModelInstanceStateEnum.SCHEDULED:
- self._restart_model_instance(mi)
- logger.trace(
- f"UPDATED event: restarted scheduled model instance {mi.name}."
- )
- # Start on subordinate worker if not started yet, or restart if failed.
- if not is_main_worker:
- deployment_metadata = mi.get_deployment_metadata(self._worker_id)
- workload_name = (
- deployment_metadata.name if deployment_metadata else mi.name
- )
- workload = get_workload(workload_name)
- if not workload:
- self._start_model_instance(mi)
- logger.trace(
- f"UPDATED event: started model instance {mi.name} on subordinate worker."
- )
- elif workload.state in [
- WorkloadStatusStateEnum.UNKNOWN,
- WorkloadStatusStateEnum.UNHEALTHY,
- WorkloadStatusStateEnum.INACTIVE,
- WorkloadStatusStateEnum.FAILED,
- ]:
- self._stop_model_instance(mi, clear_restart_backoff=False)
- self._start_model_instance(mi)
- logger.trace(
- f"UPDATED event: restarted failed model instance {mi.name} on subordinate worker."
- )
- return
- if event.type == EventType.CREATED:
- # Only handle CREATED if this is the assigned worker
- if not is_main_worker:
- return
- if mi.state == ModelInstanceStateEnum.RUNNING:
- logger.warning(
- f"Model instance {mi.name} is already running. Skipping start."
- )
- return
- self._start_model_instance(mi)
- logger.trace(f"CREATED event: started created model instance {mi.name}.")
- def _get_numbered_log_path(self, mi: ModelInstance) -> str:
- """Get log file path with restart count.
- Args:
- mi: The model instance.
- Returns:
- Log file path with format: {log_dir}/{model_instance_id}.{restart_count}.log
- """
- restart_count = mi.restart_count or 0
- return f"{self._serve_log_dir}/{mi.id}.{restart_count}.log"
- def _persist_container_logs(
- self,
- workload_name: str,
- log_path: str,
- stop_event: threading.Event,
- token: Optional[str] = None,
- ):
- """Persist container logs to local file.
- This is a blocking operation that runs in a separate thread.
- Retries indefinitely until container is created.
- Args:
- workload_name: Name of the container workload
- log_path: Path to save container logs
- stop_event: Event to signal thread to stop
- token: Operation token identifying a specific container in the workload.
- If None, logs from the default (index=0) container are fetched.
- """
- retry_count = 0
- while not stop_event.is_set():
- try:
- log_stream = logs_workload(
- name=workload_name,
- token=token,
- tail=-1,
- follow=True,
- )
- if hasattr(log_stream, '__iter__'):
- with open(log_path, 'w', buffering=1, encoding='utf-8') as f:
- for line in log_stream:
- if stop_event.is_set():
- break
- if isinstance(line, bytes):
- f.write(line.decode('utf-8', errors='replace'))
- else:
- f.write(str(line))
- f.flush()
- break
- except Exception as e:
- if stop_event.is_set():
- break
- retry_count += 1
- logger.debug(
- f"Container not ready for {workload_name}, retrying "
- f"(attempt {retry_count}): {e}"
- )
- stop_event.wait(timeout=2)
- logger.debug(f"Log persistence thread for {workload_name} exiting")
- def _discover_sidecar_logs(
- self,
- mi_id: int,
- workload_name: str,
- restart_count: int,
- stop_event: threading.Event,
- ):
- """Background thread that waits for sidecar containers to appear.
- Polls get_workload() until sidecar containers are found in the
- loggable list, then starts log persistence threads for each.
- Exits when sidecars are found or stop_event is set.
- Args:
- mi_id: Model instance ID
- workload_name: Workload name
- restart_count: Current restart count for log file naming
- stop_event: Event to signal thread to stop
- """
- while not stop_event.is_set():
- try:
- workload = get_workload(workload_name)
- if workload and workload.loggable:
- sidecars = [op for op in workload.loggable if op.name != "default"]
- if sidecars:
- self._start_sidecar_log_threads(
- mi_id,
- workload_name,
- workload.loggable,
- restart_count,
- )
- logger.debug(f"Sidecar discovery for {workload_name} complete")
- return
- except Exception:
- pass
- stop_event.wait(timeout=2)
- def _start_sidecar_log_threads(
- self,
- mi_id: int,
- workload_name: str,
- loggable_ops: list,
- restart_count: int,
- ):
- """Start additional log persistence threads for sidecar containers.
- Called from the main log persistence thread once the workload is available
- and multiple loggable containers are discovered.
- Args:
- mi_id: Model instance ID
- workload_name: Workload name
- loggable_ops: List of WorkloadStatusOperation from workload.loggable
- restart_count: Current restart count for log file naming
- """
- names = []
- for op in loggable_ops:
- if op.name == "default":
- continue # Main container handled by caller thread
- log_path = (
- f"{self._serve_log_dir}/{mi_id}.container."
- f"{op.name}.{restart_count}.log"
- )
- stop_event = threading.Event()
- thread = threading.Thread(
- target=self._persist_container_logs,
- args=(workload_name, log_path, stop_event, op.token),
- daemon=True,
- name=f"log-persist-{workload_name}-{op.name}",
- )
- thread.start()
- # Append to existing tracking lists.
- self._log_persistence_threads.setdefault(mi_id, []).append(thread)
- self._log_persistence_stop_events.setdefault(mi_id, []).append(stop_event)
- names.append(op.name)
- if names:
- logger.debug(
- f"Started sidecar log persistence threads for {workload_name}: "
- f"{names}"
- )
- def _start_container_log_persistence(self, mi: ModelInstance):
- """Start a background thread to persist container logs.
- Starts a single "main" log persistence thread. The thread will
- automatically discover sidecar containers (e.g., Ray head) once
- the workload is created, and spawn additional threads for each.
- Args:
- mi: The model instance.
- """
- # Stop and clean up existing threads if any
- self._stop_container_log_persistence(mi.id)
- # Use deployment metadata name for the actual workload name,
- # which differs for subordinate workers (e.g., "model-f0").
- deployment_metadata = mi.get_deployment_metadata(self._worker_id)
- workload_name = deployment_metadata.name if deployment_metadata else mi.name
- restart_count = mi.restart_count or 0
- log_path = f"{self._serve_log_dir}/{mi.id}.container.{restart_count}.log"
- stop_event = threading.Event()
- # Main container log thread.
- thread = threading.Thread(
- target=self._persist_container_logs,
- args=(workload_name, log_path, stop_event),
- daemon=True,
- name=f"log-persist-{workload_name}",
- )
- thread.start()
- # Sidecar discovery thread — polls until sidecar containers appear,
- # then starts additional log threads for each.
- discovery_thread = threading.Thread(
- target=self._discover_sidecar_logs,
- args=(mi.id, workload_name, restart_count, stop_event),
- daemon=True,
- name=f"log-discover-{workload_name}",
- )
- discovery_thread.start()
- self._log_persistence_threads[mi.id] = [thread, discovery_thread]
- self._log_persistence_stop_events[mi.id] = [stop_event]
- logger.debug(f"Started container log persistence thread for {mi.name}")
- def _stop_container_log_persistence(
- self, model_instance_id: int, timeout: float = 2.0
- ):
- """Stop all container log persistence threads for a model instance.
- Args:
- model_instance_id: The model instance ID
- timeout: Maximum time to wait for each thread to stop (seconds)
- """
- # Signal all threads to stop
- stop_events = self._log_persistence_stop_events.pop(model_instance_id, [])
- for stop_event in stop_events:
- stop_event.set()
- # Wait for all threads to finish
- threads = self._log_persistence_threads.pop(model_instance_id, [])
- for thread in threads:
- if thread and thread.is_alive():
- thread.join(timeout=timeout)
- if thread.is_alive():
- logger.warning(
- f"Log persistence thread {thread.name} for model instance "
- f"{model_instance_id} did not stop within {timeout}s"
- )
- def _cleanup_old_logs(self, model_instance_id: int, current_restart_count: int):
- """Remove serve log files except the current and previous restart_count.
- Keeps files for restart_count in {R, R-1} where R is current_restart_count;
- when R is 0, only R is kept.
- Args:
- model_instance_id: Model instance ID
- current_restart_count: Restart count for the upcoming run (same as log path).
- """
- try:
- log_dir = Path(self._serve_log_dir)
- # Separate main logs, container logs, and sidecar container logs
- main_log_pattern = f"{model_instance_id}.*.log"
- all_main_logs = [
- f for f in log_dir.glob(main_log_pattern) if '.container.' not in f.name
- ]
- container_log_pattern = f"{model_instance_id}.container.*.log"
- all_container_files = list(log_dir.glob(container_log_pattern))
- # Split into default container logs (e.g., 42.container.0.log)
- # and sidecar container logs (e.g., 42.container.ray-head.0.log)
- default_container_logs = [
- f
- for f in all_container_files
- if extract_container_restart_count(f.name) > 0
- or re.match(rf'{model_instance_id}\.container\.\d+\.log', f.name)
- ]
- sidecar_container_logs = [
- f for f in all_container_files if f not in default_container_logs
- ]
- self._cleanup_log_type(all_main_logs, current_restart_count, "main")
- self._cleanup_log_type(
- default_container_logs, current_restart_count, "container"
- )
- self._cleanup_log_type(
- sidecar_container_logs, current_restart_count, "sidecar_container"
- )
- except Exception as e:
- logger.error(f"Failed to cleanup old logs for {model_instance_id}: {e}")
- def _cleanup_log_type(
- self,
- log_files: List[Path],
- current_restart_count: int,
- log_type: str,
- ):
- """Delete log files whose restart_count is not current or previous."""
- keep = {current_restart_count}
- if current_restart_count > 0:
- keep.add(current_restart_count - 1)
- def _extract_sidecar_restart_count(filename: str) -> int:
- """Extract restart count from {id}.container.{name}.{restart_count}.log"""
- match = re.match(r'\d+\.container\.[^.]+\.(\d+)\.log', filename)
- return int(match.group(1)) if match else 0
- extract_fns = {
- "main": extract_restart_count,
- "container": extract_container_restart_count,
- "sidecar_container": _extract_sidecar_restart_count,
- }
- extract_fn = extract_fns.get(log_type, extract_container_restart_count)
- for f in log_files:
- rc = extract_fn(f.name)
- if rc in keep:
- continue
- try:
- f.unlink()
- logger.info(f"Deleted old {log_type} log file: {f}")
- except Exception as e:
- logger.warning(f"Failed to delete {log_type} log file {f}: {e}")
- def _start_model_instance(self, mi: ModelInstance): # noqa: C901
- """
- Start model instance through a subprocess.
- Args:
- mi: The model instance to start.
- """
- if self._is_provisioning(mi):
- logger.warning(f"Model instance {mi.name} is provisioning. Skipping start.")
- return
- # Clean up old log files before starting
- self._cleanup_old_logs(mi.id, mi.restart_count or 0)
- is_main_worker = mi.worker_id == self._worker_id
- log_file_path = self._get_numbered_log_path(mi)
- sw_pos: Optional[int] = None
- sw: Optional[ModelInstanceSubordinateWorker] = None
- if not is_main_worker:
- sw_pos = next(
- (
- i
- for i, sw in enumerate(mi.distributed_servers.subordinate_workers)
- if sw.worker_id == self._worker_id
- ),
- )
- sw = mi.distributed_servers.subordinate_workers[sw_pos]
- try:
- model = self._get_model(mi)
- backend = get_backend(model)
- self._assign_ports(mi, model, backend)
- logger.debug(
- f"Starting model instance {mi.name}"
- f"{'' if not is_main_worker else f' on ports {mi.ports if mi.ports else [mi.port]}'}"
- )
- fallback_registry = (
- registration.determine_default_registry(
- self._config.system_default_container_registry
- )
- if is_built_in_backend(backend)
- else None
- )
- process = multiprocessing.Process(
- target=ServeManager._serve_model_instance,
- args=(
- mi,
- backend,
- self._clientset.headers,
- log_file_path,
- self._config,
- self._worker_id,
- self._inference_backend_manager.get_backend_by_name(backend),
- fallback_registry,
- ),
- )
- process.daemon = False
- process.start()
- self._provisioning_processes[mi.id] = process
- # Start container log persistence for containerized backends
- self._start_container_log_persistence(mi)
- # Get patch dict for main worker.
- if is_main_worker:
- patch_dict = {
- "state": ModelInstanceStateEnum.INITIALIZING,
- "port": mi.port,
- "ports": mi.ports,
- "pid": process.pid,
- }
- # Get patch dict for subordinate worker.
- else:
- sw.state = ModelInstanceStateEnum.INITIALIZING
- # For initialize later mode, the state is set to RUNNING directly,
- # which means the subordinate worker doesn't need to wait for the main worker to be healthy.
- if (
- mi.distributed_servers.mode
- == DistributedServerCoordinateModeEnum.INITIALIZE_LATER
- ):
- sw.state = ModelInstanceStateEnum.RUNNING
- sw.pid = process.pid
- patch_dict = {
- f"distributed_servers.subordinate_workers.{sw_pos}": sw,
- }
- self._update_model_instance(mi.id, **patch_dict)
- logger.info(
- f"Started model instance {mi.name}"
- f"{'' if not is_main_worker else f' on ports {mi.ports if mi.ports else [mi.port]}'}"
- )
- except Exception as e:
- # Clean up provisioning process if started.
- if mi.id in self._provisioning_processes:
- self._stop_model_instance(mi)
- # Get patch dict for main worker.
- if is_main_worker:
- patch_dict = {
- "state": ModelInstanceStateEnum.ERROR,
- "state_message": f"Failed to start model instance: {e}",
- }
- # Get patch dict for subordinate worker.
- else:
- sw.state = ModelInstanceStateEnum.ERROR
- sw.state_message = f"Failed to start model instance: {e}"
- patch_dict = {
- f"distributed_servers.subordinate_workers.{sw_pos}": sw,
- }
- self._update_model_instance(mi.id, **patch_dict)
- logger.error(f"Failed to start model instance {mi.name}: {e}")
- def _assign_ports(
- self,
- mi: ModelInstance,
- model: Model,
- backend: BackendEnum,
- ) -> None:
- """
- Assign ports to the model instance.
- This method is thread-safe and allocates ports for:
- - Main serving port
- - RPC port for vLLM DP communication (if applicable)
- - Connecting port for subordinate workers (if applicable)
- Args:
- mi: The model instance to assign ports to.
- model: The model associated with the instance.
- backend: The backend type (e.g., vLLM, SGLang).
- """
- if mi.port:
- # Port already assigned, skip.
- return
- with _port_lock:
- if mi.port:
- # Port already assigned, skip.
- return
- if self._assigned_ports:
- unavailable_ports = set.union(*self._assigned_ports.values())
- else:
- unavailable_ports = set()
- # Main serving port
- mi.port = network.get_free_port(
- port_range=self._config.service_port_range,
- unavailable_ports=unavailable_ports,
- host=mi.worker_ip,
- )
- mi.ports = [mi.port]
- unavailable_ports.add(mi.port)
- # Additional ports for distributed servers
- if mi.distributed_servers and mi.distributed_servers.subordinate_workers:
- # RPC port for DP communication in vLLM backend
- if backend == BackendEnum.VLLM:
- dps = find_int_parameter(
- model.backend_parameters,
- ["data-parallel-size", "dp"],
- )
- if dps and dps > 1:
- dp_connecting_port = network.get_free_port(
- port_range=self._config.service_port_range,
- unavailable_ports=unavailable_ports,
- host=mi.worker_ip,
- )
- mi.ports.append(dp_connecting_port)
- unavailable_ports.add(dp_connecting_port)
- # Connecting port for subordinate workers communication
- connecting_port = network.get_free_port(
- port_range=self._config.service_port_range,
- unavailable_ports=unavailable_ports,
- host=mi.worker_ip,
- )
- mi.ports.append(connecting_port)
- unavailable_ports.add(connecting_port)
- self._assigned_ports[mi.id] = set(mi.ports)
- def _restart_model_instance(self, mi: ModelInstance):
- """
- Restart model instance.
- Args:
- mi: The model instance to restart.
- """
- self._stop_model_instance(mi, clear_restart_backoff=False)
- self._start_model_instance(mi)
- def _update_model(self, id: int, **kwargs):
- """
- Update model instance with given fields.
- Args:
- id: The ID of the model instance to update.
- **kwargs: The fields to update, group by field name and value.
- """
- try:
- m_public = self._clientset.models.get(id=id)
- m = ModelUpdate(**m_public.model_dump())
- for key, value in kwargs.items():
- set_attr(m, key, value)
- self._clientset.models.update(id=id, model_update=m)
- except NotFoundException:
- logger.warning(f"Model with ID {id} not found when trying to update.")
- def _update_model_instance(self, id: int, **kwargs):
- """
- Update model instance with given fields.
- Args:
- id: The ID of the model instance to update.
- **kwargs: The fields to update, group by field name and value.
- """
- try:
- mi_public = self._clientset.model_instances.get(id=id)
- mi = ModelInstanceUpdate(**mi_public.model_dump())
- for key, value in kwargs.items():
- set_attr(mi, key, value)
- self._clientset.model_instances.update(id=id, model_update=mi)
- except NotFoundException:
- logger.warning(
- f"Model instance with ID {id} not found when trying to update."
- )
- def _stop_model_instance(
- self, mi: ModelInstance, clear_restart_backoff: bool = True
- ):
- """
- Stop model instance and clean up.
- Args:
- mi: The model instance to stop.
- clear_restart_backoff: Whether to clear transient restart backoff state.
- """
- logger.debug(f"Stopping model instance {mi.name or mi.id}")
- # Stop container log persistence thread
- self._stop_container_log_persistence(mi.id)
- # Teardown provisioning process if still alive.
- if self._is_provisioning(mi):
- terminate_process_tree(self._provisioning_processes[mi.id].pid)
- # Delete workload.
- deployment_metadata = mi.get_deployment_metadata(self._worker_id)
- if deployment_metadata:
- delete_workload(deployment_metadata.name)
- # Cleanup internal states.
- self._provisioning_processes.pop(mi.id, None)
- self._assigned_ports.pop(mi.id, None)
- self._error_model_instances.pop(mi.id, None)
- self._model_cache_by_instance.pop(mi.id, None)
- self._model_instance_by_instance_id.pop(mi.id, None)
- if clear_restart_backoff:
- self._restart_backoff_counts.pop(mi.id, None)
- self._inference_health_check_failures.pop(mi.id, None)
- self._last_health_check_time.pop(mi.id, None)
- self._last_successful_inference.pop(mi.id, None)
- logger.info(f"Stopped model instance {mi.name or mi.id}")
- def _restart_error_model_instance(self, mi: ModelInstance):
- """
- Restart error model instance with exponential backoff,
- maximum delay 5 minutes.
- Args:
- mi: The model instance to restart.
- """
- if self._is_provisioning(mi):
- logger.debug(f"Model instance {mi.name} is provisioning. Skipping restart.")
- return
- restart_count = mi.restart_count or 0
- backoff_count = self._restart_backoff_counts.get(mi.id, 0)
- last_restart_time = mi.last_restart_time or mi.updated_at
- current_time = datetime.now(timezone.utc)
- delay = min(10 * (2 ** (backoff_count - 1)), 300) if backoff_count > 0 else 0
- if backoff_count > 0 and last_restart_time:
- elapsed_time = (current_time - last_restart_time).total_seconds()
- if elapsed_time < delay:
- logger.trace(
- f"Delaying restart of {mi.name} for {delay - elapsed_time:.2f} seconds."
- )
- return
- logger.info(
- f"Restarting model instance {mi.name} "
- f"(attempt {backoff_count + 1}) after {delay} seconds delay."
- )
- with contextlib.suppress(NotFoundException):
- self._restart_backoff_counts[mi.id] = backoff_count + 1
- self._update_model_instance(
- mi.id,
- restart_count=restart_count + 1,
- last_restart_time=current_time,
- state=ModelInstanceStateEnum.SCHEDULED,
- state_message="",
- )
- # Pop from error model instances,
- # if failed to restart next time, it will be added again in watch_model_instance_events().
- self._error_model_instances.pop(mi.id, None)
- def _get_model(self, mi: ModelInstance) -> Model:
- """
- Efficiently get model related to the model instance with caching.
- Args:
- mi: The model instance whose model to get.
- """
- if model := self._model_cache_by_instance.get(mi.id):
- return model
- model = self._clientset.models.get(mi.model_id)
- self._model_cache_by_instance[mi.id] = model
- return model
- def _refresh_model(self, mi: ModelInstance) -> Model:
- """
- Refresh the model information from the server.
- Args:
- mi: The model instance whose model to refresh.
- Returns:
- The refreshed model.
- """
- logger.debug(f"Refreshing model {mi.model_name} information from server.")
- refreshed_model = self._clientset.models.get(mi.model_id)
- self._model_cache_by_instance[mi.id] = refreshed_model
- return refreshed_model
- def _is_provisioning(self, mi: ModelInstance) -> bool:
- """
- Check if the model instance is still provisioning.
- Args:
- mi: The model instance to check.
- """
- if process := self._provisioning_processes.get(mi.id):
- if process.is_alive():
- process.join(timeout=0)
- return process.is_alive()
- return False
- def _get_health_check_path(self, backend: str) -> Optional[str]:
- """
- Get health check path for the given backend.
- Args:
- backend: The backend name.
- Returns:
- The health check path if exists, else None.
- """
- inference_backend = self._inference_backend_manager.get_backend_by_name(backend)
- return inference_backend.health_check_path if inference_backend else None
- def get_instance_port_by_model_instance_id(
- self, model_instance_id: int
- ) -> Optional[int]:
- """
- Get the port of the model instance related to the given model instance ID.
- Args:
- model_instance_id: The model instance ID to get the port for.
- Returns:
- The port of the model instance if it exists and is running, else None.
- """
- instance = self._model_instance_by_instance_id.get(
- model_instance_id
- ) # Ensure the model instance is cached.
- return (
- instance.ports[0]
- if instance and instance.state == ModelInstanceStateEnum.RUNNING
- else None
- )
- def is_ready(
- backend: str,
- mi: ModelInstance,
- health_check_path: Optional[str] = None,
- model: Model = None,
- ) -> bool:
- """
- Access the health endpoint of the given model instance to check if it is servable.
- """
- is_built_in = is_built_in_backend(backend)
- if (not is_built_in or backend == BackendEnum.CUSTOM) and (not health_check_path):
- # If custom backend does not have health check path, consider it always ready.
- return True
- if backend == BackendEnum.ASCEND_MINDIE and not health_check_path:
- # Ref: https://www.hiascend.com/document/detail/zh/mindie/21RC2/mindieservice/servicedev/mindie_service0066.html
- # /info provides metadata information and requires more time to respond. Use it for health check.
- health_check_path = "/info"
- elif (
- backend == BackendEnum.SGLANG
- and model
- and CategoryEnum.IMAGE in model.categories
- ):
- if not model.backend_version:
- # version may be empty at initialization, consider it not ready.
- return False
- elif compare_versions(model.backend_version, "0.5.5.post3") >= 0:
- # SGLang Diffusion supported health check path at v0.5.5.post3
- health_check_path = "/health"
- else:
- # Older versions do not support health check, consider it always ready.
- return True
- elif is_built_in and backend != BackendEnum.CUSTOM and not health_check_path:
- # Built-in backends (vLLM, SGLang, vox-box) except (Custom, MindIE) use /v1/models as health check path.
- health_check_path = "/v1/models"
- try:
- # Use the worker IP instead of localhost for health check.
- # Reasons:
- # 1. Connectivity to the loopback address does not work with Ascend MindIE.
- # 2. More adaptable to container networks.
- health_check_url = f"http://{mi.worker_ip}:{mi.port}{health_check_path}"
- response = requests.get(health_check_url, timeout=1)
- if response.status_code == 200:
- return True
- except Exception as e:
- logger.debug(f"Error checking model instance {mi.name} health: {e}")
- pass
- return False
- def _get_inference_endpoint_and_payload(model: Model) -> tuple[str, dict] | None:
- """
- Get inference endpoint and payload for the model.
- Returns None if the model type should skip health check.
- """
- skip_categories = {
- CategoryEnum.IMAGE,
- CategoryEnum.SPEECH_TO_TEXT,
- CategoryEnum.TEXT_TO_SPEECH,
- CategoryEnum.UNKNOWN,
- }
- if not skip_categories.isdisjoint(model.categories):
- return None
- # Return endpoint and payload based on model type (priority order)
- if CategoryEnum.EMBEDDING in model.categories:
- return "/v1/embeddings", {"model": model.name, "input": "test"}
- if CategoryEnum.RERANKER in model.categories:
- return "/v1/rerank", {
- "model": model.name,
- "query": "test",
- "documents": ["test"],
- }
- return "/v1/chat/completions", {
- "model": model.name,
- "messages": [{"role": "user", "content": "ping"}],
- "max_tokens": 1,
- "max_completion_tokens": 1,
- }
- def _get_inference_health_check_config(model: Model) -> dict:
- """Read per-model inference health check config from model.env."""
- env = model.env or {}
- enabled = env.get(
- "GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_ENABLED", "false"
- ).lower() in (
- "true",
- "1",
- )
- interval = safe_int(
- env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_INTERVAL"),
- 300,
- )
- timeout = safe_int(
- env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_TIMEOUT"),
- 15,
- )
- threshold = safe_int(
- env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"),
- 3,
- )
- return {
- "enabled": enabled,
- "interval": interval,
- "timeout": timeout,
- "threshold": threshold,
- }
- def is_inference_ready(mi: ModelInstance, model: Model, timeout: int = 15) -> bool:
- """
- Send a minimal inference request to verify the inference capability is working.
- """
- # Check Custom backend (no standard inference API)
- if is_custom_backend(model.backend):
- return True
- # Check port assignment
- if not mi.port:
- logger.debug(f"Model instance {mi.name} does not have port assigned yet.")
- return False
- # Get endpoint and payload, None means skip health check
- result = _get_inference_endpoint_and_payload(model)
- if not result:
- logger.debug(f"Skipping inference check for {mi.name}")
- return True
- endpoint_path, payload = result
- inference_url = f"http://{mi.worker_ip}:{mi.port}{endpoint_path}"
- try:
- response = requests.post(inference_url, json=payload, timeout=timeout)
- if response.status_code == 200:
- return True
- else:
- logger.warning(
- f"Model instance {mi.name} inference health check failed "
- f"with status {response.status_code} for endpoint {endpoint_path}"
- )
- except Exception as e:
- logger.debug(
- f"Error checking model instance {mi.name} inference at {endpoint_path}: {e}"
- )
- return False
|