runtime_metrics_client.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import logging
  2. from typing import Optional
  3. import requests
  4. from prometheus_client.parser import text_string_to_metric_families
  5. from concurrent.futures import ThreadPoolExecutor, as_completed
  6. from gpustack.schemas.models import BackendEnum
  7. logger = logging.getLogger(__name__)
  8. BackendVersionAPI = {
  9. BackendEnum.VLLM.value: ["version"],
  10. BackendEnum.SGLANG.value: ["server_info", "get_server_info"],
  11. BackendEnum.ASCEND_MINDIE.value: ["info"],
  12. }
  13. class Config:
  14. def __init__(
  15. self, timeout=3, max_retries=2, base_delay=1, max_delay=3, insecure_tls=True
  16. ):
  17. self.timeout = timeout
  18. self.max_retries = max_retries
  19. self.base_delay = base_delay
  20. self.max_delay = max_delay
  21. self.insecure_tls = insecure_tls
  22. class Client:
  23. def __init__(self, config=None):
  24. self.config = config or Config()
  25. def fetch_metrics_from_endpoint(self, endpoint):
  26. url = f"http://{endpoint}/metrics"
  27. logger.trace(f"Fetching metrics from {url}")
  28. for attempt in range(self.config.max_retries + 1):
  29. try:
  30. resp = requests.get(
  31. url,
  32. timeout=self.config.timeout,
  33. verify=not self.config.insecure_tls,
  34. )
  35. if resp.status_code == 200:
  36. metrics = {}
  37. for family in text_string_to_metric_families(resp.text):
  38. metrics[family.name] = family
  39. return metrics
  40. else:
  41. logger.warning(
  42. f"[{endpoint}] Attempt {attempt + 1}: Bad status {resp.status_code}"
  43. )
  44. except Exception as e:
  45. logger.error(f"[{endpoint}] Attempt {attempt + 1}: Error {e}")
  46. # Exponential backoff
  47. if attempt < self.config.max_retries:
  48. delay = min(
  49. self.config.base_delay * (2**attempt), self.config.max_delay
  50. )
  51. import time
  52. time.sleep(delay)
  53. return None
  54. def fetch_metrics_from_endpoints(self, endpoints, max_workers=16):
  55. results = {}
  56. with ThreadPoolExecutor(max_workers=max_workers) as pool:
  57. futures = {
  58. pool.submit(self.fetch_metrics_from_endpoint, ep): ep
  59. for ep in endpoints
  60. }
  61. for future in as_completed(futures):
  62. ep = futures[future]
  63. results[ep] = future.result()
  64. return results
  65. def fetch_runtime_version_from_endpoint(
  66. self, endpoint: str, runtime: str
  67. ) -> Optional[str]:
  68. """
  69. Try to fetch the runtime version from all possible API paths. Return on first success.
  70. Log last error or warning for troubleshooting.
  71. """
  72. paths = BackendVersionAPI.get(runtime)
  73. if paths is None:
  74. return None
  75. error_msg = ""
  76. warning_msg = ""
  77. for path in paths:
  78. url = f"http://{endpoint}/{path}"
  79. try:
  80. resp = requests.get(
  81. url,
  82. timeout=self.config.timeout,
  83. verify=not self.config.insecure_tls,
  84. )
  85. if resp.status_code == 200:
  86. data = resp.json()
  87. return data.get("version", None)
  88. else:
  89. warning_msg = f"[{endpoint}] Bad status {resp.status_code} when fetching {runtime} version from {url}"
  90. except Exception as e:
  91. error_msg = (
  92. f"[{endpoint}] Error {e} when fetching {runtime} version from {url}"
  93. )
  94. if error_msg:
  95. logger.error(error_msg)
  96. elif warning_msg:
  97. logger.warning(warning_msg)
  98. return None