benchmark_serving.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import asyncio
  2. from dataclasses import asdict, dataclass, is_dataclass
  3. import time
  4. from typing import List, Optional
  5. import aiohttp
  6. import numpy
  7. import logging
  8. import argparse
  9. import json
  10. import random
  11. from openai import APIConnectionError, AsyncOpenAI
  12. from aiohttp import ClientSession
  13. from httpx_aiohttp import AiohttpTransport
  14. from openai import DefaultAsyncHttpxClient
  15. from openai.types.chat import (
  16. ChatCompletionStreamOptionsParam,
  17. )
  18. from tqdm import tqdm
  19. logging.basicConfig(
  20. level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s"
  21. )
  22. SAMPLE_PROMPTS = [
  23. "Explain how blockchain technology works, and provide a real-world example of its application outside of cryptocurrency.",
  24. "Compare and contrast the philosophies of Nietzsche and Kant, including their views on morality and human nature.",
  25. "Imagine you're a travel blogger. Write a detailed post describing a week-long adventure through rural Japan.",
  26. "Write a fictional letter from Albert Einstein to a modern-day physicist, discussing the current state of quantum mechanics.",
  27. "Provide a comprehensive explanation of how transformers work in machine learning, including attention mechanisms and positional encoding.",
  28. "Draft a business proposal for launching a new AI-powered productivity app, including target audience, key features, and a monetization strategy.",
  29. "Simulate a panel discussion between Elon Musk, Marie Curie, and Sun Tzu on the topic of 'Leadership in Times of Crisis'.",
  30. "Describe the process of photosynthesis in depth, and explain its importance in the global carbon cycle.",
  31. "Analyze the impact of social media on political polarization, citing relevant studies or historical examples.",
  32. "Write a short science fiction story where humans discover a parallel universe that operates under different physical laws.",
  33. "Explain the role of the Federal Reserve in the U.S. economy and how it manages inflation and unemployment.",
  34. "Describe the architecture of a modern web application, from frontend to backend, including databases, APIs, and deployment.",
  35. "Write an essay discussing whether artificial general intelligence (AGI) poses an existential threat to humanity.",
  36. "Summarize the key events and consequences of the Cuban Missile Crisis, and reflect on lessons for modern diplomacy.",
  37. "Create a guide for beginners on how to train a custom LLM using open-source tools and publicly available datasets.",
  38. ]
  39. @dataclass
  40. class PercentileResults:
  41. average: float
  42. p50: float
  43. p95: float
  44. p99: float
  45. @dataclass
  46. class BenchmarkResults:
  47. model: str
  48. total_requests: int
  49. successful_requests: int
  50. success_rate: float
  51. concurrency: int
  52. request_timeout: int
  53. max_completion_tokens: int
  54. total_time: float
  55. requests_per_second: float
  56. total_tokens: int
  57. total_prompt_tokens: int
  58. total_completion_tokens: int
  59. total_tokens_per_second: float
  60. total_prompt_tokens_per_second: float
  61. total_completion_tokens_per_second: float
  62. latency: PercentileResults
  63. completion_tokens_per_second: PercentileResults
  64. time_to_first_token: PercentileResults
  65. async def process_stream(stream):
  66. first_token_time = None
  67. async for chunk in stream:
  68. if first_token_time is None:
  69. first_token_time = time.time()
  70. if chunk.usage:
  71. return first_token_time, chunk.usage
  72. return first_token_time, None
  73. def get_random_prompt(prompt_multiplier):
  74. """
  75. Returns a random prompt from the SAMPLE_PROMPTS list, repeated prompt_multiplier times.
  76. """
  77. # Add a random prefix to avoid prefix cache hits
  78. random_prefix = str(random.randint(100000, 999999))
  79. return (
  80. random_prefix + " " + (random.choice(SAMPLE_PROMPTS) + " ") * prompt_multiplier
  81. )
  82. async def make_chat_completion_request(
  83. client: AsyncOpenAI,
  84. model,
  85. max_completion_tokens,
  86. ignore_eos,
  87. request_timeout,
  88. prompt_multiplier,
  89. ):
  90. start_time = time.time()
  91. content = get_random_prompt(prompt_multiplier)
  92. try:
  93. stream = await client.chat.completions.create(
  94. model=model,
  95. messages=[{"role": "user", "content": content}],
  96. max_completion_tokens=max_completion_tokens,
  97. stream=True,
  98. stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
  99. extra_body={"ignore_eos": ignore_eos} if ignore_eos else None,
  100. )
  101. first_token_time, usage = await asyncio.wait_for(
  102. process_stream(stream), timeout=request_timeout
  103. )
  104. end_time = time.time()
  105. elapsed_time = end_time - start_time
  106. ttft = (first_token_time - start_time) * 1000 if first_token_time else None
  107. return usage, elapsed_time, ttft
  108. except asyncio.TimeoutError:
  109. logging.warning(f"Request timed out after {request_timeout} seconds")
  110. return None
  111. except APIConnectionError as e:
  112. logging.error(f"API connection error: {str(e)}")
  113. return None
  114. except Exception as e:
  115. logging.error(f"Error during request: {str(e)}")
  116. return None
  117. async def make_embedding_request(
  118. client: AsyncOpenAI,
  119. model,
  120. request_timeout,
  121. prompt_multiplier=1,
  122. ):
  123. start_time = time.time()
  124. content = get_random_prompt(prompt_multiplier)
  125. try:
  126. response = await asyncio.wait_for(
  127. client.embeddings.create(model=model, input=content),
  128. timeout=request_timeout,
  129. )
  130. end_time = time.time()
  131. elapsed_time = end_time - start_time
  132. ttft = None # Embeddings do not have a time to first token in the same way as chat completions
  133. return response.usage, elapsed_time, ttft
  134. except asyncio.TimeoutError:
  135. logging.warning(f"Embedding request timed out after {request_timeout} seconds")
  136. return None
  137. except Exception as e:
  138. logging.error(f"Error during embedding request: {str(e)}")
  139. return None
  140. async def worker(
  141. client,
  142. model,
  143. semaphore,
  144. queue,
  145. results,
  146. max_completion_tokens,
  147. ignore_eos,
  148. request_timeout,
  149. embeddings=False,
  150. prompt_multiplier=1,
  151. pbar=None,
  152. ):
  153. while True:
  154. async with semaphore:
  155. task_id = await queue.get()
  156. if task_id is None:
  157. queue.task_done()
  158. break
  159. logging.debug(f"Starting request {task_id}")
  160. if embeddings:
  161. result = await make_embedding_request(
  162. client, model, request_timeout, prompt_multiplier
  163. )
  164. else:
  165. result = await make_chat_completion_request(
  166. client,
  167. model,
  168. max_completion_tokens,
  169. ignore_eos,
  170. request_timeout,
  171. prompt_multiplier,
  172. )
  173. if result:
  174. results.append(result)
  175. else:
  176. logging.warning(f"Request {task_id} failed")
  177. queue.task_done()
  178. if pbar:
  179. pbar.update(1)
  180. logging.debug(f"Finished request {task_id}")
  181. def calculate_percentile(values, percentile, reverse=False):
  182. if not values:
  183. return None
  184. if reverse:
  185. return numpy.percentile(values, 100 - percentile)
  186. return numpy.percentile(values, percentile)
  187. async def preflight_check(client, model, embeddings=False) -> bool:
  188. if embeddings:
  189. result = await make_embedding_request(client, model, 16)
  190. else:
  191. result = await make_chat_completion_request(client, model, 16, False, 60, 1)
  192. return result is not None
  193. def set_headers(aiohttp_session: ClientSession, headers: Optional[List[str]]):
  194. if headers:
  195. for header in headers:
  196. if ":" not in header:
  197. raise ValueError(f"Invalid header format: {header}. Expected Key:Value")
  198. key, value = header.split(":", 1)
  199. aiohttp_session.headers[key.strip()] = value.strip()
  200. async def main(
  201. model,
  202. num_requests,
  203. concurrency,
  204. request_timeout,
  205. max_completion_tokens,
  206. ignore_eos,
  207. server_url,
  208. api_key,
  209. headers=None,
  210. embeddings=False,
  211. prompt_multiplier=1,
  212. ) -> Optional[BenchmarkResults]:
  213. connector = aiohttp.TCPConnector(
  214. limit=2000,
  215. force_close=True,
  216. )
  217. async with ClientSession(connector=connector, trust_env=True) as aiohttp_session:
  218. if headers:
  219. set_headers(aiohttp_session, headers)
  220. transport = AiohttpTransport(client=aiohttp_session)
  221. httpx_client = DefaultAsyncHttpxClient(
  222. transport=transport, timeout=request_timeout
  223. )
  224. client = AsyncOpenAI(
  225. base_url=f"{server_url}/v1",
  226. api_key=api_key,
  227. http_client=httpx_client,
  228. max_retries=0,
  229. )
  230. if not await preflight_check(client, model, embeddings=embeddings):
  231. raise Exception(
  232. "Preflight check failed. Please check configuration and the service status."
  233. )
  234. semaphore = asyncio.Semaphore(concurrency)
  235. queue = asyncio.Queue()
  236. results = []
  237. # Add tasks to the queue
  238. for i in range(num_requests):
  239. await queue.put(i)
  240. # Add sentinel values to stop workers
  241. for _ in range(concurrency):
  242. await queue.put(None)
  243. pbar = tqdm(
  244. total=num_requests,
  245. desc="Running Benchmark requests",
  246. unit="request",
  247. dynamic_ncols=True,
  248. )
  249. # Create worker tasks
  250. workers = [
  251. asyncio.create_task(
  252. worker(
  253. client,
  254. model,
  255. semaphore,
  256. queue,
  257. results,
  258. max_completion_tokens,
  259. ignore_eos,
  260. request_timeout,
  261. embeddings,
  262. prompt_multiplier,
  263. pbar=pbar,
  264. )
  265. )
  266. for _ in range(concurrency)
  267. ]
  268. start_time = time.time()
  269. # Wait for all tasks to complete
  270. await queue.join()
  271. await asyncio.gather(*workers)
  272. end_time = time.time()
  273. total_elapsed_time = end_time - start_time
  274. return calculate_results(
  275. model,
  276. concurrency,
  277. request_timeout,
  278. max_completion_tokens,
  279. total_elapsed_time,
  280. num_requests,
  281. results,
  282. )
  283. def calculate_results(
  284. model,
  285. concurrency,
  286. request_timeout,
  287. max_completion_tokens,
  288. total_elapsed_time,
  289. num_requests,
  290. results,
  291. ):
  292. # Calculate metrics
  293. total_tokens = 0
  294. prompt_tokens = 0
  295. completion_tokens = 0
  296. tokens_per_second_list = []
  297. prompt_tokens_per_second_list = []
  298. completion_tokens_per_second_list = []
  299. for usage, elapsed_time, _ in results:
  300. if usage is not None:
  301. total_tokens += usage.total_tokens
  302. prompt_tokens += usage.prompt_tokens
  303. completion_tokens += usage.completion_tokens
  304. prompt_tokens_per_second = (
  305. usage.prompt_tokens / elapsed_time if elapsed_time > 0 else 0
  306. )
  307. completion_tokens_per_second = (
  308. usage.completion_tokens / elapsed_time if elapsed_time > 0 else 0
  309. )
  310. tokens_per_second = (
  311. usage.total_tokens / elapsed_time if elapsed_time > 0 else 0
  312. )
  313. tokens_per_second_list.append(tokens_per_second)
  314. prompt_tokens_per_second_list.append(prompt_tokens_per_second)
  315. completion_tokens_per_second_list.append(completion_tokens_per_second)
  316. latencies = [
  317. elapsed_time for _, elapsed_time, _ in results if elapsed_time is not None
  318. ]
  319. ttft_list = [ttft for _, _, ttft in results if ttft is not None]
  320. successful_requests = len(results)
  321. success_rate = successful_requests / num_requests if num_requests > 0 else 0
  322. requests_per_second = (
  323. successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0
  324. )
  325. avg_latency = sum(latencies) / len(latencies) if latencies else 0
  326. avg_completion_tokens_per_second = (
  327. sum(completion_tokens_per_second_list) / len(completion_tokens_per_second_list)
  328. if completion_tokens_per_second_list
  329. else 0
  330. )
  331. total_tokens_per_second = (
  332. total_tokens / total_elapsed_time if total_elapsed_time > 0 else 0
  333. )
  334. total_prompt_tokens_per_second = (
  335. prompt_tokens / total_elapsed_time if total_elapsed_time > 0 else 0
  336. )
  337. total_completion_tokens_per_second = (
  338. completion_tokens / total_elapsed_time if total_elapsed_time > 0 else 0
  339. )
  340. avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0
  341. # Calculate percentiles
  342. percentiles = [50, 95, 99]
  343. latency_percentiles = [calculate_percentile(latencies, p) for p in percentiles]
  344. completion_tps_percentiles = [
  345. calculate_percentile(completion_tokens_per_second_list, p, reverse=True)
  346. for p in percentiles
  347. ]
  348. ttft_percentiles = [calculate_percentile(ttft_list, p) for p in percentiles]
  349. return BenchmarkResults(
  350. model=model,
  351. total_requests=num_requests,
  352. successful_requests=successful_requests,
  353. success_rate=success_rate,
  354. concurrency=concurrency,
  355. request_timeout=request_timeout,
  356. max_completion_tokens=max_completion_tokens,
  357. total_time=total_elapsed_time,
  358. requests_per_second=requests_per_second,
  359. total_tokens=total_tokens,
  360. total_prompt_tokens=prompt_tokens,
  361. total_completion_tokens=completion_tokens,
  362. total_tokens_per_second=total_tokens_per_second,
  363. total_prompt_tokens_per_second=total_prompt_tokens_per_second,
  364. total_completion_tokens_per_second=total_completion_tokens_per_second,
  365. latency=PercentileResults(
  366. average=avg_latency,
  367. p50=latency_percentiles[0],
  368. p95=latency_percentiles[1],
  369. p99=latency_percentiles[2],
  370. ),
  371. completion_tokens_per_second=PercentileResults(
  372. average=avg_completion_tokens_per_second,
  373. p50=completion_tps_percentiles[0],
  374. p95=completion_tps_percentiles[1],
  375. p99=completion_tps_percentiles[2],
  376. ),
  377. time_to_first_token=PercentileResults(
  378. average=avg_ttft,
  379. p50=ttft_percentiles[0],
  380. p95=ttft_percentiles[1],
  381. p99=ttft_percentiles[2],
  382. ),
  383. )
  384. def fmt_line(label, *values, width=40):
  385. label_part = f"{label:<{width}}"
  386. value_part = " ".join(str(v) for v in values)
  387. return f"{label_part}{value_part}"
  388. def fmt_float(v, suffix=""):
  389. return f"{v:.2f}{suffix}"
  390. def output_benchmark_results_pretty(
  391. results: BenchmarkResults, file: str = None, embeddings: bool = False
  392. ):
  393. lines = []
  394. lines.append("============== Serving Benchmark Result ===============")
  395. lines.append(fmt_line("Model:", results.model))
  396. lines.append(
  397. fmt_line(
  398. "Total requests:",
  399. f"{results.successful_requests}/{results.total_requests}({results.success_rate:.2%})",
  400. )
  401. )
  402. lines.append(fmt_line("Concurrency:", results.concurrency))
  403. lines.append(fmt_line("Benchmark duration (s):", fmt_float(results.total_time)))
  404. lines.append(
  405. fmt_line("Request throughput (req/s):", fmt_float(results.requests_per_second))
  406. )
  407. lines.append(fmt_line("Total input tokens:", results.total_prompt_tokens))
  408. if not embeddings:
  409. lines.append(fmt_line("Total output tokens:", results.total_completion_tokens))
  410. output_tok_per_sec = (
  411. results.total_completion_tokens / results.total_time
  412. if results.total_time > 0
  413. else 0
  414. )
  415. total_tok_per_sec = (
  416. results.total_tokens / results.total_time if results.total_time > 0 else 0
  417. )
  418. if not embeddings:
  419. lines.append(
  420. fmt_line("Output token throughput (tok/s):", fmt_float(output_tok_per_sec))
  421. )
  422. lines.append(
  423. fmt_line("Total token throughput (tok/s):", fmt_float(total_tok_per_sec))
  424. )
  425. lines.append("------------------- Request Latency -------------------")
  426. lines.append(fmt_line("Average latency (s):", fmt_float(results.latency.average)))
  427. lines.append(fmt_line("P50 latency (s):", fmt_float(results.latency.p50)))
  428. lines.append(fmt_line("P95 latency (s):", fmt_float(results.latency.p95)))
  429. lines.append(fmt_line("P99 latency (s):", fmt_float(results.latency.p99)))
  430. if not embeddings:
  431. lines.append("--------------- Output Token Per Second ---------------")
  432. lines.append(
  433. fmt_line(
  434. "Average TPS (tok/s):",
  435. fmt_float(results.completion_tokens_per_second.average),
  436. )
  437. )
  438. lines.append(
  439. fmt_line(
  440. "P50 TPS (tok/s):", fmt_float(results.completion_tokens_per_second.p50)
  441. )
  442. )
  443. lines.append(
  444. fmt_line(
  445. "P95 TPS (tok/s):", fmt_float(results.completion_tokens_per_second.p95)
  446. )
  447. )
  448. lines.append(
  449. fmt_line(
  450. "P99 TPS (tok/s):", fmt_float(results.completion_tokens_per_second.p99)
  451. )
  452. )
  453. lines.append("----------------- Time to First Token -----------------")
  454. lines.append(
  455. fmt_line(
  456. "Average TTFT (ms):", fmt_float(results.time_to_first_token.average)
  457. )
  458. )
  459. lines.append(
  460. fmt_line("P50 TTFT (ms):", fmt_float(results.time_to_first_token.p50))
  461. )
  462. lines.append(
  463. fmt_line("P95 TTFT (ms):", fmt_float(results.time_to_first_token.p95))
  464. )
  465. lines.append(
  466. fmt_line("P99 TTFT (ms):", fmt_float(results.time_to_first_token.p99))
  467. )
  468. lines.append("=" * 55)
  469. output = "\n".join(lines)
  470. if file:
  471. with open(file, "w") as f:
  472. f.write(output + "\n")
  473. logging.info(f"Pretty benchmark results saved to {file}")
  474. else:
  475. print(output)
  476. def output_benchmark_results_json(
  477. results: BenchmarkResults, result_file=None, embeddings: bool = False
  478. ):
  479. # Round all floats in results to two decimal places for output
  480. def _round_floats(obj, ndigits=2):
  481. if is_dataclass(obj):
  482. obj = asdict(obj)
  483. if isinstance(obj, dict):
  484. return {k: _round_floats(v, ndigits) for k, v in obj.items()}
  485. if isinstance(obj, list):
  486. return [_round_floats(v, ndigits) for v in obj]
  487. if isinstance(obj, float):
  488. return round(obj, ndigits)
  489. return obj
  490. formatted_results = _round_floats(results, 2)
  491. if result_file:
  492. with open(result_file, "w") as f:
  493. json.dump(formatted_results, f, indent=2)
  494. logging.info(f"Results saved to {result_file}")
  495. else:
  496. print(json.dumps(formatted_results, indent=2))
  497. if __name__ == "__main__":
  498. parser = argparse.ArgumentParser(description="Benchmark Chat Completions API")
  499. parser.add_argument(
  500. "-m", "--model", type=str, required=True, help="Name of the model"
  501. )
  502. parser.add_argument(
  503. "-n",
  504. "--num-requests",
  505. type=int,
  506. default=100,
  507. help="Number of requests to make (default: 100)",
  508. )
  509. parser.add_argument(
  510. "-c",
  511. "--concurrency",
  512. type=int,
  513. default=10,
  514. help="Number of concurrent requests (default: 10)",
  515. )
  516. parser.add_argument(
  517. "--request-timeout",
  518. type=int,
  519. default=300,
  520. help="Timeout for each request in seconds (default: 300)",
  521. )
  522. parser.add_argument(
  523. "--max-completion-tokens",
  524. type=int,
  525. default=1024,
  526. help="Maximum number of tokens in the completion (default: 1024)",
  527. )
  528. parser.add_argument(
  529. "--prompt-multiplier",
  530. type=int,
  531. default=1,
  532. help="Repeat the randomly selected prompt N times to create longer inputs",
  533. )
  534. parser.add_argument(
  535. '--ignore-eos',
  536. action='store_true',
  537. help='Set ignore_eos flag when sending the benchmark request. This will not stop the stream when the model generates an EOS token.',
  538. )
  539. parser.add_argument(
  540. "--server-url",
  541. type=str,
  542. default="http://127.0.0.1",
  543. help="URL of the GPUStack server",
  544. )
  545. parser.add_argument("--api-key", type=str, default="fake", help="GPUStack API key")
  546. parser.add_argument(
  547. "--result-file",
  548. type=str,
  549. help="Result file path to save benchmark json results",
  550. )
  551. parser.add_argument(
  552. "-H",
  553. "--header",
  554. action="append",
  555. dest="headers",
  556. help="Custom HTTP header in Key:Value format. May be specified multiple times.",
  557. )
  558. parser.add_argument(
  559. '--embeddings',
  560. action='store_true',
  561. help='Run embedding benchmark instead of chat completions',
  562. )
  563. parser.add_argument(
  564. '--json',
  565. action='store_true',
  566. help='Output results in JSON format instead of pretty format',
  567. )
  568. args = parser.parse_args()
  569. try:
  570. results = asyncio.run(
  571. main(
  572. args.model,
  573. args.num_requests,
  574. args.concurrency,
  575. args.request_timeout,
  576. args.max_completion_tokens,
  577. args.ignore_eos,
  578. args.server_url,
  579. args.api_key,
  580. args.headers,
  581. args.embeddings,
  582. args.prompt_multiplier,
  583. )
  584. )
  585. if args.json:
  586. output_benchmark_results_json(
  587. results, args.result_file, embeddings=args.embeddings
  588. )
  589. else:
  590. output_benchmark_results_pretty(
  591. results, args.result_file, embeddings=args.embeddings
  592. )
  593. except Exception as e:
  594. logging.error(f"Benchmarking failed: {str(e)}")
  595. exit(1)