plot_latency_comparison.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. """
  2. Plot latency comparison between baseline and optimized benchmark results.
  3. Examples:
  4. uv run python hack/perf/plot_latency_comparison.py \
  5. --metric latency \
  6. --baseline baseline_r1.json baseline_r4.json baseline_r8.json baseline_r16.json \
  7. --optimized optimized_r1.json optimized_r4.json optimized_r8.json optimized_r16.json \
  8. --output latency_comparison.png
  9. uv run python hack/perf/plot_latency_comparison.py \
  10. --metric ttft \
  11. --baseline baseline_r1.json baseline_r4.json baseline_r8.json baseline_r16.json \
  12. --optimized optimized_r1.json optimized_r4.json optimized_r8.json optimized_r16.json \
  13. --output ttft_comparison.png
  14. """
  15. from __future__ import annotations
  16. import argparse
  17. import json
  18. from dataclasses import dataclass
  19. from pathlib import Path
  20. import matplotlib.pyplot as plt
  21. from matplotlib.axes import Axes
  22. METRIC_FIELDS = {
  23. "latency": ("request_latency_mean", "Latency", "s"),
  24. "ttft": ("time_to_first_token_mean", "TTFT", "ms"),
  25. "itl": ("inter_token_latency_mean", "ITL", "ms"),
  26. "tpot": ("time_per_output_token_mean", "TPOT", "ms"),
  27. }
  28. @dataclass(frozen=True)
  29. class BenchmarkPoint:
  30. request_rate: float
  31. request_latency_mean: float
  32. request_concurrency_mean: float
  33. time_to_first_token_mean: float
  34. inter_token_latency_mean: float
  35. time_per_output_token_mean: float
  36. source: Path
  37. def parse_args() -> argparse.Namespace:
  38. parser = argparse.ArgumentParser(
  39. formatter_class=argparse.RawDescriptionHelpFormatter,
  40. description=(
  41. "Compare two groups of benchmark result JSON files and plot "
  42. "request rate vs. mean latency."
  43. ),
  44. epilog="""\
  45. Examples:
  46. uv run python hack/perf/plot_latency_comparison.py \
  47. --metric latency \
  48. --baseline baseline_r1.json baseline_r4.json baseline_r8.json baseline_r16.json \
  49. --optimized optimized_r1.json optimized_r4.json optimized_r8.json optimized_r16.json \
  50. --output latency_comparison.png
  51. uv run python hack/perf/plot_latency_comparison.py \
  52. --metric ttft \
  53. --baseline baseline_r1.json baseline_r4.json baseline_r8.json baseline_r16.json \
  54. --optimized optimized_r1.json optimized_r4.json optimized_r8.json optimized_r16.json \
  55. --output ttft_comparison.png
  56. """,
  57. )
  58. parser.add_argument(
  59. "--baseline",
  60. nargs="+",
  61. required=True,
  62. help="Baseline benchmark JSON files.",
  63. )
  64. parser.add_argument(
  65. "--optimized",
  66. nargs="+",
  67. required=True,
  68. help="GPUStack-Optimized benchmark JSON files.",
  69. )
  70. parser.add_argument(
  71. "--output",
  72. default="latency_comparison.png",
  73. help="Output image path. Default: latency_comparison.png",
  74. )
  75. parser.add_argument(
  76. "--title",
  77. default="Latency Comparison: Baseline vs. GPUStack-Optimized",
  78. help="Plot title.",
  79. )
  80. parser.add_argument(
  81. "--figsize",
  82. nargs=2,
  83. type=float,
  84. metavar=("WIDTH", "HEIGHT"),
  85. default=(11, 7),
  86. help="Figure size in inches. Default: 11 7",
  87. )
  88. parser.add_argument(
  89. "--metric",
  90. choices=sorted(METRIC_FIELDS.keys()),
  91. default="latency",
  92. help="Metric to plot on y-axis. Default: latency",
  93. )
  94. return parser.parse_args()
  95. def load_point(path_str: str) -> BenchmarkPoint:
  96. path = Path(path_str)
  97. with path.open() as f:
  98. payload = json.load(f)
  99. required_fields = [
  100. "request_rate",
  101. "request_latency_mean",
  102. "request_concurrency_mean",
  103. "time_to_first_token_mean",
  104. "inter_token_latency_mean",
  105. "time_per_output_token_mean",
  106. ]
  107. missing = [field for field in required_fields if payload.get(field) is None]
  108. if missing:
  109. missing_str = ", ".join(missing)
  110. raise ValueError(f"{path} is missing required fields: {missing_str}")
  111. return BenchmarkPoint(
  112. request_rate=float(payload["request_rate"]),
  113. request_latency_mean=float(payload["request_latency_mean"]),
  114. request_concurrency_mean=float(payload["request_concurrency_mean"]),
  115. time_to_first_token_mean=float(payload["time_to_first_token_mean"]),
  116. inter_token_latency_mean=float(payload["inter_token_latency_mean"]),
  117. time_per_output_token_mean=float(payload["time_per_output_token_mean"]),
  118. source=path,
  119. )
  120. def load_series(paths: list[str]) -> list[BenchmarkPoint]:
  121. points = [load_point(path) for path in paths]
  122. return sorted(points, key=lambda point: point.request_rate)
  123. def metric_value(point: BenchmarkPoint, metric: str) -> float:
  124. field_name = METRIC_FIELDS[metric][0]
  125. return float(getattr(point, field_name))
  126. def metric_label(metric: str) -> str:
  127. title, unit = METRIC_FIELDS[metric][1], METRIC_FIELDS[metric][2]
  128. return f"{title} ({unit})"
  129. def annotate_series(
  130. ax: Axes,
  131. points: list[BenchmarkPoint],
  132. color: str,
  133. direction: int,
  134. metric: str,
  135. ) -> None:
  136. x_offsets = [10, 18, -54, -62, 14, -48]
  137. y_magnitudes = [18, 30, 22, 34, 26, 38]
  138. short_name, unit = METRIC_FIELDS[metric][1].lower(), METRIC_FIELDS[metric][2]
  139. for index, point in enumerate(points):
  140. x_offset = x_offsets[index % len(x_offsets)]
  141. y_offset = y_magnitudes[index % len(y_magnitudes)] * direction
  142. label = (
  143. f"{short_name}={metric_value(point, metric):.2f}{unit}\n"
  144. f"rps={point.request_rate:.0f}\n"
  145. f"conc={point.request_concurrency_mean:.2f}"
  146. )
  147. ax.annotate(
  148. label,
  149. xy=(point.request_rate, metric_value(point, metric)),
  150. xytext=(x_offset, y_offset),
  151. textcoords="offset points",
  152. ha="left" if x_offset >= 0 else "right",
  153. va="bottom" if direction > 0 else "top",
  154. fontsize=9,
  155. color=color,
  156. bbox={
  157. "boxstyle": "round,pad=0.25",
  158. "fc": "white",
  159. "ec": color,
  160. "alpha": 0.85,
  161. },
  162. arrowprops={"arrowstyle": "-", "color": color, "alpha": 0.5},
  163. )
  164. def annotate_speedup_arrows(
  165. ax: Axes,
  166. baseline_points: list[BenchmarkPoint],
  167. optimized_points: list[BenchmarkPoint],
  168. metric: str,
  169. ) -> None:
  170. baseline_by_rate = {point.request_rate: point for point in baseline_points}
  171. optimized_by_rate = {point.request_rate: point for point in optimized_points}
  172. shared_rates = sorted(set(baseline_by_rate) & set(optimized_by_rate))
  173. for request_rate in shared_rates:
  174. baseline_point = baseline_by_rate[request_rate]
  175. optimized_point = optimized_by_rate[request_rate]
  176. optimized_metric_value = metric_value(optimized_point, metric)
  177. baseline_metric_value = metric_value(baseline_point, metric)
  178. if optimized_metric_value <= 0:
  179. continue
  180. speedup = baseline_metric_value / optimized_metric_value
  181. mid_y = (baseline_metric_value + optimized_metric_value) / 2
  182. if speedup >= 1:
  183. speedup_label = f"x{speedup:.2f} faster"
  184. else:
  185. speedup_label = f"x{(1 / speedup):.2f} slower"
  186. ax.annotate(
  187. "",
  188. xy=(request_rate, optimized_metric_value),
  189. xytext=(request_rate, baseline_metric_value),
  190. arrowprops={
  191. "arrowstyle": "->",
  192. "color": "#54A24B",
  193. "lw": 2,
  194. "alpha": 0.9,
  195. },
  196. )
  197. ax.text(
  198. request_rate,
  199. mid_y,
  200. speedup_label,
  201. ha="left",
  202. va="center",
  203. fontsize=9,
  204. color="#54A24B",
  205. bbox={
  206. "boxstyle": "round,pad=0.2",
  207. "fc": "white",
  208. "ec": "#54A24B",
  209. "alpha": 0.9,
  210. },
  211. )
  212. def plot_series(
  213. baseline_points: list[BenchmarkPoint],
  214. optimized_points: list[BenchmarkPoint],
  215. title: str,
  216. output_path: str,
  217. figsize: tuple[float, float],
  218. metric: str,
  219. ) -> None:
  220. fig, ax = plt.subplots(figsize=figsize)
  221. baseline_x = [point.request_rate for point in baseline_points]
  222. baseline_y = [metric_value(point, metric) for point in baseline_points]
  223. optimized_x = [point.request_rate for point in optimized_points]
  224. optimized_y = [metric_value(point, metric) for point in optimized_points]
  225. ax.plot(
  226. baseline_x,
  227. baseline_y,
  228. marker="o",
  229. linewidth=2,
  230. color="#4C78A8",
  231. label="Baseline",
  232. )
  233. ax.plot(
  234. optimized_x,
  235. optimized_y,
  236. marker="o",
  237. linewidth=2,
  238. color="#F58518",
  239. label="GPUStack-Optimized",
  240. )
  241. annotate_series(ax, baseline_points, color="#4C78A8", direction=1, metric=metric)
  242. annotate_series(ax, optimized_points, color="#F58518", direction=-1, metric=metric)
  243. annotate_speedup_arrows(ax, baseline_points, optimized_points, metric=metric)
  244. all_latencies = baseline_y + optimized_y
  245. if all_latencies:
  246. min_latency = min(all_latencies)
  247. max_latency = max(all_latencies)
  248. latency_span = max_latency - min_latency
  249. padding = max(1.0, latency_span * 0.35)
  250. lower_bound = max(0, min_latency - padding * 0.45)
  251. upper_bound = max_latency + padding
  252. ax.set_ylim(lower_bound, upper_bound)
  253. ax.set_xlabel("Request/s")
  254. ax.set_ylabel(metric_label(metric))
  255. ax.set_title(title)
  256. ax.grid(True, linestyle="--", alpha=0.4)
  257. ax.legend(loc="upper left")
  258. fig.tight_layout(rect=(0, 0, 1, 0.96))
  259. fig.savefig(output_path, dpi=300)
  260. print(f"Saved plot to {output_path}")
  261. def main() -> None:
  262. args = parse_args()
  263. baseline_points = load_series(args.baseline)
  264. optimized_points = load_series(args.optimized)
  265. plot_series(
  266. baseline_points=baseline_points,
  267. optimized_points=optimized_points,
  268. title=args.title,
  269. output_path=args.output,
  270. figsize=tuple(args.figsize),
  271. metric=args.metric,
  272. )
  273. if __name__ == "__main__":
  274. main()