benchmarks.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. from sqlmodel import col
  2. import yaml
  3. from typing import Optional, Sequence
  4. import aiohttp
  5. from fastapi import APIRouter, Depends, Query, Request, status
  6. from fastapi.responses import PlainTextResponse, StreamingResponse
  7. from sqlmodel import func
  8. from gpustack import envs
  9. from gpustack.api.exceptions import (
  10. AlreadyExistsException,
  11. InternalServerErrorException,
  12. NotFoundException,
  13. BadRequestException,
  14. )
  15. from gpustack.api.responses import StreamingResponseWithStatusCode
  16. from gpustack.api.tenant import (
  17. bypass_tenant_filter,
  18. assert_cluster_resource_visible,
  19. cluster_resource_visibility_conditions,
  20. )
  21. from gpustack.schemas.clusters import Cluster
  22. from gpustack.schemas.models import (
  23. Model,
  24. ModelInstance,
  25. ModelInstanceStateEnum,
  26. is_audio_model,
  27. is_embedding_model,
  28. is_image_model,
  29. is_reranker_model,
  30. )
  31. from gpustack.schemas.workers import Worker
  32. from gpustack.server.db import async_session
  33. from gpustack.server.deps import SessionDep, TenantContextDep
  34. from gpustack.schemas.benchmark import (
  35. DATASET_RANDOM,
  36. DATASET_SHAREGPT,
  37. Benchmark,
  38. BenchmarkCreate,
  39. BenchmarkFullPublic,
  40. BenchmarkListParams,
  41. BenchmarkMetrics,
  42. BenchmarkSnapshot,
  43. BenchmarkStateEnum,
  44. BenchmarkStateUpdate,
  45. BenchmarkUpdate,
  46. BenchmarkPublic,
  47. BenchmarksPublic,
  48. )
  49. from gpustack.server.services import (
  50. WorkerService,
  51. )
  52. from gpustack.server.worker_request import stream_to_worker, request_to_worker
  53. from gpustack.utils.gpu import summary_gpu_snapshots
  54. from gpustack.utils.snapshot import (
  55. create_model_instance_snapshot,
  56. create_worker_snapshot,
  57. )
  58. from gpustack.worker.logs import LogOptionsDep
  59. from sqlalchemy.orm import defer
  60. MAX_EXPORT_RECORDS = 20
  61. BENCHMARK_EXPORT_FIELD_ORDER = [
  62. "name",
  63. "model_name",
  64. "model_instance_name",
  65. "profile",
  66. "dataset_name",
  67. "request_rate",
  68. "total_requests",
  69. "dataset_input_tokens",
  70. "dataset_output_tokens",
  71. "dataset_seed",
  72. ]
  73. router = APIRouter()
  74. def order_benchmark_export_fields(benchmark: dict) -> dict:
  75. ordered = {}
  76. for field in BENCHMARK_EXPORT_FIELD_ORDER:
  77. if field in benchmark:
  78. ordered[field] = benchmark[field]
  79. for field, value in benchmark.items():
  80. if field not in ordered:
  81. ordered[field] = value
  82. return ordered
  83. @router.get("", response_model=BenchmarksPublic)
  84. async def get_benchmarks(
  85. ctx: TenantContextDep,
  86. params: BenchmarkListParams = Depends(),
  87. search: str = None,
  88. state: Optional[BenchmarkStateEnum] = Query(
  89. default=None,
  90. description="Filter by benchmark state.",
  91. ),
  92. model_name: Optional[str] = Query(None, description="Filter by model name."),
  93. gpu_summary: Optional[str] = Query(None, description="Filter by GPU summary."),
  94. dataset_name: Optional[str] = Query(None, description="Filter by dataset name."),
  95. profile: Optional[str] = Query(None, description="Filter by profile."),
  96. ):
  97. return await _get_benchmarks(
  98. ctx=ctx,
  99. params=params,
  100. state=state,
  101. search=search,
  102. model_name=model_name,
  103. gpu_summary=gpu_summary,
  104. dataset_name=dataset_name,
  105. profile=profile,
  106. )
  107. def gpu_summary_filter(data: Benchmark, gpu_summary: Optional[str]) -> bool:
  108. if (
  109. gpu_summary
  110. and data.gpu_summary
  111. and gpu_summary.lower() not in data.gpu_summary.lower()
  112. ):
  113. return False
  114. return True
  115. async def _get_benchmarks(
  116. ctx,
  117. params: BenchmarkListParams,
  118. search: str = None,
  119. state: Optional[BenchmarkStateEnum] = None,
  120. model_name: Optional[str] = None,
  121. gpu_summary: Optional[str] = None,
  122. dataset_name: Optional[str] = None,
  123. profile: Optional[str] = None,
  124. ):
  125. fuzzy_fields = {}
  126. if search:
  127. fuzzy_fields["name"] = search
  128. if profile:
  129. fuzzy_fields["profile"] = profile
  130. fields = {}
  131. if state:
  132. fields["state"] = state
  133. if model_name:
  134. fields["model_name"] = model_name
  135. if dataset_name:
  136. fields["dataset_name"] = dataset_name
  137. extra_conditions = list(cluster_resource_visibility_conditions(ctx, Benchmark))
  138. if gpu_summary:
  139. extra_conditions.append(
  140. func.lower(Benchmark.gpu_summary).like(f"%{gpu_summary}%")
  141. )
  142. def _benchmark_visible(b: Benchmark) -> bool:
  143. if bypass_tenant_filter(ctx):
  144. return True
  145. org_id = getattr(b, "owner_principal_id", None)
  146. if (
  147. ctx.current_principal_id is not None
  148. and org_id is not None
  149. and org_id == ctx.current_principal_id
  150. ):
  151. return True
  152. if getattr(b, "cluster_id", None) in ctx.accessible_cluster_ids:
  153. return True
  154. return False
  155. if params.watch:
  156. return StreamingResponse(
  157. Benchmark.streaming(
  158. fields=fields,
  159. fuzzy_fields=fuzzy_fields,
  160. filter_func=lambda data: _benchmark_visible(data)
  161. and gpu_summary_filter(data, gpu_summary),
  162. ),
  163. media_type="text/event-stream",
  164. )
  165. order_by = params.order_by
  166. if order_by:
  167. new_order_by = []
  168. for field, direction in order_by:
  169. new_order_by.append((field, direction))
  170. if field in [
  171. "dataset_name",
  172. "cluster_id",
  173. "model_id",
  174. "model_name",
  175. "state",
  176. ]:
  177. # add additional sorting fields for deterministic ordering
  178. new_order_by.append(("created_at", direction))
  179. order_by = new_order_by
  180. async with async_session() as session:
  181. return await Benchmark.paginated_by_query(
  182. session=session,
  183. fields=fields,
  184. fuzzy_fields=fuzzy_fields,
  185. page=params.page,
  186. per_page=params.perPage,
  187. order_by=order_by,
  188. extra_conditions=extra_conditions,
  189. options=[defer(Benchmark.raw_metrics)],
  190. )
  191. @router.get("/{id}", response_model=BenchmarkFullPublic)
  192. async def get_benchmark(
  193. session: SessionDep,
  194. ctx: TenantContextDep,
  195. id: int,
  196. ):
  197. benchmark = await Benchmark.one_by_id(session, id)
  198. assert_cluster_resource_visible(
  199. ctx, benchmark, not_found_message=f"Benchmark {id} not found"
  200. )
  201. return benchmark
  202. async def validate_and_mutate_benchmark_in(
  203. session: SessionDep, benchmark_in: BenchmarkCreate
  204. ) -> Benchmark:
  205. if not benchmark_in.model_instance_name.strip():
  206. raise BadRequestException(message="Field model_instance_name must be specified")
  207. mutated = Benchmark(**benchmark_in.model_dump())
  208. instance = await ModelInstance.one_by_field(
  209. session, "name", benchmark_in.model_instance_name
  210. )
  211. if not instance:
  212. raise BadRequestException(
  213. message=f"Model instance '{benchmark_in.model_instance_name}' not found"
  214. )
  215. if instance.state != ModelInstanceStateEnum.RUNNING:
  216. raise BadRequestException(
  217. message=f"Model instance '{benchmark_in.model_instance_name}' not in RUNNING state"
  218. )
  219. if benchmark_in.model_id is None:
  220. mutated.model_id = instance.model_id
  221. mutated.model_name = instance.model_name
  222. if benchmark_in.dataset_name is None:
  223. raise BadRequestException(message="Field dataset_name must be specified")
  224. if benchmark_in.dataset_name not in [DATASET_RANDOM, DATASET_SHAREGPT]:
  225. raise BadRequestException(
  226. message=f"Dataset '{benchmark_in.dataset_name}' is not supported. Supported datasets are '{DATASET_RANDOM}' and '{DATASET_SHAREGPT}'."
  227. )
  228. if benchmark_in.dataset_name == DATASET_RANDOM and (
  229. benchmark_in.dataset_input_tokens is None
  230. or benchmark_in.dataset_output_tokens is None
  231. ):
  232. raise BadRequestException(
  233. message="Fields dataset_input_tokens and dataset_output_tokens must be specified for 'Random' dataset"
  234. )
  235. model = await Model.one_by_id(session, mutated.model_id)
  236. if not model:
  237. raise BadRequestException(message=f"Model {mutated.model_id} not found")
  238. if (
  239. is_image_model(model)
  240. or is_audio_model(model)
  241. or is_embedding_model(model)
  242. or is_reranker_model(model)
  243. ):
  244. raise BadRequestException(
  245. message=f"Benchmarking is not supported for model type '{model.type.value}'"
  246. )
  247. if benchmark_in.request_rate <= 0:
  248. mutated.request_rate = (
  249. benchmark_in.total_requests
  250. if benchmark_in.total_requests is not None
  251. else 1000
  252. ) # treat non-positive request_rate as unlimited
  253. snapshot = await get_benchmark_snapshot(session, instance, model)
  254. mutated.snapshot = snapshot
  255. mutated.gpu_summary, mutated.gpu_vendor_summary = summary_gpu_snapshots(
  256. snapshot.gpus
  257. )
  258. mutated.worker_id = instance.worker_id
  259. # Derive tenant scope from the benchmark's cluster.
  260. if mutated.cluster_id is not None:
  261. cluster = await Cluster.one_by_id(session, mutated.cluster_id)
  262. if cluster is not None:
  263. mutated.owner_principal_id = cluster.owner_principal_id
  264. return mutated
  265. @router.post("", response_model=BenchmarkPublic)
  266. async def create_benchmark(
  267. session: SessionDep, ctx: TenantContextDep, benchmark_in: BenchmarkCreate
  268. ):
  269. existing = await Benchmark.one_by_field(session, "name", benchmark_in.name)
  270. if existing:
  271. raise AlreadyExistsException(
  272. message=f"Benchmark '{benchmark_in.name}' already exists. "
  273. "Please choose a different name or check the existing benchmark."
  274. )
  275. mutated = await validate_and_mutate_benchmark_in(session, benchmark_in)
  276. try:
  277. benchmark = await Benchmark.create(session, mutated)
  278. except Exception as e:
  279. raise InternalServerErrorException(message=f"Failed to create benchmark: {e}")
  280. return benchmark
  281. @router.put("/{id}", response_model=BenchmarkPublic)
  282. async def update_benchmark(
  283. session: SessionDep,
  284. ctx: TenantContextDep,
  285. id: int,
  286. benchmark_in: BenchmarkUpdate,
  287. ):
  288. benchmark = await Benchmark.one_by_id(session, id)
  289. assert_cluster_resource_visible(
  290. ctx, benchmark, not_found_message="Benchmark not found"
  291. )
  292. try:
  293. await benchmark.update(session, benchmark_in)
  294. except Exception as e:
  295. raise InternalServerErrorException(message=f"Failed to update benchmark: {e}")
  296. return benchmark
  297. @router.patch("/{id}/state", response_model=BenchmarkPublic)
  298. async def update_benchmark_state(
  299. session: SessionDep,
  300. ctx: TenantContextDep,
  301. id: int,
  302. state_update: BenchmarkStateUpdate,
  303. ):
  304. benchmark = await Benchmark.one_by_id(session, id)
  305. assert_cluster_resource_visible(
  306. ctx, benchmark, not_found_message="Benchmark not found"
  307. )
  308. if (
  309. state_update.state is not None
  310. and state_update.state == BenchmarkStateEnum.STOPPED
  311. and benchmark.state
  312. not in [
  313. BenchmarkStateEnum.QUEUED,
  314. BenchmarkStateEnum.PENDING,
  315. BenchmarkStateEnum.RUNNING,
  316. ]
  317. ):
  318. raise BadRequestException(
  319. message="Only benchmarks in QUEUED, PENDING, or RUNNING state can be stopped."
  320. )
  321. try:
  322. await benchmark.update(session, state_update)
  323. except Exception as e:
  324. raise InternalServerErrorException(
  325. message=f"Failed to update benchmark state: {e}"
  326. )
  327. return benchmark
  328. async def get_benchmark_snapshot(
  329. session: SessionDep, mi: ModelInstance, model: Model
  330. ) -> BenchmarkSnapshot:
  331. # instance snapshot
  332. worker_snapshots = {}
  333. gpu_snapshots = {}
  334. instance_snapshots = {}
  335. instance_snapshots[mi.name] = create_model_instance_snapshot(mi, model)
  336. w: Worker = await WorkerService(session).get_by_id(mi.worker_id)
  337. w_snapshot, gpus_snapshots = create_worker_snapshot(w, mi.gpu_type, mi.gpu_indexes)
  338. if w_snapshot is not None:
  339. worker_snapshots[w.name] = w_snapshot
  340. if gpus_snapshots is not None:
  341. gpu_snapshots.update(gpus_snapshots)
  342. if mi.distributed_servers and mi.distributed_servers.subordinate_workers:
  343. for sub in mi.distributed_servers.subordinate_workers:
  344. sw: Worker = await WorkerService(session).get_by_id(sub.worker_id)
  345. w_snapshot, gpus_snapshots = create_worker_snapshot(
  346. sw, sub.gpu_type, sub.gpu_indexes
  347. )
  348. if w_snapshot is not None:
  349. worker_snapshots[sw.name] = w_snapshot
  350. if gpus_snapshots is not None:
  351. gpu_snapshots.update(gpus_snapshots)
  352. return BenchmarkSnapshot(
  353. instances=instance_snapshots,
  354. workers=worker_snapshots,
  355. gpus=gpu_snapshots,
  356. )
  357. @router.post("/{id}/metrics", response_model=BenchmarkPublic)
  358. async def update_benchmark_metrics(
  359. session: SessionDep, ctx: TenantContextDep, id: int, metrics: BenchmarkMetrics
  360. ):
  361. benchmark = await Benchmark.one_by_id(session, id)
  362. assert_cluster_resource_visible(
  363. ctx, benchmark, not_found_message="Benchmark not found"
  364. )
  365. try:
  366. await benchmark.update(session, metrics)
  367. except Exception as e:
  368. raise InternalServerErrorException(
  369. message=f"Failed to update benchmark metrics: {e}"
  370. )
  371. return benchmark
  372. @router.delete("/{id}")
  373. async def delete_benchmark(session: SessionDep, ctx: TenantContextDep, id: int):
  374. benchmark = await Benchmark.one_by_id(session, id)
  375. assert_cluster_resource_visible(
  376. ctx, benchmark, not_found_message="Benchmark not found"
  377. )
  378. try:
  379. await benchmark.delete(session)
  380. except Exception as e:
  381. raise InternalServerErrorException(message=f"Failed to delete benchmark: {e}")
  382. @router.get("/{id}/logs")
  383. async def get_benchmark_logs( # noqa: C901
  384. request: Request,
  385. session: SessionDep,
  386. ctx: TenantContextDep,
  387. id: int,
  388. log_options: LogOptionsDep,
  389. ):
  390. benchmark = await Benchmark.one_by_id(session, id)
  391. assert_cluster_resource_visible(
  392. ctx, benchmark, not_found_message="Benchmark not found"
  393. )
  394. worker = await Worker.one_by_id(session, benchmark.worker_id)
  395. if not worker:
  396. raise NotFoundException(message="Benchmark's worker not found")
  397. if benchmark.state in [
  398. BenchmarkStateEnum.ERROR,
  399. BenchmarkStateEnum.STOPPED,
  400. BenchmarkStateEnum.COMPLETED,
  401. ]:
  402. log_options.follow = False
  403. timeout = aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT, sock_connect=5)
  404. if log_options.follow:
  405. def on_exception(e: Exception, t: aiohttp.ClientTimeout) -> tuple[str, int]:
  406. msg = (
  407. str(e)
  408. if not isinstance(e, TimeoutError)
  409. else f"Log stream timed out ({t.total} seconds). Please reopen the log page."
  410. )
  411. return f"\x1b[999;1H{msg}\n", status.HTTP_500_INTERNAL_SERVER_ERROR
  412. return StreamingResponseWithStatusCode(
  413. stream_to_worker(
  414. worker=worker,
  415. method="GET",
  416. path=f"benchmark_logs/{benchmark.id}",
  417. proxy_client=request.app.state.http_client,
  418. no_proxy_client=request.app.state.http_client_no_proxy,
  419. params={
  420. "tail": log_options.tail,
  421. "follow": log_options.follow,
  422. "benchmark_name": benchmark.name,
  423. },
  424. timeout=timeout,
  425. on_exception=on_exception,
  426. raw=True,
  427. ),
  428. media_type="application/octet-stream",
  429. )
  430. else:
  431. resp, body = await request_to_worker(
  432. worker=worker,
  433. method="GET",
  434. path=f"benchmark_logs/{benchmark.id}",
  435. proxy_client=request.app.state.http_client,
  436. no_proxy_client=request.app.state.http_client_no_proxy,
  437. params={
  438. "tail": log_options.tail,
  439. "follow": log_options.follow,
  440. "benchmark_name": benchmark.name,
  441. },
  442. timeout=timeout,
  443. )
  444. return PlainTextResponse(
  445. content=body.decode() if body else "", status_code=resp.status
  446. )
  447. @router.post("/export")
  448. async def export_benchmarks(
  449. session: SessionDep,
  450. ctx: TenantContextDep,
  451. ids: list[int],
  452. ):
  453. if not ids:
  454. raise BadRequestException(message="No benchmark ids provided.")
  455. if len(ids) > MAX_EXPORT_RECORDS:
  456. raise BadRequestException(
  457. message=f"Export up to {MAX_EXPORT_RECORDS} records at most."
  458. )
  459. exclude_fields = [
  460. "id",
  461. "cluster_id",
  462. "owner_principal_id",
  463. "model_id",
  464. "worker_id",
  465. "created_at",
  466. "updated_at",
  467. "pid",
  468. "progress",
  469. "state_message",
  470. "state",
  471. "deleted_at",
  472. ]
  473. extra_conditions = [
  474. col(Benchmark.id).in_(ids),
  475. *cluster_resource_visibility_conditions(ctx, Benchmark),
  476. ]
  477. benchmarks: Sequence[Benchmark] = await Benchmark.all_by_fields(
  478. session, fields={}, extra_conditions=extra_conditions
  479. )
  480. exported_benchmarks = []
  481. for b in benchmarks:
  482. eb = b.model_dump(exclude=set(exclude_fields))
  483. exported_benchmarks.append(order_benchmark_export_fields(eb))
  484. export_data = {"benchmarks": exported_benchmarks}
  485. yaml_str = yaml.safe_dump(export_data, allow_unicode=True, sort_keys=False)
  486. return PlainTextResponse(content=yaml_str, media_type="application/x-yaml")