benchmark.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. from dataclasses import dataclass
  2. from datetime import datetime
  3. from enum import Enum
  4. from typing import Any, ClassVar, Dict, List, Optional
  5. from pydantic import BaseModel
  6. from sqlalchemy import JSON, Column
  7. from sqlmodel import Field, ForeignKey, Integer, SQLModel, Text
  8. from gpustack.schemas.common import (
  9. ListParams,
  10. PaginatedList,
  11. pydantic_column_type,
  12. )
  13. from gpustack.mixins import BaseModelMixin
  14. from gpustack.schemas.models import (
  15. ComputedResourceClaim,
  16. ExtendedKVCacheConfig,
  17. SpeculativeConfig,
  18. )
  19. from gpustack.schemas.workers import GPUDeviceInfo, OperatingSystemInfo
  20. DATASET_RANDOM = "Random"
  21. DATASET_SHAREGPT = "ShareGPT"
  22. class BenchmarkStateEnum(str, Enum):
  23. r"""
  24. Enum for Benchmark State
  25. Transitions:
  26. |- - Server - -|- - - - - - - Worker - - - - - - -|
  27. | | |
  28. PENDING ---> ---> ---> QUEUED ---> RUNNING ---> COMPLETED/STOPPED/ERROR
  29. ^ ^
  30. | |
  31. |----------|
  32. |
  33. |(Worker unreachable)
  34. v
  35. UNREACHABLE
  36. """
  37. PENDING = "pending"
  38. QUEUED = "queued"
  39. RUNNING = "running"
  40. COMPLETED = "completed"
  41. STOPPED = "stopped"
  42. ERROR = "error"
  43. UNREACHABLE = "unreachable"
  44. def __str__(self):
  45. return self.value
  46. class ModelInstanceRuntimeInfo(BaseModel):
  47. computed_resource_claim: Optional[ComputedResourceClaim]
  48. ports: Optional[List[int]]
  49. worker_id: Optional[int] = None
  50. worker_name: Optional[str] = None
  51. worker_ip: Optional[str] = None
  52. gpu_type: Optional[str] = None
  53. gpu_indexes: Optional[List[int]] = None
  54. gpu_ids: Optional[List[str]] = None
  55. class ModelInstanceSnapshot(ModelInstanceRuntimeInfo):
  56. id: int
  57. name: str
  58. resolved_path: Optional[str] = None
  59. # resource info
  60. state: Optional[str] = None
  61. state_message: Optional[str] = None
  62. # backend info
  63. backend: Optional[str] = None
  64. backend_version: Optional[str] = None
  65. api_detected_backend_version: Optional[str] = None
  66. backend_parameters: Optional[List[str]] = Field(sa_type=JSON, default=None)
  67. injected_backend_parameters: Optional[List[str]] = Field(sa_type=JSON, default=None)
  68. image_name: Optional[str] = None
  69. run_command: Optional[str] = Field(sa_type=Text, default=None)
  70. env: Optional[Dict[str, str]] = Field(sa_type=JSON, default=None)
  71. # Extended KV Cache configuration. Currently maps to LMCache config in vLLM and SGLang.
  72. extended_kv_cache: Optional[ExtendedKVCacheConfig] = Field(
  73. sa_type=pydantic_column_type(ExtendedKVCacheConfig), default=None
  74. )
  75. speculative_config: Optional[SpeculativeConfig] = Field(
  76. sa_type=pydantic_column_type(SpeculativeConfig), default=None
  77. )
  78. # subordinate workers info
  79. subordinate_workers: Optional[List[ModelInstanceRuntimeInfo]] = None
  80. class WorkerSnapshot(BaseModel):
  81. id: int
  82. name: str
  83. cpu_total: Optional[int] = None
  84. memory_total: Optional[int] = None
  85. os: Optional[OperatingSystemInfo] = None
  86. class GPUSnapshot(GPUDeviceInfo):
  87. id: str
  88. worker_id: int
  89. worker_name: str
  90. memory_total: Optional[int] = None
  91. core_total: Optional[int] = None
  92. @dataclass
  93. class BenchmarkDeploymentMetadata:
  94. name: str
  95. labels: dict[str, str]
  96. class BenchmarkBase(SQLModel):
  97. name: str = Field(index=True, unique=True)
  98. description: Optional[str] = Field(
  99. sa_type=Text,
  100. nullable=True,
  101. default=None,
  102. )
  103. profile: Optional[str] = Field(default="Custom")
  104. dataset_name: Optional[str] = Field(
  105. default=None
  106. ) # denormalized field for easier query
  107. dataset_input_tokens: Optional[int] = Field(default=None)
  108. dataset_output_tokens: Optional[int] = Field(default=None)
  109. dataset_seed: Optional[int] = Field(default=None)
  110. cluster_id: int = Field(default=None)
  111. model_id: Optional[int] = Field(default=None)
  112. model_name: Optional[str] = Field(
  113. default=None
  114. ) # denormalized field for easier query
  115. model_instance_name: str
  116. request_rate: int = Field(default=10) # requests per second
  117. total_requests: Optional[int] = Field(
  118. default=None
  119. ) # total number of requests to send
  120. # Benchmark state fields
  121. state: BenchmarkStateEnum = Field(
  122. default=BenchmarkStateEnum.PENDING,
  123. index=True,
  124. )
  125. state_message: Optional[str] = Field(
  126. default=None, sa_column=Column(Text, nullable=True)
  127. )
  128. progress: Optional[float] = Field(default=None)
  129. worker_id: Optional[int] = Field(default=None)
  130. pid: Optional[int] = Field(default=None)
  131. def get_deployment_metadata(
  132. self,
  133. ) -> Optional[BenchmarkDeploymentMetadata]:
  134. """
  135. Get the deployment metadata for the benchmark.
  136. """
  137. return BenchmarkDeploymentMetadata(
  138. name=self.name,
  139. labels={
  140. "benchmark-name": self.name,
  141. "model-instance-name": self.model_instance_name or "",
  142. "type": "benchmark",
  143. },
  144. )
  145. ModelInstanceSnapshots = Dict[str, ModelInstanceSnapshot]
  146. WorkerSnapshots = Dict[str, WorkerSnapshot]
  147. GPUSnapshots = Dict[str, GPUSnapshot]
  148. class BenchmarkSnapshot(BaseModel):
  149. instances: Optional[ModelInstanceSnapshots] = None
  150. workers: Optional[WorkerSnapshots] = None
  151. gpus: Optional[GPUSnapshots] = None
  152. class BenchmarkMetricsLite(SQLModel):
  153. requests_per_second_mean: Optional[float] = Field(
  154. default=None, description="Mean requests per second (unit: req/s)"
  155. )
  156. request_latency_mean: Optional[float] = Field(
  157. default=None, description="Mean request latency (unit: seconds)"
  158. )
  159. time_per_output_token_mean: Optional[float] = Field(
  160. default=None, description="Mean time per output token (unit: ms)"
  161. )
  162. inter_token_latency_mean: Optional[float] = Field(
  163. default=None, description="Mean inter-token latency (unit: ms)"
  164. )
  165. time_to_first_token_mean: Optional[float] = Field(
  166. default=None, description="Mean time to first token (unit: ms)"
  167. )
  168. tokens_per_second_mean: Optional[float] = Field(
  169. default=None, description="Mean tokens per second (unit: tok/s)"
  170. )
  171. output_tokens_per_second_mean: Optional[float] = Field(
  172. default=None, description="Mean output tokens per second (unit: tok/s)"
  173. )
  174. input_tokens_per_second_mean: Optional[float] = Field(
  175. default=None, description="Mean prompt tokens per second (unit: tok/s)"
  176. )
  177. request_concurrency_mean: Optional[float] = Field(
  178. default=None,
  179. description="Mean request concurrency (unit: number of concurrent requests)",
  180. )
  181. request_concurrency_max: Optional[float] = Field(
  182. default=None,
  183. description="Max request concurrency (unit: number of concurrent requests)",
  184. )
  185. request_total: Optional[int] = Field(
  186. default=None, description="Total number of requests made"
  187. )
  188. request_successful: Optional[int] = Field(
  189. default=None, description="Total number of successful requests"
  190. )
  191. request_errored: Optional[int] = Field(
  192. default=None, description="Total number of errored requests"
  193. )
  194. request_incomplete: Optional[int] = Field(
  195. default=None, description="Total number of incomplete requests"
  196. )
  197. class BenchmarkMetrics(BenchmarkMetricsLite):
  198. raw_metrics: Optional[Dict[str, Any]] = Field(
  199. sa_column=Column(JSON), default=None
  200. ) # deferred loading of potentially large field
  201. class BenchmarkWithSnapshots(BenchmarkBase):
  202. snapshot: Optional[BenchmarkSnapshot] = Field(
  203. default=None,
  204. sa_column=Column(pydantic_column_type(BenchmarkSnapshot)),
  205. )
  206. gpu_summary: Optional[str] = Field(
  207. default=None, sa_column=Column(Text, nullable=True)
  208. )
  209. gpu_vendor_summary: Optional[str] = Field(
  210. default=None, sa_column=Column(Text, nullable=True)
  211. )
  212. class Benchmark(BenchmarkWithSnapshots, BenchmarkMetrics, BaseModelMixin, table=True):
  213. id: Optional[int] = Field(default=None, primary_key=True)
  214. # Tenant scope. Server-derived from cluster on creation.
  215. owner_principal_id: Optional[int] = Field(
  216. default=None,
  217. sa_column=Column(Integer, ForeignKey("principals.id"), nullable=True),
  218. )
  219. __tablename__ = 'benchmarks'
  220. class BenchmarkListParams(ListParams):
  221. sortable_fields: ClassVar[List[str]] = [
  222. "name",
  223. "dataset_name",
  224. "model_name",
  225. "state",
  226. "created_at",
  227. "updated_at",
  228. # metrics fields
  229. "requests_per_second_mean",
  230. "request_latency_mean",
  231. "time_per_output_token_mean",
  232. "inter_token_latency_mean",
  233. "time_to_first_token_mean",
  234. "tokens_per_second_mean",
  235. "output_tokens_per_second_mean",
  236. "input_tokens_per_second_mean",
  237. "request_concurrency_mean",
  238. "request_concurrency_max",
  239. "request_total",
  240. "request_successful",
  241. "request_errored",
  242. "request_incomplete",
  243. ]
  244. class BenchmarkCreate(BenchmarkBase):
  245. pass
  246. class BenchmarkUpdate(SQLModel):
  247. name: str = Field(index=True, unique=True)
  248. description: Optional[str] = Field(
  249. sa_type=Text,
  250. nullable=True,
  251. default=None,
  252. )
  253. class BenchmarkStateUpdate(SQLModel):
  254. state: Optional[BenchmarkStateEnum] = None
  255. state_message: Optional[str] = Field(
  256. default=None, sa_column=Column(Text, nullable=True)
  257. )
  258. pid: Optional[int] = Field(default=None)
  259. progress: Optional[float] = None
  260. class BenchmarkFullPublic(
  261. BenchmarkWithSnapshots,
  262. BenchmarkMetrics,
  263. ):
  264. id: int
  265. created_at: datetime
  266. updated_at: datetime
  267. gpu_summary: Optional[str] = Field(
  268. default=None, sa_column=Column(Text, nullable=True)
  269. )
  270. gpu_vendor_summary: Optional[str] = Field(
  271. default=None, sa_column=Column(Text, nullable=True)
  272. )
  273. class BenchmarkPublic(
  274. BenchmarkWithSnapshots,
  275. BenchmarkMetricsLite,
  276. ):
  277. id: int
  278. created_at: datetime
  279. updated_at: datetime
  280. gpu_summary: Optional[str] = Field(
  281. default=None, sa_column=Column(Text, nullable=True)
  282. )
  283. gpu_vendor_summary: Optional[str] = Field(
  284. default=None, sa_column=Column(Text, nullable=True)
  285. )
  286. BenchmarksPublic = PaginatedList[BenchmarkPublic]