| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- from typing import Callable
- from cachetools import TTLCache
- from prometheus_client.core import ( # noqa: F401
- GaugeMetricFamily,
- InfoMetricFamily,
- HistogramMetricFamily,
- CounterMetricFamily,
- SummaryMetricFamily,
- )
- from prometheus_client import CollectorRegistry
- from gpustack.client.generated_clientset import ClientSet
- from gpustack.utils.command import find_parameter
- from gpustack.utils.metrics import (
- get_builtin_metrics_config,
- get_runtime_metrics_config,
- )
- from gpustack.worker.runtime_metrics_client import (
- Config as RunTimeMetricsClientConfig,
- )
- from gpustack.worker.runtime_metrics_client import Client as RuntimeMetricsClient
- from gpustack.schemas.models import (
- BackendEnum,
- Model,
- ModelInstance,
- ModelInstanceStateEnum,
- ModelInstanceUpdate,
- get_backend,
- is_audio_model,
- is_image_model,
- )
- import logging
- import uuid
- from typing import Optional
- from gpustack.utils import version
- logger = logging.getLogger(__name__)
- METRICS_CONFIG_FETCH_TIMEOUT_SECONDS = 30
- # unified registry
- unified_registry = CollectorRegistry()
- # raw metrics registry
- raw_registry = CollectorRegistry()
- class RuntimeMetricsAggregator:
- def __init__(
- self,
- cache: dict = None,
- worker_id_getter=Callable[[], int],
- clientset: ClientSet = None,
- ):
- self._cache = cache
- self._metrics_client_config = RunTimeMetricsClientConfig(
- timeout=5, max_retries=2, insecure_tls=True
- )
- self._metrics_client = RuntimeMetricsClient(self._metrics_client_config)
- self._worker_id_getter = worker_id_getter
- self._clientset = clientset
- # Cache for metrics config (refresh every 300 seconds)
- self._metrics_config_cache = TTLCache(maxsize=1, ttl=300)
- def aggregate(self):
- """
- Fetch metrics from all model instances, normalize and aggregate both unified and raw metrics, and write results to cache.
- """
- worker_id = self._worker_id_getter()
- if not worker_id:
- logger.trace("Worker ID is not set. Skipping runtime metrics fetch.")
- return
- # 1. Get metrics config
- metrics_config = self._get_metrics_config()
- # 2. Get active model endpoints
- endpoints, endpoint_to_instance, instance_id_to_model = (
- self._find_active_model_endpoints(worker_id, metrics_config)
- )
- if not endpoints:
- logger.trace(
- "No valid endpoints found for model instances. Skipping runtime metrics fetch."
- )
- return
- trace_id = uuid.uuid4().hex[:8]
- logger.trace(
- f"trace_id: {trace_id}, fetching runtime metrics from {len(endpoints)} endpoints"
- )
- # 3. Batch fetch metrics from all endpoints
- endpoint_metrics = self._metrics_client.fetch_metrics_from_endpoints(endpoints)
- # 4. Unified and raw aggregation
- unified_metrics = {}
- raw_metrics = {}
- for ep, metrics in endpoint_metrics.items():
- if not metrics:
- continue
- mi = endpoint_to_instance[ep]
- m = instance_id_to_model.get(mi.id)
- runtime = get_backend(m)
- runtime_version = self.fetch_and_update_api_backend_version(mi, ep)
- base_labels = self._build_base_labels(mi, m, runtime)
- self._process_endpoint_metrics(
- metrics,
- base_labels,
- runtime,
- runtime_version,
- unified_metrics,
- raw_metrics,
- metrics_config,
- )
- self._cache["unified"] = unified_metrics
- self._cache["raw"] = raw_metrics
- logger.trace(f"trace_id: {trace_id}, completed fetching runtime metrics.")
- def fetch_and_update_api_backend_version(
- self,
- model_instance: ModelInstance,
- endpoint: str,
- ) -> Optional[str]:
- if model_instance.api_detected_backend_version is not None:
- return model_instance.api_detected_backend_version
- version = self._metrics_client.fetch_runtime_version_from_endpoint(
- endpoint, model_instance.backend
- )
- if version is not None:
- self._update_model_instance(
- model_instance.id, api_detected_backend_version=version
- )
- return version
- return model_instance.backend_version
- def _find_active_model_endpoints(
- self, worker_id: int, metrics_config: dict
- ) -> tuple[set, dict[str, ModelInstance], dict[int, Model]]:
- """
- Get all endpoints and related mappings for RUNNING model instances on this worker.
- Returns: (endpoints, endpoint->instance, instance_id->model)
- """
- model_instances, models = self._list_worker_models(worker_id)
- if not model_instances or not models:
- return set(), {}, {}
- model_id_to_model = {m.id: m for m in models.items}
- endpoints = set()
- endpoint_to_instance = {}
- instance_id_to_model = {}
- for mi in model_instances.items:
- model = model_id_to_model.get(mi.model_id)
- if self._should_skip_endpoint(
- model=model,
- model_instance=mi,
- metrics_config=metrics_config,
- ):
- logger.trace(f"Skipping model instance {mi.id} in metrics aggregation.")
- continue
- endpoint = f"{mi.worker_ip}:{mi.ports[0]}"
- endpoints.add(endpoint)
- endpoint_to_instance[endpoint] = mi
- instance_id_to_model[mi.id] = model
- return endpoints, endpoint_to_instance, instance_id_to_model
- def _list_worker_models(self, worker_id: int):
- """
- Query all model instances and model objects on this worker.
- """
- model_instances = self._clientset.model_instances.list(
- params={"worker_id": str(worker_id)}
- )
- models = self._clientset.models.list()
- return model_instances, models
- def _update_model_instance(self, id: int, **kwargs):
- try:
- mi_public = self._clientset.model_instances.get(id=id)
- mi = ModelInstanceUpdate(**mi_public.model_dump())
- for key, value in kwargs.items():
- setattr(mi, key, value)
- self._clientset.model_instances.update(id=id, model_update=mi)
- except Exception as e:
- logger.error(f"Failed to update model instance {id}: {e}")
- def _build_base_labels(self, mi, m, runtime):
- """
- Build base labels for each metric.
- """
- return {
- "worker_id": str(mi.worker_id) if mi.worker_id else "",
- "worker_name": mi.worker_name if mi.worker_name else "",
- "model_id": str(m.id) if m else "",
- "model_name": m.name if m else "",
- "model_instance_id": str(mi.id),
- "model_instance_name": mi.name,
- "runtime": runtime,
- }
- def _process_endpoint_metrics(
- self,
- metrics,
- base_labels,
- runtime,
- runtime_version,
- unified_metrics,
- raw_metrics,
- metrics_config,
- ):
- """
- Process metrics for a single endpoint, aggregate to unified and raw.
- """
- for source_family_name, family in metrics.items():
- first_sample = family.samples[0] if family.samples else None
- if not first_sample:
- continue
- label_keys = list(base_labels.keys())
- for k in first_sample.labels.keys():
- if k not in label_keys:
- label_keys.append(k)
- # raw metrics
- if source_family_name not in raw_metrics:
- raw_metrics[source_family_name] = create_prom_metric_family(
- name=source_family_name,
- type=family.type,
- description=family.documentation,
- labels=label_keys,
- )
- raw_family = raw_metrics[source_family_name]
- # unified metrics
- unified_family = None
- unified_metric_family_name = get_unified_metric_family_name(
- metrics_config, source_family_name, runtime, runtime_version
- )
- if unified_metric_family_name:
- cfg = get_unified_metric_family_config(
- metrics_config, unified_metric_family_name
- )
- if cfg:
- if unified_metric_family_name not in unified_metrics:
- unified_metrics[unified_metric_family_name] = (
- create_prom_metric_family(
- name=unified_metric_family_name,
- type=cfg.get("type"),
- description=cfg.get("description"),
- labels=label_keys,
- )
- )
- unified_family = unified_metrics[unified_metric_family_name]
- for sample in family.samples:
- label_values = [
- (
- base_labels.get(k, sample.labels.get(k, ""))
- if k in base_labels
- else sample.labels.get(k, "")
- )
- for k in label_keys
- ]
- labels = sample.labels.copy()
- labels.update(base_labels)
- if family.type in ("histogram", "summary"):
- raw_family.add_sample(
- name=sample.name,
- labels=labels,
- value=sample.value,
- timestamp=sample.timestamp,
- )
- if unified_family:
- new_name = sample.name.replace(
- source_family_name, unified_metric_family_name
- )
- unified_family.add_sample(
- name=new_name,
- labels=labels,
- value=sample.value,
- timestamp=sample.timestamp,
- )
- else:
- raw_family.add_metric(
- labels=label_values,
- value=sample.value,
- timestamp=sample.timestamp,
- )
- if unified_family:
- unified_family.add_metric(
- labels=label_values,
- value=sample.value,
- timestamp=sample.timestamp,
- )
- def _should_skip_endpoint(
- self, model: Model, model_instance: ModelInstance, metrics_config: dict
- ) -> bool:
- # skip image and audio models
- if is_image_model(model) or is_audio_model(model):
- return True
- # model and model instance must be valid
- if (
- model_instance.state != ModelInstanceStateEnum.RUNNING
- or model_instance.worker_ip is None
- or not model_instance.ports
- ):
- return True
- if not model:
- return True
- runtime = model.backend
- if not runtime:
- return True
- # check runtime metrics config
- runtime_cfg = get_runtime_metrics_config(metrics_config, runtime)
- if not runtime_cfg:
- return True
- # check runtime-specific metrics flags
- if runtime == BackendEnum.VLLM:
- disable_metrics = find_parameter(
- model.backend_parameters, ["disable-log-stats"]
- )
- if disable_metrics:
- return True
- if model.env and model.env.get("GPUSTACK_DISABLE_METRICS"):
- return True
- return False
- def _get_online_metrics_config(self):
- try:
- resp = self._clientset.http_client.get_httpx_client().get(
- f"{self._clientset.base_url}/v2/metrics/config",
- timeout=METRICS_CONFIG_FETCH_TIMEOUT_SECONDS,
- )
- if resp.status_code == 404:
- return None
- elif resp.status_code != 200:
- logger.warning(
- f"Failed to fetch online metrics config, status: {resp.status_code}"
- )
- return None
- data = resp.json()
- if not isinstance(data, dict):
- logger.warning(
- "Online metrics config is not a dict, fallback to builtin config."
- )
- return None
- return data
- except Exception as e:
- logger.error(f"Error fetching online metrics config: {e}")
- return None
- def _get_metrics_config(self):
- """Get metrics config with automatic caching (300 seconds TTL)."""
- try:
- return self._metrics_config_cache["config"]
- except KeyError:
- # Cache miss, fetch fresh config
- pass
- online_config = self._get_online_metrics_config()
- if online_config:
- logger.debug("Updated online metrics config cache")
- self._metrics_config_cache["config"] = online_config
- return online_config
- else:
- builtin_config = get_builtin_metrics_config()
- logger.debug("Using builtin metrics config")
- # Cache for 300 seconds
- self._metrics_config_cache["config"] = builtin_config
- return builtin_config
- _METRIC_FAMILY_CLASS = {
- "gauge": GaugeMetricFamily,
- "info": InfoMetricFamily,
- "histogram": HistogramMetricFamily,
- "counter": CounterMetricFamily,
- "summary": SummaryMetricFamily,
- }
- def create_prom_metric_family(type: str, name: str, description: str, labels=None):
- cls = _METRIC_FAMILY_CLASS.get(str(type).lower())
- if not cls:
- raise ValueError(f"Unknown metric family type: {type}")
- if labels is not None:
- return cls(name, description, labels=labels)
- else:
- return cls(name, description)
- def get_unified_metric_family_name(
- config: dict,
- source_metric_family_name: str,
- runtime: str,
- runtime_version: Optional[str],
- ) -> Optional[str]:
- """
- Return the unified (normalized) metric family name as a string. If not found, return an empty string.
- Prefer version-specific mapping if matched, otherwise use the default '*'.
- """
- runtime_cfg = get_runtime_metrics_config(config, runtime)
- if not runtime_cfg:
- return None
- name = runtime_cfg.get("*", {}).get(source_metric_family_name, None)
- if runtime_version:
- is_valid_version = version.is_valid_version_str(runtime_version)
- for ver_range, mapping in runtime_cfg.items():
- if ver_range == "*":
- continue
- if (is_valid_version and version.in_range(runtime_version, ver_range)) or (
- not is_valid_version and runtime_version == ver_range
- ):
- old_version_name = mapping.get(source_metric_family_name)
- if old_version_name is not None:
- return old_version_name
- return name
- def get_unified_metric_family_config(
- config: dict, unified_metric_family_name: str
- ) -> dict:
- return config.get("gpustack_metrics", {}).get(unified_metric_family_name, {})
|