runtime_metrics_aggregator.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. from typing import Callable
  2. from cachetools import TTLCache
  3. from prometheus_client.core import ( # noqa: F401
  4. GaugeMetricFamily,
  5. InfoMetricFamily,
  6. HistogramMetricFamily,
  7. CounterMetricFamily,
  8. SummaryMetricFamily,
  9. )
  10. from prometheus_client import CollectorRegistry
  11. from gpustack.client.generated_clientset import ClientSet
  12. from gpustack.utils.command import find_parameter
  13. from gpustack.utils.metrics import (
  14. get_builtin_metrics_config,
  15. get_runtime_metrics_config,
  16. )
  17. from gpustack.worker.runtime_metrics_client import (
  18. Config as RunTimeMetricsClientConfig,
  19. )
  20. from gpustack.worker.runtime_metrics_client import Client as RuntimeMetricsClient
  21. from gpustack.schemas.models import (
  22. BackendEnum,
  23. Model,
  24. ModelInstance,
  25. ModelInstanceStateEnum,
  26. ModelInstanceUpdate,
  27. get_backend,
  28. is_audio_model,
  29. is_image_model,
  30. )
  31. import logging
  32. import uuid
  33. from typing import Optional
  34. from gpustack.utils import version
  35. logger = logging.getLogger(__name__)
  36. METRICS_CONFIG_FETCH_TIMEOUT_SECONDS = 30
  37. # unified registry
  38. unified_registry = CollectorRegistry()
  39. # raw metrics registry
  40. raw_registry = CollectorRegistry()
  41. class RuntimeMetricsAggregator:
  42. def __init__(
  43. self,
  44. cache: dict = None,
  45. worker_id_getter=Callable[[], int],
  46. clientset: ClientSet = None,
  47. ):
  48. self._cache = cache
  49. self._metrics_client_config = RunTimeMetricsClientConfig(
  50. timeout=5, max_retries=2, insecure_tls=True
  51. )
  52. self._metrics_client = RuntimeMetricsClient(self._metrics_client_config)
  53. self._worker_id_getter = worker_id_getter
  54. self._clientset = clientset
  55. # Cache for metrics config (refresh every 300 seconds)
  56. self._metrics_config_cache = TTLCache(maxsize=1, ttl=300)
  57. def aggregate(self):
  58. """
  59. Fetch metrics from all model instances, normalize and aggregate both unified and raw metrics, and write results to cache.
  60. """
  61. worker_id = self._worker_id_getter()
  62. if not worker_id:
  63. logger.trace("Worker ID is not set. Skipping runtime metrics fetch.")
  64. return
  65. # 1. Get metrics config
  66. metrics_config = self._get_metrics_config()
  67. # 2. Get active model endpoints
  68. endpoints, endpoint_to_instance, instance_id_to_model = (
  69. self._find_active_model_endpoints(worker_id, metrics_config)
  70. )
  71. if not endpoints:
  72. logger.trace(
  73. "No valid endpoints found for model instances. Skipping runtime metrics fetch."
  74. )
  75. return
  76. trace_id = uuid.uuid4().hex[:8]
  77. logger.trace(
  78. f"trace_id: {trace_id}, fetching runtime metrics from {len(endpoints)} endpoints"
  79. )
  80. # 3. Batch fetch metrics from all endpoints
  81. endpoint_metrics = self._metrics_client.fetch_metrics_from_endpoints(endpoints)
  82. # 4. Unified and raw aggregation
  83. unified_metrics = {}
  84. raw_metrics = {}
  85. for ep, metrics in endpoint_metrics.items():
  86. if not metrics:
  87. continue
  88. mi = endpoint_to_instance[ep]
  89. m = instance_id_to_model.get(mi.id)
  90. runtime = get_backend(m)
  91. runtime_version = self.fetch_and_update_api_backend_version(mi, ep)
  92. base_labels = self._build_base_labels(mi, m, runtime)
  93. self._process_endpoint_metrics(
  94. metrics,
  95. base_labels,
  96. runtime,
  97. runtime_version,
  98. unified_metrics,
  99. raw_metrics,
  100. metrics_config,
  101. )
  102. self._cache["unified"] = unified_metrics
  103. self._cache["raw"] = raw_metrics
  104. logger.trace(f"trace_id: {trace_id}, completed fetching runtime metrics.")
  105. def fetch_and_update_api_backend_version(
  106. self,
  107. model_instance: ModelInstance,
  108. endpoint: str,
  109. ) -> Optional[str]:
  110. if model_instance.api_detected_backend_version is not None:
  111. return model_instance.api_detected_backend_version
  112. version = self._metrics_client.fetch_runtime_version_from_endpoint(
  113. endpoint, model_instance.backend
  114. )
  115. if version is not None:
  116. self._update_model_instance(
  117. model_instance.id, api_detected_backend_version=version
  118. )
  119. return version
  120. return model_instance.backend_version
  121. def _find_active_model_endpoints(
  122. self, worker_id: int, metrics_config: dict
  123. ) -> tuple[set, dict[str, ModelInstance], dict[int, Model]]:
  124. """
  125. Get all endpoints and related mappings for RUNNING model instances on this worker.
  126. Returns: (endpoints, endpoint->instance, instance_id->model)
  127. """
  128. model_instances, models = self._list_worker_models(worker_id)
  129. if not model_instances or not models:
  130. return set(), {}, {}
  131. model_id_to_model = {m.id: m for m in models.items}
  132. endpoints = set()
  133. endpoint_to_instance = {}
  134. instance_id_to_model = {}
  135. for mi in model_instances.items:
  136. model = model_id_to_model.get(mi.model_id)
  137. if self._should_skip_endpoint(
  138. model=model,
  139. model_instance=mi,
  140. metrics_config=metrics_config,
  141. ):
  142. logger.trace(f"Skipping model instance {mi.id} in metrics aggregation.")
  143. continue
  144. endpoint = f"{mi.worker_ip}:{mi.ports[0]}"
  145. endpoints.add(endpoint)
  146. endpoint_to_instance[endpoint] = mi
  147. instance_id_to_model[mi.id] = model
  148. return endpoints, endpoint_to_instance, instance_id_to_model
  149. def _list_worker_models(self, worker_id: int):
  150. """
  151. Query all model instances and model objects on this worker.
  152. """
  153. model_instances = self._clientset.model_instances.list(
  154. params={"worker_id": str(worker_id)}
  155. )
  156. models = self._clientset.models.list()
  157. return model_instances, models
  158. def _update_model_instance(self, id: int, **kwargs):
  159. try:
  160. mi_public = self._clientset.model_instances.get(id=id)
  161. mi = ModelInstanceUpdate(**mi_public.model_dump())
  162. for key, value in kwargs.items():
  163. setattr(mi, key, value)
  164. self._clientset.model_instances.update(id=id, model_update=mi)
  165. except Exception as e:
  166. logger.error(f"Failed to update model instance {id}: {e}")
  167. def _build_base_labels(self, mi, m, runtime):
  168. """
  169. Build base labels for each metric.
  170. """
  171. return {
  172. "worker_id": str(mi.worker_id) if mi.worker_id else "",
  173. "worker_name": mi.worker_name if mi.worker_name else "",
  174. "model_id": str(m.id) if m else "",
  175. "model_name": m.name if m else "",
  176. "model_instance_id": str(mi.id),
  177. "model_instance_name": mi.name,
  178. "runtime": runtime,
  179. }
  180. def _process_endpoint_metrics(
  181. self,
  182. metrics,
  183. base_labels,
  184. runtime,
  185. runtime_version,
  186. unified_metrics,
  187. raw_metrics,
  188. metrics_config,
  189. ):
  190. """
  191. Process metrics for a single endpoint, aggregate to unified and raw.
  192. """
  193. for source_family_name, family in metrics.items():
  194. first_sample = family.samples[0] if family.samples else None
  195. if not first_sample:
  196. continue
  197. label_keys = list(base_labels.keys())
  198. for k in first_sample.labels.keys():
  199. if k not in label_keys:
  200. label_keys.append(k)
  201. # raw metrics
  202. if source_family_name not in raw_metrics:
  203. raw_metrics[source_family_name] = create_prom_metric_family(
  204. name=source_family_name,
  205. type=family.type,
  206. description=family.documentation,
  207. labels=label_keys,
  208. )
  209. raw_family = raw_metrics[source_family_name]
  210. # unified metrics
  211. unified_family = None
  212. unified_metric_family_name = get_unified_metric_family_name(
  213. metrics_config, source_family_name, runtime, runtime_version
  214. )
  215. if unified_metric_family_name:
  216. cfg = get_unified_metric_family_config(
  217. metrics_config, unified_metric_family_name
  218. )
  219. if cfg:
  220. if unified_metric_family_name not in unified_metrics:
  221. unified_metrics[unified_metric_family_name] = (
  222. create_prom_metric_family(
  223. name=unified_metric_family_name,
  224. type=cfg.get("type"),
  225. description=cfg.get("description"),
  226. labels=label_keys,
  227. )
  228. )
  229. unified_family = unified_metrics[unified_metric_family_name]
  230. for sample in family.samples:
  231. label_values = [
  232. (
  233. base_labels.get(k, sample.labels.get(k, ""))
  234. if k in base_labels
  235. else sample.labels.get(k, "")
  236. )
  237. for k in label_keys
  238. ]
  239. labels = sample.labels.copy()
  240. labels.update(base_labels)
  241. if family.type in ("histogram", "summary"):
  242. raw_family.add_sample(
  243. name=sample.name,
  244. labels=labels,
  245. value=sample.value,
  246. timestamp=sample.timestamp,
  247. )
  248. if unified_family:
  249. new_name = sample.name.replace(
  250. source_family_name, unified_metric_family_name
  251. )
  252. unified_family.add_sample(
  253. name=new_name,
  254. labels=labels,
  255. value=sample.value,
  256. timestamp=sample.timestamp,
  257. )
  258. else:
  259. raw_family.add_metric(
  260. labels=label_values,
  261. value=sample.value,
  262. timestamp=sample.timestamp,
  263. )
  264. if unified_family:
  265. unified_family.add_metric(
  266. labels=label_values,
  267. value=sample.value,
  268. timestamp=sample.timestamp,
  269. )
  270. def _should_skip_endpoint(
  271. self, model: Model, model_instance: ModelInstance, metrics_config: dict
  272. ) -> bool:
  273. # skip image and audio models
  274. if is_image_model(model) or is_audio_model(model):
  275. return True
  276. # model and model instance must be valid
  277. if (
  278. model_instance.state != ModelInstanceStateEnum.RUNNING
  279. or model_instance.worker_ip is None
  280. or not model_instance.ports
  281. ):
  282. return True
  283. if not model:
  284. return True
  285. runtime = model.backend
  286. if not runtime:
  287. return True
  288. # check runtime metrics config
  289. runtime_cfg = get_runtime_metrics_config(metrics_config, runtime)
  290. if not runtime_cfg:
  291. return True
  292. # check runtime-specific metrics flags
  293. if runtime == BackendEnum.VLLM:
  294. disable_metrics = find_parameter(
  295. model.backend_parameters, ["disable-log-stats"]
  296. )
  297. if disable_metrics:
  298. return True
  299. if model.env and model.env.get("GPUSTACK_DISABLE_METRICS"):
  300. return True
  301. return False
  302. def _get_online_metrics_config(self):
  303. try:
  304. resp = self._clientset.http_client.get_httpx_client().get(
  305. f"{self._clientset.base_url}/v2/metrics/config",
  306. timeout=METRICS_CONFIG_FETCH_TIMEOUT_SECONDS,
  307. )
  308. if resp.status_code == 404:
  309. return None
  310. elif resp.status_code != 200:
  311. logger.warning(
  312. f"Failed to fetch online metrics config, status: {resp.status_code}"
  313. )
  314. return None
  315. data = resp.json()
  316. if not isinstance(data, dict):
  317. logger.warning(
  318. "Online metrics config is not a dict, fallback to builtin config."
  319. )
  320. return None
  321. return data
  322. except Exception as e:
  323. logger.error(f"Error fetching online metrics config: {e}")
  324. return None
  325. def _get_metrics_config(self):
  326. """Get metrics config with automatic caching (300 seconds TTL)."""
  327. try:
  328. return self._metrics_config_cache["config"]
  329. except KeyError:
  330. # Cache miss, fetch fresh config
  331. pass
  332. online_config = self._get_online_metrics_config()
  333. if online_config:
  334. logger.debug("Updated online metrics config cache")
  335. self._metrics_config_cache["config"] = online_config
  336. return online_config
  337. else:
  338. builtin_config = get_builtin_metrics_config()
  339. logger.debug("Using builtin metrics config")
  340. # Cache for 300 seconds
  341. self._metrics_config_cache["config"] = builtin_config
  342. return builtin_config
  343. _METRIC_FAMILY_CLASS = {
  344. "gauge": GaugeMetricFamily,
  345. "info": InfoMetricFamily,
  346. "histogram": HistogramMetricFamily,
  347. "counter": CounterMetricFamily,
  348. "summary": SummaryMetricFamily,
  349. }
  350. def create_prom_metric_family(type: str, name: str, description: str, labels=None):
  351. cls = _METRIC_FAMILY_CLASS.get(str(type).lower())
  352. if not cls:
  353. raise ValueError(f"Unknown metric family type: {type}")
  354. if labels is not None:
  355. return cls(name, description, labels=labels)
  356. else:
  357. return cls(name, description)
  358. def get_unified_metric_family_name(
  359. config: dict,
  360. source_metric_family_name: str,
  361. runtime: str,
  362. runtime_version: Optional[str],
  363. ) -> Optional[str]:
  364. """
  365. Return the unified (normalized) metric family name as a string. If not found, return an empty string.
  366. Prefer version-specific mapping if matched, otherwise use the default '*'.
  367. """
  368. runtime_cfg = get_runtime_metrics_config(config, runtime)
  369. if not runtime_cfg:
  370. return None
  371. name = runtime_cfg.get("*", {}).get(source_metric_family_name, None)
  372. if runtime_version:
  373. is_valid_version = version.is_valid_version_str(runtime_version)
  374. for ver_range, mapping in runtime_cfg.items():
  375. if ver_range == "*":
  376. continue
  377. if (is_valid_version and version.in_range(runtime_version, ver_range)) or (
  378. not is_valid_version and runtime_version == ver_range
  379. ):
  380. old_version_name = mapping.get(source_metric_family_name)
  381. if old_version_name is not None:
  382. return old_version_name
  383. return name
  384. def get_unified_metric_family_config(
  385. config: dict, unified_metric_family_name: str
  386. ) -> dict:
  387. return config.get("gpustack_metrics", {}).get(unified_metric_family_name, {})