runner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import json
  2. import logging
  3. import os
  4. import sys
  5. from typing import Dict, List, Optional
  6. from gpustack.client.generated_clientset import ClientSet
  7. from gpustack.config.config import Config, set_global_config
  8. from gpustack.config.registration import read_worker_token
  9. from gpustack.envs import BENCHMARK_DATASET_SHAREGPT_PATH, BENCHMARK_REQUEST_TIMEOUT
  10. from gpustack.logging import setup_logging
  11. from gpustack.schemas.benchmark import (
  12. DATASET_RANDOM,
  13. DATASET_SHAREGPT,
  14. Benchmark,
  15. BenchmarkDeploymentMetadata,
  16. BenchmarkStateEnum,
  17. ModelInstanceSnapshot,
  18. )
  19. from gpustack.utils.command import find_bool_parameter
  20. from gpustack.utils.config import apply_registry_override_to_image
  21. from gpustack.utils.envs import filter_env_vars, sanitize_env
  22. from gpustack_runtime.logging import setup_logging as setup_runtime_logging
  23. from gpustack_runtime import envs as runtime_envs
  24. from gpustack_runtime.deployer import ContainerMount
  25. from gpustack_runtime.deployer import (
  26. Container,
  27. ContainerEnv,
  28. ContainerExecution,
  29. ContainerProfileEnum,
  30. WorkloadPlan,
  31. create_workload,
  32. ContainerRestartPolicyEnum,
  33. )
  34. from gpustack.utils.profiling import time_decorator
  35. from gpustack.utils.runtime import transform_workload_plan
  36. logger = logging.getLogger(__name__)
  37. class BenchmarkRunner:
  38. _clientset: ClientSet
  39. _config: Config
  40. _benchmark: Benchmark
  41. _model_path: str
  42. _model_endpoint: str
  43. _model_backend_parameters: Optional[List[str]]
  44. _api_url: str
  45. _api_key: str
  46. _benchmark_dir: Optional[str]
  47. _fallback_registry: Optional[str] = None
  48. """The fallback container registry to use if needed."""
  49. @time_decorator
  50. def __init__(
  51. self,
  52. clientset: ClientSet,
  53. benchmark: Benchmark,
  54. cfg: Config,
  55. fallback_registry: Optional[str] = None,
  56. ):
  57. setup_logging(debug=cfg.debug)
  58. setup_runtime_logging()
  59. set_global_config(cfg)
  60. try:
  61. self._clientset = clientset
  62. self._benchmark = benchmark
  63. self._config = cfg
  64. self._fallback_registry = fallback_registry
  65. if (
  66. benchmark.snapshot is None
  67. or benchmark.snapshot.instances is None
  68. or len(benchmark.snapshot.instances) == 0
  69. or benchmark.snapshot.instances.get(benchmark.model_instance_name)
  70. is None
  71. ):
  72. raise ValueError(
  73. f"Benchmark {benchmark.name}(id={benchmark.id}) has no snapshot for model instance {benchmark.model_instance_name}"
  74. )
  75. instance_snapshot: ModelInstanceSnapshot = benchmark.snapshot.instances.get(
  76. benchmark.model_instance_name
  77. )
  78. if instance_snapshot.resolved_path is None:
  79. raise ValueError(
  80. f"Benchmark {benchmark.name}(id={benchmark.id}) snapshot for model instance {benchmark.model_instance_name} has no resolved path"
  81. )
  82. if instance_snapshot.worker_ip is None:
  83. raise ValueError(
  84. f"Benchmark {benchmark.name}(id={benchmark.id}) snapshot for model instance {benchmark.model_instance_name} has no worker IP"
  85. )
  86. if instance_snapshot.ports is None or len(instance_snapshot.ports) == 0:
  87. raise ValueError(
  88. f"Benchmark {benchmark.name}(id={benchmark.id}) snapshot for model instance {benchmark.model_instance_name} has no ports"
  89. )
  90. self._benchmark_dir = self._config.benchmark_dir
  91. self._model_path = instance_snapshot.resolved_path
  92. self._model_endpoint = f"http://{instance_snapshot.worker_ip}:{instance_snapshot.ports[0] if instance_snapshot.ports else ''}"
  93. self._model_backend_parameters = instance_snapshot.backend_parameters
  94. _api_key = read_worker_token(self._config.data_dir)
  95. if _api_key is None:
  96. raise ValueError(
  97. f"Worker token not found for benchmark {benchmark.name}(id={benchmark.id}) progress reporting"
  98. )
  99. self._api_key = _api_key
  100. _server_url = self._clientset.base_url
  101. if not _server_url:
  102. raise ValueError(
  103. f"Server URL not configured for benchmark {benchmark.name}(id={benchmark.id}) progress reporting"
  104. )
  105. self._api_url = (
  106. f"{_server_url.rstrip('/')}/v2/benchmarks/{self._benchmark.id}/state"
  107. )
  108. except Exception as e:
  109. error_message = f"Failed to initialize: {e}"
  110. logger.error(error_message)
  111. try:
  112. patch_dict = {
  113. "state_message": error_message,
  114. "state": BenchmarkStateEnum.ERROR,
  115. }
  116. self._update_benchmark_state(benchmark.id, **patch_dict)
  117. except Exception as ue:
  118. logger.error(
  119. f"Failed to update benchmark {benchmark.name}(id={benchmark.id}) state: {ue}"
  120. )
  121. sys.exit(1)
  122. def start(self):
  123. deployment_metadata = self._benchmark.get_deployment_metadata()
  124. env = {}
  125. if not runtime_envs.GPUSTACK_RUNTIME_DEPLOY_MIRRORED_DEPLOYMENT:
  126. env = filter_env_vars(os.environ)
  127. command_args = self._build_command_args()
  128. self._create_workload(
  129. deployment_metadata=deployment_metadata,
  130. command=["benchmark-runner"],
  131. command_args=command_args,
  132. env=env,
  133. )
  134. def _create_workload(
  135. self,
  136. deployment_metadata: BenchmarkDeploymentMetadata,
  137. command: Optional[List[str]],
  138. command_args: List[str],
  139. env: Dict[str, str],
  140. ):
  141. image = apply_registry_override_to_image(
  142. self._config, self._config.benchmark_image_repo, self._fallback_registry
  143. )
  144. if not image:
  145. raise ValueError("Failed to get image for benchmark runner workload")
  146. mounts = self._get_configured_mounts()
  147. run_container = Container(
  148. image=image,
  149. name="default",
  150. profile=ContainerProfileEnum.RUN,
  151. restart_policy=ContainerRestartPolicyEnum.NEVER,
  152. execution=ContainerExecution(
  153. privileged=True,
  154. args=command_args,
  155. ),
  156. envs=[
  157. ContainerEnv(
  158. name=name,
  159. value=value,
  160. )
  161. for name, value in env.items()
  162. ],
  163. mounts=mounts,
  164. )
  165. logger.info(
  166. f"Creating benchmark container workload: {deployment_metadata.name}"
  167. )
  168. logger.info(
  169. f"With image: {image}, "
  170. f"command: [{' '.join(command) if command else ''}], "
  171. f"arguments: [{' '.join(str(arg) for arg in command_args)}], "
  172. f"envs(inconsistent input items mean unchangeable):{os.linesep}"
  173. f"{os.linesep.join(f'{k}={v}' for k, v in sorted(sanitize_env(env).items()))}"
  174. )
  175. workload_plan = WorkloadPlan(
  176. name=deployment_metadata.name,
  177. host_network=True,
  178. shm_size=10 * 1 << 30, # 10 GiB
  179. containers=[run_container],
  180. labels=deployment_metadata.labels,
  181. )
  182. create_workload(
  183. transform_workload_plan(
  184. self._config, workload_plan, self._fallback_registry
  185. )
  186. )
  187. logger.info(f"Created benchmark container workload: {deployment_metadata.name}")
  188. def _build_command_args(self) -> List[str]:
  189. backend_kwargs = {
  190. "timeout": BENCHMARK_REQUEST_TIMEOUT,
  191. "response_handlers": {
  192. "chat_completions": "chat_completions_with_reasoning"
  193. },
  194. }
  195. command_args = [
  196. "benchmark",
  197. "run",
  198. "--target",
  199. self._model_endpoint,
  200. "--profile",
  201. "constant",
  202. "--rate",
  203. str(self._benchmark.request_rate),
  204. "--sample-requests",
  205. "0",
  206. "--processor",
  207. self._model_path,
  208. "--output-dir",
  209. f"{self._benchmark_dir}",
  210. "--outputs",
  211. f"{self._benchmark.id}.dual_json",
  212. "--progress-url",
  213. self._api_url,
  214. "--progress-auth",
  215. self._api_key,
  216. "--backend-kwargs",
  217. json.dumps(backend_kwargs),
  218. "--backend",
  219. "openai_http_error_detail",
  220. ]
  221. if find_bool_parameter(self._model_backend_parameters, ["trust-remote-code"]):
  222. command_args.extend(
  223. [
  224. "--processor-args",
  225. json.dumps({"trust_remote_code": True}),
  226. ]
  227. )
  228. if self._benchmark.dataset_name == DATASET_SHAREGPT:
  229. data = BENCHMARK_DATASET_SHAREGPT_PATH
  230. command_args.extend(["--data", data])
  231. elif (
  232. self._benchmark.dataset_name == DATASET_RANDOM
  233. and self._benchmark.dataset_input_tokens is not None
  234. and self._benchmark.dataset_output_tokens is not None
  235. ):
  236. data = f"prompt_tokens={self._benchmark.dataset_input_tokens},output_tokens={self._benchmark.dataset_output_tokens}"
  237. command_args.extend(["--data", data])
  238. if self._benchmark.dataset_seed is not None:
  239. command_args.extend(
  240. [
  241. "--random-seed",
  242. f"{self._benchmark.dataset_seed}",
  243. ]
  244. )
  245. if (
  246. self._benchmark.total_requests is not None
  247. and self._benchmark.total_requests > 0
  248. ):
  249. command_args.extend(
  250. [
  251. "--max-requests",
  252. f"{self._benchmark.total_requests}",
  253. ]
  254. )
  255. return command_args
  256. def _update_benchmark_state(self, id: int, **kwargs):
  257. resp = self._clientset.http_client.get_httpx_client().patch(
  258. "/benchmarks/{id}/state".format(id=id), json=kwargs
  259. )
  260. resp.raise_for_status()
  261. def _get_configured_mounts(self) -> List[ContainerMount]:
  262. """
  263. Get the volume mounts for the model instance.
  264. If runtime mirrored deployment is enabled, no mounts will be set up.
  265. Returns:
  266. A list of ContainerMount objects for the model instance.
  267. """
  268. mounts: List[ContainerMount] = []
  269. if (
  270. self._model_path
  271. and self._benchmark_dir
  272. and not runtime_envs.GPUSTACK_RUNTIME_DEPLOY_MIRRORED_DEPLOYMENT
  273. ):
  274. model_dir = os.path.dirname(self._model_path)
  275. mounts.extend(
  276. [
  277. ContainerMount(
  278. path=model_dir,
  279. ),
  280. ContainerMount(
  281. path=self._benchmark_dir,
  282. ),
  283. ]
  284. )
  285. return mounts