| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import logging
- from typing import Optional
- import requests
- from prometheus_client.parser import text_string_to_metric_families
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from gpustack.schemas.models import BackendEnum
- logger = logging.getLogger(__name__)
- BackendVersionAPI = {
- BackendEnum.VLLM.value: ["version"],
- BackendEnum.SGLANG.value: ["server_info", "get_server_info"],
- BackendEnum.ASCEND_MINDIE.value: ["info"],
- }
- class Config:
- def __init__(
- self, timeout=3, max_retries=2, base_delay=1, max_delay=3, insecure_tls=True
- ):
- self.timeout = timeout
- self.max_retries = max_retries
- self.base_delay = base_delay
- self.max_delay = max_delay
- self.insecure_tls = insecure_tls
- class Client:
- def __init__(self, config=None):
- self.config = config or Config()
- def fetch_metrics_from_endpoint(self, endpoint):
- url = f"http://{endpoint}/metrics"
- logger.trace(f"Fetching metrics from {url}")
- for attempt in range(self.config.max_retries + 1):
- try:
- resp = requests.get(
- url,
- timeout=self.config.timeout,
- verify=not self.config.insecure_tls,
- )
- if resp.status_code == 200:
- metrics = {}
- for family in text_string_to_metric_families(resp.text):
- metrics[family.name] = family
- return metrics
- else:
- logger.warning(
- f"[{endpoint}] Attempt {attempt + 1}: Bad status {resp.status_code}"
- )
- except Exception as e:
- logger.error(f"[{endpoint}] Attempt {attempt + 1}: Error {e}")
- # Exponential backoff
- if attempt < self.config.max_retries:
- delay = min(
- self.config.base_delay * (2**attempt), self.config.max_delay
- )
- import time
- time.sleep(delay)
- return None
- def fetch_metrics_from_endpoints(self, endpoints, max_workers=16):
- results = {}
- with ThreadPoolExecutor(max_workers=max_workers) as pool:
- futures = {
- pool.submit(self.fetch_metrics_from_endpoint, ep): ep
- for ep in endpoints
- }
- for future in as_completed(futures):
- ep = futures[future]
- results[ep] = future.result()
- return results
- def fetch_runtime_version_from_endpoint(
- self, endpoint: str, runtime: str
- ) -> Optional[str]:
- """
- Try to fetch the runtime version from all possible API paths. Return on first success.
- Log last error or warning for troubleshooting.
- """
- paths = BackendVersionAPI.get(runtime)
- if paths is None:
- return None
- error_msg = ""
- warning_msg = ""
- for path in paths:
- url = f"http://{endpoint}/{path}"
- try:
- resp = requests.get(
- url,
- timeout=self.config.timeout,
- verify=not self.config.insecure_tls,
- )
- if resp.status_code == 200:
- data = resp.json()
- return data.get("version", None)
- else:
- warning_msg = f"[{endpoint}] Bad status {resp.status_code} when fetching {runtime} version from {url}"
- except Exception as e:
- error_msg = (
- f"[{endpoint}] Error {e} when fetching {runtime} version from {url}"
- )
- if error_msg:
- logger.error(error_msg)
- elif warning_msg:
- logger.warning(warning_msg)
- return None
|