run_model_benchmark.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219
  1. #!/usr/bin/env python3
  2. """
  3. Automated GPUStack serving benchmark runner.
  4. This script reads:
  5. 1. A model/run configuration YAML
  6. 2. A benchmark profile YAML
  7. Then it drives the full benchmark lifecycle through GPUStack's HTTP API:
  8. 1. Create a model deployment for one run
  9. 2. Wait until the model instance is `running`
  10. 3. Optionally warm up the OpenAI-compatible endpoint
  11. 4. Create one or more benchmark jobs for the selected test cases
  12. 5. Watch benchmark state over SSE until completion
  13. 6. Save the final benchmark payload as JSON
  14. 7. Scale the model back to zero replicas before moving to the next run
  15. Typical usage:
  16. ```bash
  17. python3 hack/perf/run_model_benchmark.py \
  18. --config .cache/plan/benchmark/high-throughput/qwen_3.5_35b_fp8.yaml \
  19. --profile gpustack/assets/profiles_config/profiles_config.yaml \
  20. --gpustack-url https://YOUR_GPUSTACK \
  21. --gpustack-token $GPUSTACK_TOKEN \
  22. --cluster-id 1 \
  23. --output-dir benchmark_results
  24. ```
  25. Run only a subset of runs:
  26. ```bash
  27. python3 hack/perf/benchmark_serving.py \
  28. --config .../qwen_3.5_9b.yaml \
  29. --profile .../profiles_config.yaml \
  30. --gpustack-url https://YOUR_GPUSTACK \
  31. --gpustack-token $GPUSTACK_TOKEN \
  32. --cluster-id 1 \
  33. --run-names vllm-standard,sgl-throughput-bundle
  34. ```
  35. Override test cases or request rates from the profile:
  36. ```bash
  37. python3 hack/perf/benchmark_serving.py \
  38. --config .../qwen_3.5_122b_a10b_fp8.yaml \
  39. --profile .../profiles_config.yaml \
  40. --gpustack-url https://YOUR_GPUSTACK \
  41. --gpustack-token $GPUSTACK_TOKEN \
  42. --cluster-id 1 \
  43. --test-cases Throughput,Long\\ Context \
  44. --request-rates 1,4,8
  45. ```
  46. Expected config YAML shape:
  47. ```yaml
  48. model: "Qwen/Qwen3.5-35B-A3B-FP8"
  49. source: "model_scope" # or "huggingface"
  50. health_check:
  51. init_delay: 60
  52. timeout: 1800
  53. interval: 5.0
  54. warmup:
  55. num_requests: 10
  56. test_cases:
  57. - name: Throughput
  58. runs:
  59. - name: vllm-standard
  60. backend: vLLM
  61. backend_version: 0.17.1
  62. backend_parameters:
  63. - --reasoning-parser=qwen3
  64. - --max-model-len=32768
  65. ```
  66. Expected profile YAML shape:
  67. ```yaml
  68. profiles:
  69. - name: Throughput
  70. request_rate: 4
  71. total_requests: 100
  72. dataset_name: sharegpt
  73. ```
  74. """
  75. import json
  76. import logging
  77. import re
  78. import ssl
  79. import sys
  80. import time
  81. import urllib.error
  82. import urllib.parse
  83. import urllib.request
  84. import hashlib
  85. from dataclasses import dataclass
  86. from enum import Enum
  87. from pathlib import Path
  88. from typing import Any, Dict, Iterator, List, Optional
  89. import yaml
  90. logging.basicConfig(
  91. level=logging.INFO,
  92. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  93. )
  94. logger = logging.getLogger("llm_benchmark")
  95. class EngineType(Enum):
  96. """Supported inference engine types"""
  97. VLLM = "vLLM"
  98. SGLANG = "SGLang"
  99. TRTLLM = "TRT-LLM"
  100. class Source(Enum):
  101. Huggingface = "huggingface"
  102. ModelScope = "model_scope"
  103. @dataclass
  104. class HealthCheck:
  105. init_delay: int = 60
  106. timeout: int = 30
  107. interval: float = 1.0
  108. @dataclass
  109. class Model:
  110. """Configuration for a model test run"""
  111. name: str
  112. test_cases: List[str]
  113. backend: EngineType
  114. backend_version: Optional[str] = None
  115. backend_parameters: Optional[List[str]] = None
  116. envs: Optional[Dict[str, str]] = None
  117. args: Optional[List[str]] = None
  118. health_check: Optional[HealthCheck] = None
  119. warmup_num_requests: Optional[int] = None
  120. stop_model_after_run: bool = True
  121. instance_name: Optional[str] = None
  122. model_id: Optional[int] = None
  123. model_name: Optional[str] = None
  124. benchmark_id: Optional[int] = None
  125. benchmark_name: Optional[str] = None
  126. request_rates: Optional[List[int]] = None
  127. class EngineManager:
  128. """
  129. Translate one benchmark run into concrete GPUStack API operations.
  130. Important behavior:
  131. - one `Model` dataclass instance corresponds to one deployment/run in the YAML
  132. - each run may execute multiple benchmark profiles (`test_cases`)
  133. - each benchmark result is written to one JSON file under `output_dir`
  134. """
  135. def __init__(
  136. self,
  137. model: str,
  138. source: str,
  139. gpustack_url: str,
  140. gpustack_token: str,
  141. cluster_id: int,
  142. output_dir: str = "benchmark_results",
  143. ):
  144. self.model_repo_id = model
  145. self.source = source
  146. self.gpustack_url = gpustack_url.rstrip("/")
  147. self.gpustack_token = gpustack_token
  148. self.cluster_id = int(cluster_id)
  149. self.output_dir = Path(output_dir)
  150. self.output_dir.mkdir(parents=True, exist_ok=True)
  151. self.ssl_context = ssl.create_default_context()
  152. self.ssl_context.check_hostname = False
  153. self.ssl_context.verify_mode = ssl.CERT_NONE
  154. def _headers(self) -> Dict[str, str]:
  155. return {
  156. "Accept": "application/json, text/plain, */*",
  157. "Content-Type": "application/json",
  158. "Authorization": f"Bearer {self.gpustack_token}",
  159. }
  160. def _request(
  161. self,
  162. method: str,
  163. path: str,
  164. *,
  165. json_body: Optional[Dict[str, Any]] = None,
  166. params: Optional[Dict[str, Any]] = None,
  167. stream: bool = False,
  168. timeout: int = 60,
  169. ) -> Any:
  170. url = f"{self.gpustack_url}{path}"
  171. if params:
  172. query_string = urllib.parse.urlencode(params)
  173. url = f"{url}?{query_string}"
  174. logger.debug("HTTP %s %s", method, url)
  175. body = None
  176. headers = self._headers()
  177. if json_body is not None:
  178. body = json.dumps(json_body).encode("utf-8")
  179. request = urllib.request.Request(
  180. url=url,
  181. data=body,
  182. headers=headers,
  183. method=method.upper(),
  184. )
  185. try:
  186. return urllib.request.urlopen(
  187. request,
  188. timeout=timeout,
  189. context=self.ssl_context,
  190. )
  191. except urllib.error.HTTPError as exc:
  192. details = exc.read().decode("utf-8", errors="replace").strip()
  193. raise RuntimeError(f"{method} {url} failed: {exc.code} {details}") from exc
  194. def _slugify(self, value: str) -> str:
  195. slug = re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-")
  196. return slug or "benchmark"
  197. def _repo_model_name(self) -> str:
  198. return self._slugify(self.model_repo_id.split("/")[-1])
  199. def _timestamp(self) -> str:
  200. return time.strftime("%Y%m%d%H%M%S")
  201. def _bounded_name(self, *parts: str, max_length: int = 63) -> str:
  202. """
  203. Build a readable, stable name that always fits Kubernetes-style label limits.
  204. Keeps a human-readable prefix and appends a short hash when truncation is needed.
  205. """
  206. base = "_".join(part for part in parts if part)
  207. if len(base) <= max_length:
  208. return base
  209. digest = hashlib.sha1(base.encode("utf-8")).hexdigest()[:8]
  210. reserve = len(digest) + 1
  211. prefix = base[: max_length - reserve].rstrip("_.-")
  212. if not prefix:
  213. prefix = "benchmark"
  214. return f"{prefix}_{digest}"
  215. def _build_benchmark_name(
  216. self,
  217. model_name: str,
  218. profile_slug: str,
  219. request_rate: int,
  220. ts: str,
  221. max_length: int = 63,
  222. ) -> str:
  223. """
  224. Prefer a readable timestamped name when it fits.
  225. Fall back to the readable name without timestamp, then to a bounded hashed name.
  226. """
  227. readable_name = f"{model_name}_{profile_slug}_r{request_rate}"
  228. timestamped_name = f"{readable_name}_{ts}"
  229. if len(timestamped_name) <= max_length:
  230. return timestamped_name
  231. if len(readable_name) <= max_length:
  232. return readable_name
  233. return self._bounded_name(model_name, profile_slug, f"r{request_rate}", ts)
  234. def _iter_sse_payloads(self, response: Any) -> Iterator[Dict[str, Any]]:
  235. event_lines: List[str] = []
  236. for raw_line in response:
  237. line = raw_line.decode("utf-8", errors="replace").strip()
  238. if not line:
  239. if event_lines:
  240. payload = "\n".join(event_lines)
  241. event_lines = []
  242. try:
  243. yield json.loads(payload)
  244. except json.JSONDecodeError:
  245. logger.debug("Skip non-JSON SSE payload: %s", payload)
  246. continue
  247. if line.startswith("data:"):
  248. event_lines.append(line[5:].strip())
  249. continue
  250. try:
  251. yield json.loads(line)
  252. except json.JSONDecodeError:
  253. logger.debug("Skip non-JSON stream line: %s", line)
  254. if event_lines:
  255. payload = "\n".join(event_lines)
  256. try:
  257. yield json.loads(payload)
  258. except json.JSONDecodeError:
  259. logger.debug("Skip trailing non-JSON SSE payload: %s", payload)
  260. def _get_model(self, model_id: int) -> Dict[str, Any]:
  261. with self._request("GET", f"/v2/models/{model_id}") as response:
  262. return json.loads(response.read().decode("utf-8"))
  263. def _list_models(
  264. self,
  265. *,
  266. search: Optional[str] = None,
  267. timeout: int = 60,
  268. ) -> List[Dict[str, Any]]:
  269. params: Dict[str, Any] = {
  270. "perPage": 100,
  271. "page": 1,
  272. "cluster_id": self.cluster_id,
  273. }
  274. if search:
  275. params["search"] = search
  276. with self._request(
  277. "GET", "/v2/models", params=params, timeout=timeout
  278. ) as response:
  279. payload = json.loads(response.read().decode("utf-8"))
  280. return payload.get("items", [])
  281. def _list_model_instances(
  282. self, model_id: int, timeout: int = 60
  283. ) -> List[Dict[str, Any]]:
  284. with self._request(
  285. "GET",
  286. "/v2/model-instances",
  287. params={"model_id": model_id, "perPage": 100, "page": 1},
  288. timeout=timeout,
  289. ) as response:
  290. response = json.loads(response.read().decode("utf-8"))
  291. return response.get("items", [])
  292. def _result_path(self, benchmark_name: str) -> Path:
  293. filename = f"{self._slugify(benchmark_name)}.json"
  294. return self.output_dir / filename
  295. def _dump_result(self, path: Path, payload: Dict[str, Any]) -> None:
  296. path.parent.mkdir(parents=True, exist_ok=True)
  297. with path.open("w", encoding="utf-8") as f:
  298. json.dump(payload, f, indent=2, ensure_ascii=False)
  299. def _is_retryable_request_error(self, exc: Exception) -> bool:
  300. return isinstance(exc, (urllib.error.URLError, TimeoutError, OSError))
  301. def _matches_model_payload(
  302. self,
  303. existing: Dict[str, Any],
  304. expected_name: str,
  305. payload: Dict[str, Any],
  306. ) -> bool:
  307. if existing.get("name") != expected_name:
  308. return False
  309. if existing.get("cluster_id") != self.cluster_id:
  310. return False
  311. if existing.get("backend") != payload.get("backend"):
  312. return False
  313. if existing.get("source") != payload.get("source"):
  314. return False
  315. if payload.get("source") == Source.Huggingface.value:
  316. return existing.get("huggingface_repo_id") == payload.get(
  317. "huggingface_repo_id"
  318. )
  319. if payload.get("source") == Source.ModelScope.value:
  320. return existing.get("model_scope_model_id") == payload.get(
  321. "model_scope_model_id"
  322. )
  323. return True
  324. def _find_existing_model(
  325. self,
  326. model_name: str,
  327. payload: Dict[str, Any],
  328. *,
  329. timeout: int = 30,
  330. ) -> Optional[Dict[str, Any]]:
  331. try:
  332. candidates = self._list_models(search=model_name, timeout=timeout)
  333. except Exception as exc:
  334. if self._is_retryable_request_error(exc):
  335. logger.warning("Failed to query existing model %s: %s", model_name, exc)
  336. return None
  337. raise
  338. for candidate in candidates:
  339. if self._matches_model_payload(candidate, model_name, payload):
  340. return candidate
  341. return None
  342. def _apply_model_identity(self, config: Model, payload: Dict[str, Any]) -> None:
  343. config.model_id = payload["id"]
  344. config.model_name = payload["name"]
  345. def _update_existing_model(self, config: Model, payload: Dict[str, Any]) -> None:
  346. if config.model_id is None:
  347. raise RuntimeError("Cannot update model without model id")
  348. with self._request(
  349. "PUT", f"/v2/models/{config.model_id}", json_body=payload, timeout=120
  350. ) as response:
  351. updated = json.loads(response.read().decode("utf-8"))
  352. self._apply_model_identity(config, updated)
  353. def _create_or_reuse_model(
  354. self,
  355. config: Model,
  356. payload: Dict[str, Any],
  357. *,
  358. retries: int = 3,
  359. retry_delay: float = 5.0,
  360. ) -> None:
  361. model_name = payload["name"]
  362. existing = self._find_existing_model(model_name, payload)
  363. if existing:
  364. self._apply_model_identity(config, existing)
  365. logger.info(
  366. "Reusing existing model %s (id=%s) for run %s",
  367. config.model_name,
  368. config.model_id,
  369. config.name,
  370. )
  371. self._update_existing_model(config, payload)
  372. return
  373. last_error: Optional[Exception] = None
  374. for attempt in range(1, retries + 1):
  375. try:
  376. with self._request(
  377. "POST", "/v2/models", json_body=payload, timeout=120
  378. ) as response:
  379. created = json.loads(response.read().decode("utf-8"))
  380. self._apply_model_identity(config, created)
  381. logger.info(
  382. "Created model %s (id=%s) for run %s",
  383. config.model_name,
  384. config.model_id,
  385. config.name,
  386. )
  387. return
  388. except RuntimeError as exc:
  389. last_error = exc
  390. if "already exists" in str(exc).lower():
  391. existing = self._find_existing_model(model_name, payload)
  392. if existing:
  393. self._apply_model_identity(config, existing)
  394. logger.info(
  395. "Detected existing model %s (id=%s) after create conflict",
  396. config.model_name,
  397. config.model_id,
  398. )
  399. self._update_existing_model(config, payload)
  400. return
  401. raise
  402. except Exception as exc:
  403. last_error = exc
  404. if not self._is_retryable_request_error(exc):
  405. raise
  406. existing = self._find_existing_model(model_name, payload)
  407. if existing:
  408. self._apply_model_identity(config, existing)
  409. logger.info(
  410. "Found existing model %s (id=%s) after create timeout/error",
  411. config.model_name,
  412. config.model_id,
  413. )
  414. self._update_existing_model(config, payload)
  415. return
  416. if attempt == retries:
  417. raise
  418. logger.warning(
  419. "Create model request failed for %s (attempt %s/%s): %s; retrying in %.1fs",
  420. model_name,
  421. attempt,
  422. retries,
  423. exc,
  424. retry_delay,
  425. )
  426. time.sleep(retry_delay)
  427. if last_error is not None:
  428. raise last_error
  429. def _benchmark_metrics_ready(self, payload: Dict[str, Any]) -> bool:
  430. return bool(
  431. payload.get("raw_metrics") is not None
  432. or payload.get("request_latency_mean") is not None
  433. or payload.get("tokens_per_second_mean") is not None
  434. or payload.get("requests_per_second_mean") is not None
  435. )
  436. def _wait_for_benchmark_result(
  437. self,
  438. benchmark_id: int,
  439. *,
  440. timeout: int = 180,
  441. poll_interval: float = 2.0,
  442. ) -> Dict[str, Any]:
  443. """
  444. Wait for the benchmark detail endpoint to include synced metrics.
  445. The worker marks a benchmark as `completed` before it uploads parsed metrics
  446. back to `/v2/benchmarks/{id}/metrics`, so a detail fetch immediately after the
  447. completion event can still return empty metric fields.
  448. """
  449. deadline = time.time() + timeout
  450. last_payload: Dict[str, Any] = {}
  451. while time.time() < deadline:
  452. with self._request(
  453. "GET", f"/v2/benchmarks/{benchmark_id}", timeout=120
  454. ) as response:
  455. last_payload = json.loads(response.read().decode("utf-8"))
  456. if self._benchmark_metrics_ready(last_payload):
  457. return last_payload
  458. time.sleep(poll_interval)
  459. logger.warning(
  460. "Timed out waiting for benchmark %s metrics to sync; saving latest payload without raw metrics",
  461. benchmark_id,
  462. )
  463. return last_payload
  464. def setup_model(self, model: Model):
  465. """
  466. Create a GPUStack model deployment for one run definition.
  467. The YAML `backend_parameters`, `envs`, and backend/version fields are passed
  468. through almost directly to GPUStack's `/v2/models` API.
  469. """
  470. name = f"{self._repo_model_name()}-{model.name}"
  471. source = self.source
  472. payload = {
  473. "source": self.source,
  474. "huggingface_repo_id": self.model_repo_id,
  475. "huggingface_filename": None,
  476. "model_scope_model_id": None,
  477. "model_scope_file_path": None,
  478. "local_path": None,
  479. "description": None,
  480. "meta": {},
  481. "replicas": 1,
  482. "ready_replicas": 0,
  483. "categories": ["llm"],
  484. "placement_strategy": "spread",
  485. "cpu_offloading": None,
  486. "distributed_inference_across_workers": True,
  487. "worker_selector": {},
  488. "gpu_selector": None,
  489. "backend": model.backend.value,
  490. "backend_version": model.backend_version,
  491. "backend_parameters": model.backend_parameters or [],
  492. "image_name": None,
  493. "run_command": None,
  494. "env": model.envs or None,
  495. "restart_on_error": False,
  496. "distributable": False,
  497. "extended_kv_cache": {},
  498. "speculative_config": {},
  499. "generic_proxy": False,
  500. "cluster_id": self.cluster_id,
  501. "name": name,
  502. "enable_model_route": True,
  503. }
  504. if source == Source.Huggingface.value:
  505. payload["huggingface_repo_id"] = self.model_repo_id
  506. elif source == Source.ModelScope.value:
  507. payload["model_scope_model_id"] = self.model_repo_id
  508. self._create_or_reuse_model(model, payload)
  509. def monitor_model_startup(self, config: Model):
  510. """Poll the first model instance until it becomes `running` or fails."""
  511. health_check = config.health_check or HealthCheck()
  512. if health_check.init_delay > 0:
  513. logger.info(
  514. "Waiting %ss before polling model startup", health_check.init_delay
  515. )
  516. time.sleep(health_check.init_delay)
  517. deadline = time.time() + health_check.timeout
  518. last_state = None
  519. logged_waiting_instances = False
  520. poll_timeout = min(max(int(health_check.interval * 2), 10), 30)
  521. while time.time() < deadline:
  522. try:
  523. # Treat transient API/proxy timeouts as retryable during startup.
  524. instances = self._list_model_instances(
  525. config.model_id, timeout=poll_timeout
  526. )
  527. except (
  528. urllib.error.URLError,
  529. TimeoutError,
  530. OSError,
  531. ) as exc:
  532. logger.warning(
  533. "Polling model startup failed: %s; retrying in %ss",
  534. exc,
  535. health_check.interval,
  536. )
  537. time.sleep(health_check.interval)
  538. continue
  539. if not instances:
  540. if not logged_waiting_instances:
  541. logger.info(
  542. "Model %s (id=%s) has no instances yet; waiting...",
  543. config.model_name,
  544. config.model_id,
  545. )
  546. logged_waiting_instances = True
  547. else:
  548. logged_waiting_instances = False
  549. instance = instances[0]
  550. state = instance.get("state")
  551. if state != last_state:
  552. logger.info(
  553. "Model instance %s state: %s (%s)",
  554. instance.get("name"),
  555. state,
  556. instance.get("state_message", ""),
  557. )
  558. last_state = state
  559. if state == "running":
  560. config.instance_name = instance["name"]
  561. return
  562. if state in {"error", "unreachable"}:
  563. raise RuntimeError(
  564. f"Model instance {instance.get('name')} failed: {instance.get('state_message')}"
  565. )
  566. time.sleep(health_check.interval)
  567. raise TimeoutError(
  568. f"Timed out waiting for model {config.model_name} to become running"
  569. )
  570. def stop_model(self, config: Model): # noqa: C901
  571. """Scale the model to zero replicas and wait until instances are gone."""
  572. if config.model_id is None:
  573. return
  574. payload: Optional[Dict[str, Any]] = None
  575. for attempt in range(1, 4):
  576. try:
  577. payload = self._get_model(config.model_id)
  578. break
  579. except Exception as exc:
  580. if not self._is_retryable_request_error(exc) or attempt == 3:
  581. logger.error(
  582. "Failed to fetch model %s (id=%s) before scale down: %s",
  583. config.model_name,
  584. config.model_id,
  585. exc,
  586. )
  587. return
  588. logger.warning(
  589. "Fetching model %s (id=%s) before scale down failed (attempt %s/3): %s; retrying",
  590. config.model_name,
  591. config.model_id,
  592. attempt,
  593. exc,
  594. )
  595. time.sleep(5)
  596. if payload is None:
  597. return
  598. payload["replicas"] = 0
  599. payload.pop("id", None)
  600. payload.pop("created_at", None)
  601. payload.pop("updated_at", None)
  602. payload.pop("ready_replicas", None)
  603. scale_down_sent = False
  604. for attempt in range(1, 4):
  605. try:
  606. with self._request(
  607. "PUT",
  608. f"/v2/models/{config.model_id}",
  609. json_body=payload,
  610. timeout=120,
  611. ):
  612. pass
  613. scale_down_sent = True
  614. break
  615. except Exception as exc:
  616. if not self._is_retryable_request_error(exc) or attempt == 3:
  617. logger.error(
  618. "Failed to scale down model %s (id=%s): %s",
  619. config.model_name,
  620. config.model_id,
  621. exc,
  622. )
  623. return
  624. logger.warning(
  625. "Scale down request for model %s (id=%s) failed (attempt %s/3): %s; retrying",
  626. config.model_name,
  627. config.model_id,
  628. attempt,
  629. exc,
  630. )
  631. time.sleep(5)
  632. if not scale_down_sent:
  633. return
  634. deadline = time.time() + 300
  635. while time.time() < deadline:
  636. try:
  637. model = self._get_model(config.model_id)
  638. instances = self._list_model_instances(config.model_id, timeout=15)
  639. except (
  640. urllib.error.URLError,
  641. TimeoutError,
  642. OSError,
  643. ) as exc:
  644. logger.warning("Polling model scale down failed: %s; retrying", exc)
  645. time.sleep(5)
  646. continue
  647. replicas = model.get("replicas")
  648. ready_replicas = model.get("ready_replicas")
  649. if replicas == 0 and not instances:
  650. logger.info(
  651. "Stopped model %s (id=%s)", config.model_name, config.model_id
  652. )
  653. return
  654. logger.info(
  655. "Waiting for model %s to scale down: replicas=%s ready_replicas=%s instances=%s",
  656. config.model_name,
  657. replicas,
  658. ready_replicas,
  659. len(instances),
  660. )
  661. time.sleep(5)
  662. logger.warning(
  663. "Scale down request sent for model %s (id=%s), but instances still exist after timeout",
  664. config.model_name,
  665. config.model_id,
  666. )
  667. def warmup_service(self, config: Model):
  668. """
  669. Send a few small chat-completions requests before benchmarking.
  670. Warmup helps reduce noise from first-request effects such as lazy kernel
  671. initialization or cold tokenizer/model paths.
  672. """
  673. if not config.model_name or not config.warmup_num_requests:
  674. return
  675. payload = {
  676. "model": config.model_name,
  677. "messages": [{"role": "user", "content": "Reply with OK."}],
  678. "temperature": 0,
  679. "top_p": 1,
  680. "max_tokens": 8,
  681. "stream": False,
  682. }
  683. warmup_errors = 0
  684. for _ in range(config.warmup_num_requests):
  685. try:
  686. with self._request(
  687. "POST", "/v1/chat/completions", json_body=payload, timeout=120
  688. ):
  689. pass
  690. except Exception as exc:
  691. warmup_errors += 1
  692. logger.warning("Warmup request failed: %s", exc)
  693. if warmup_errors >= 3:
  694. raise
  695. time.sleep(2)
  696. else:
  697. time.sleep(0.2)
  698. def create_benchmark(
  699. self,
  700. config: Model,
  701. profile: Dict[str, Any],
  702. request_rate: Optional[int] = None,
  703. ):
  704. """
  705. Create one GPUStack benchmark job from one profile definition.
  706. `request_rate` can be overridden from the CLI. If not provided, the script
  707. uses the `request_rate` defined inside the selected profile YAML.
  708. """
  709. if not config.instance_name or config.model_id is None or not config.model_name:
  710. raise RuntimeError("Model is not ready for benchmark creation")
  711. effective_request_rate = (
  712. request_rate
  713. if request_rate is not None
  714. else profile.get("request_rate", 10)
  715. )
  716. profile_slug = self._slugify(profile["name"])
  717. ts = f"{self._timestamp()}"[-4:]
  718. benchmark_name = self._build_benchmark_name(
  719. config.model_name,
  720. profile_slug,
  721. effective_request_rate,
  722. ts,
  723. )
  724. payload = {
  725. "name": benchmark_name,
  726. "cluster_id": self.cluster_id,
  727. "model_name": config.model_name,
  728. "model_id": config.model_id,
  729. "model_instance_name": config.instance_name,
  730. "profile": profile["name"],
  731. "dataset_name": profile.get("dataset_name"),
  732. "dataset_input_tokens": profile.get("dataset_input_tokens"),
  733. "dataset_output_tokens": profile.get("dataset_output_tokens"),
  734. "dataset_seed": profile.get("dataset_seed"),
  735. "dataset_shared_prefix_tokens": profile.get("dataset_shared_prefix_tokens"),
  736. "request_rate": effective_request_rate,
  737. "total_requests": profile.get("total_requests"),
  738. "max_concurrency": profile.get("max_concurrency"),
  739. }
  740. with self._request(
  741. "POST", "/v2/benchmarks", json_body=payload, timeout=120
  742. ) as response:
  743. created = json.loads(response.read().decode("utf-8"))
  744. config.benchmark_id = created["id"]
  745. config.benchmark_name = created["name"]
  746. logger.info(
  747. "Created benchmark %s (id=%s) for test case %s",
  748. config.benchmark_name,
  749. config.benchmark_id,
  750. profile["name"],
  751. )
  752. def monitor_benchmark(
  753. self,
  754. config: Model,
  755. test_case: str,
  756. request_rate: Optional[int] = None,
  757. ) -> Dict[str, Any]:
  758. """
  759. Watch the benchmark SSE stream until completion and dump the final payload.
  760. The resulting JSON is the full `/v2/benchmarks/{id}` response, which makes
  761. it suitable for later offline analysis without re-querying GPUStack.
  762. """
  763. if config.benchmark_id is None:
  764. raise RuntimeError("Benchmark has not been created")
  765. if not config.benchmark_name:
  766. raise RuntimeError("Benchmark name is missing")
  767. result_path = self._result_path(config.benchmark_name)
  768. last_state = None
  769. watch_response = self._request(
  770. "GET",
  771. "/v2/benchmarks",
  772. params={"watch": "true"},
  773. timeout=3600,
  774. )
  775. with watch_response:
  776. for event in self._iter_sse_payloads(watch_response):
  777. payload = event.get("data", event)
  778. if not isinstance(payload, dict):
  779. continue
  780. if payload.get("id") != config.benchmark_id:
  781. continue
  782. state = payload.get("state")
  783. if state != last_state:
  784. logger.info(
  785. "Benchmark %s state: %s (%s)",
  786. payload.get("name"),
  787. state,
  788. payload.get("state_message"),
  789. )
  790. last_state = state
  791. if state == "completed":
  792. final_result = self._wait_for_benchmark_result(config.benchmark_id)
  793. self._dump_result(result_path, final_result)
  794. return final_result
  795. if state in {"error", "stopped", "unreachable"}:
  796. raise RuntimeError(
  797. f"Benchmark {payload.get('name')} failed: {payload.get('state_message')}"
  798. )
  799. raise RuntimeError(
  800. f"Benchmark stream ended before completion: {config.benchmark_name}"
  801. )
  802. def parse_results(self, result_file: str) -> Dict[str, Any]:
  803. """Parse and extract metrics from benchmark result file"""
  804. result_path = Path(result_file)
  805. if result_path.exists():
  806. with result_path.open("r", encoding="utf-8") as f:
  807. return json.load(f)
  808. return {}
  809. def run_engine_test(
  810. self,
  811. config: Model,
  812. profiles: Dict[str, Dict[str, Any]],
  813. output_dir: str,
  814. ):
  815. """
  816. Execute one run from the YAML end-to-end.
  817. One run may map to:
  818. - multiple test cases
  819. - multiple request rates per test case
  820. The execution order is:
  821. create model -> wait for ready -> warm up -> run benchmarks -> stop model
  822. """
  823. logger.info("Starting test for %s", config.name)
  824. self.output_dir = Path(output_dir)
  825. self.output_dir.mkdir(parents=True, exist_ok=True)
  826. try:
  827. self.setup_model(config)
  828. self.monitor_model_startup(config)
  829. self.warmup_service(config)
  830. for test_case in config.test_cases:
  831. profile = profiles.get(test_case)
  832. if not profile:
  833. logger.warning(
  834. "Profile for test case '%s' not found, skipping", test_case
  835. )
  836. continue
  837. request_rates = config.request_rates or [None]
  838. for request_rate in request_rates:
  839. self.create_benchmark(config, profile, request_rate)
  840. self.monitor_benchmark(config, test_case, request_rate)
  841. if request_rate is None:
  842. logger.info("Completed test case: %s", test_case)
  843. else:
  844. logger.info(
  845. "Completed test case: %s with request_rate=%s",
  846. test_case,
  847. request_rate,
  848. )
  849. except Exception as e:
  850. logger.error("Error running test %s: %s", config.name, e)
  851. raise
  852. finally:
  853. if not config.stop_model_after_run:
  854. logger.info(
  855. "Skipping stop model for %s (id=%s) because stop_model_after_run=false",
  856. config.model_name,
  857. config.model_id,
  858. )
  859. else:
  860. try:
  861. self.stop_model(config)
  862. except Exception as exc:
  863. logger.error(
  864. "Unexpected error while stopping model %s (id=%s): %s",
  865. config.model_name,
  866. config.model_id,
  867. exc,
  868. )
  869. time.sleep(15)
  870. def load_yaml(config_file: str) -> Dict[str, Any]:
  871. """Load YAML configuration file"""
  872. with open(config_file, "r", encoding="utf-8") as f:
  873. return yaml.safe_load(f)
  874. def load_profile(profile_file: str) -> Dict[str, Any]:
  875. """Load benchmark profile configuration"""
  876. data = load_yaml(profile_file)
  877. profile_dict = {}
  878. for profile in data.get("profiles", []):
  879. profile_dict[profile["name"]] = profile
  880. return profile_dict
  881. def load_config(config_file: str) -> Dict[str, Any]:
  882. """Load benchmark configuration from YAML file"""
  883. return load_yaml(config_file)
  884. def create_engine_configs_from_config(
  885. config: Dict[str, Any],
  886. run_names: Optional[List[str]],
  887. test_cases: Optional[List[str]],
  888. request_rates: Optional[List[int]],
  889. ) -> List[Model]:
  890. """
  891. Materialize YAML `runs` into `Model` objects used by the executor.
  892. Precedence rules:
  893. 1. `--run-names` filters which runs are created
  894. 2. `--test-cases` overrides both run-level and top-level `test_cases`
  895. 3. run-level `test_cases` override top-level `test_cases`
  896. 4. `--request-rates` overrides the profile's `request_rate`
  897. """
  898. engine_configs = []
  899. health_check_config = config.get("health_check", {})
  900. default_health_check = HealthCheck(
  901. timeout=health_check_config.get("timeout", 30),
  902. interval=health_check_config.get("interval", 1.0),
  903. init_delay=health_check_config.get("init_delay", 60),
  904. )
  905. default_warmup_num_requests = config.get("warmup", {}).get("num_requests", 10)
  906. default_stop_model_after_run = config.get("stop_model_after_run", True)
  907. default_test_cases = [
  908. case["name"] if isinstance(case, dict) else str(case)
  909. for case in config.get("test_cases", [])
  910. ]
  911. run_name_filter = set(run_names or [])
  912. for run_config in config.get("runs", []):
  913. if run_name_filter and run_config["name"] not in run_name_filter:
  914. logger.info(
  915. "Skipping run %s as it's not in specified run names", run_config["name"]
  916. )
  917. continue
  918. if test_cases:
  919. selected_test_cases = test_cases
  920. elif "test_cases" in run_config:
  921. selected_test_cases = [
  922. case["name"] if isinstance(case, dict) else str(case)
  923. for case in run_config.get("test_cases", [])
  924. ]
  925. else:
  926. selected_test_cases = default_test_cases
  927. run_health_check_config = run_config.get("health_check", {})
  928. health_check = HealthCheck(
  929. timeout=run_health_check_config.get(
  930. "timeout", default_health_check.timeout
  931. ),
  932. interval=run_health_check_config.get(
  933. "interval", default_health_check.interval
  934. ),
  935. init_delay=run_health_check_config.get(
  936. "init_delay", default_health_check.init_delay
  937. ),
  938. )
  939. engine_config = Model(
  940. name=run_config["name"],
  941. test_cases=selected_test_cases,
  942. backend=EngineType(run_config["backend"]),
  943. backend_version=(
  944. str(run_config.get("backend_version"))
  945. if run_config.get("backend_version") is not None
  946. else None
  947. ),
  948. backend_parameters=run_config.get("backend_parameters", []),
  949. envs=run_config.get("envs", {}),
  950. args=run_config.get("args", []),
  951. health_check=health_check,
  952. warmup_num_requests=run_config.get(
  953. "warmup_num_requests", default_warmup_num_requests
  954. ),
  955. stop_model_after_run=run_config.get(
  956. "stop_model_after_run", default_stop_model_after_run
  957. ),
  958. request_rates=request_rates,
  959. )
  960. engine_configs.append(engine_config)
  961. return engine_configs
  962. def main():
  963. """CLI entry point."""
  964. import argparse
  965. parser = argparse.ArgumentParser(
  966. description="LLM Inference Engine Automated Performance Testing"
  967. )
  968. parser.add_argument(
  969. "--config", default="config.yaml", help="Path to configuration YAML file"
  970. )
  971. parser.add_argument(
  972. "--profile", default="profile.yaml", help="Path to profile YAML file"
  973. )
  974. parser.add_argument("--model", help="Override the model repo id from config")
  975. parser.add_argument("--gpustack-url", required=True, help="GPUStack URL")
  976. parser.add_argument("--gpustack-token", required=True, help="GPUStack token")
  977. parser.add_argument(
  978. "--cluster-id",
  979. "--gpustack-cluster-id",
  980. dest="cluster_id",
  981. type=int,
  982. required=True,
  983. help="GPUStack cluster id",
  984. )
  985. parser.add_argument(
  986. "--output-dir",
  987. default="benchmark_results",
  988. help="Output directory for results",
  989. )
  990. parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
  991. parser.add_argument(
  992. "--run-names",
  993. type=lambda s: [name.strip() for name in s.split(",") if name.strip()],
  994. default=[],
  995. help="Specific run names to execute, comma-separated",
  996. )
  997. parser.add_argument(
  998. "--test-cases",
  999. type=lambda s: [name.strip() for name in s.split(",") if name.strip()],
  1000. default=[],
  1001. help=(
  1002. "Specific test case names to execute, comma-separated. "
  1003. "Overrides run-level and config-level test_cases."
  1004. ),
  1005. )
  1006. parser.add_argument(
  1007. "--request-rates",
  1008. type=lambda s: [int(rate.strip()) for rate in s.split(",") if rate.strip()],
  1009. default=[],
  1010. help="Override profile request_rate with one or more comma-separated values, e.g. 1,4,8,16",
  1011. )
  1012. parser.add_argument(
  1013. "--stop-model-after-run",
  1014. dest="stop_model_after_run",
  1015. action="store_true",
  1016. default=None,
  1017. help="Stop the model after each run completes. Overrides config when provided.",
  1018. )
  1019. parser.add_argument(
  1020. "--no-stop-model-after-run",
  1021. dest="stop_model_after_run",
  1022. action="store_false",
  1023. help="Keep the model running after each run completes. Overrides config when provided.",
  1024. )
  1025. args = parser.parse_args()
  1026. if args.verbose:
  1027. logging.getLogger().setLevel(logging.DEBUG)
  1028. config = load_config(args.config)
  1029. profile = load_profile(args.profile)
  1030. model = args.model or config["model"]
  1031. source = config["source"]
  1032. output_dir = args.output_dir or config.get("output_dir", "benchmark_results")
  1033. manager = EngineManager(
  1034. model,
  1035. source,
  1036. args.gpustack_url,
  1037. args.gpustack_token,
  1038. args.cluster_id,
  1039. output_dir,
  1040. )
  1041. engine_configs = create_engine_configs_from_config(
  1042. config,
  1043. args.run_names,
  1044. args.test_cases or None,
  1045. args.request_rates or None,
  1046. )
  1047. if not engine_configs:
  1048. raise SystemExit("No matching runs found")
  1049. if args.stop_model_after_run is not None:
  1050. for engine_config in engine_configs:
  1051. engine_config.stop_model_after_run = args.stop_model_after_run
  1052. for engine_config in engine_configs:
  1053. try:
  1054. manager.run_engine_test(engine_config, profile, output_dir)
  1055. logger.info("Successfully completed test: %s", engine_config.name)
  1056. except Exception as e:
  1057. logger.error("Failed to run test %s: %s", engine_config.name, e)
  1058. logger.info("All tests completed")
  1059. class RedirectStdoutStderr:
  1060. """Utility context manager for callers that want to redirect script output."""
  1061. def __init__(self, target):
  1062. self.target = target
  1063. def __enter__(self):
  1064. self.original_stdout = sys.stdout
  1065. self.original_stderr = sys.stderr
  1066. sys.stdout = self.target
  1067. sys.stderr = self.target
  1068. def __exit__(self, exc_type, exc_value, traceback):
  1069. sys.stdout = self.original_stdout
  1070. sys.stderr = self.original_stderr
  1071. if __name__ == "__main__":
  1072. main()