test_backend.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. import types
  2. import pytest
  3. from gpustack.schemas.inference_backend import (
  4. InferenceBackend,
  5. VersionConfig,
  6. VersionConfigDict,
  7. )
  8. from gpustack.schemas.models import BackendEnum
  9. from gpustack.utils.config import apply_registry_override_to_image
  10. from gpustack.worker.backends.custom import CustomServer
  11. from gpustack.worker.backends.sglang import (
  12. SGLangServer,
  13. get_access_log_arguments as get_sglang_access_log_arguments,
  14. get_cache_report_arguments as get_sglang_cache_report_arguments,
  15. )
  16. from gpustack.worker.backends.vllm import (
  17. VLLMServer,
  18. get_access_log_arguments as get_vllm_access_log_arguments,
  19. get_cache_report_arguments as get_vllm_cache_report_arguments,
  20. )
  21. from gpustack.worker.backends.vox_box import VoxBoxServer
  22. @pytest.mark.parametrize(
  23. "image_name, container_registry, expect_image_name, fallback_registry",
  24. [
  25. (
  26. "ghcr.io/ggml-org/llama.cpp:server",
  27. "test-registry.io",
  28. "ghcr.io/ggml-org/llama.cpp:server",
  29. None,
  30. ),
  31. (
  32. "gpustack/runner:cuda12.8-vllm0.10.2",
  33. "test-registry.io",
  34. "test-registry.io/gpustack/runner:cuda12.8-vllm0.10.2",
  35. None,
  36. ),
  37. (
  38. "foo/bar",
  39. "test-registry.io",
  40. "test-registry.io/foo/bar",
  41. None,
  42. ),
  43. ("ubuntu:24.04", "test-registry.io", "test-registry.io/ubuntu:24.04", None),
  44. (
  45. "gpustack/runner:cuda12.8-vllm0.10.2",
  46. None,
  47. "quay.io/gpustack/runner:cuda12.8-vllm0.10.2",
  48. "quay.io",
  49. ),
  50. (
  51. "lmsysorg/sglang:v0.5.5",
  52. "",
  53. "lmsysorg/sglang:v0.5.5",
  54. None,
  55. ),
  56. ],
  57. )
  58. @pytest.mark.asyncio
  59. async def test_apply_registry_override(
  60. image_name,
  61. container_registry,
  62. expect_image_name,
  63. fallback_registry,
  64. monkeypatch,
  65. ):
  66. backend = CustomServer.__new__(CustomServer)
  67. # CustomServer inherits _apply_registry_override from InferenceServer,
  68. # and _apply_registry_override accesses self._config.system_default_container_registry.
  69. # Since we constructed the instance via __new__ (without __init__),
  70. # the _config attribute does not exist. We attach a minimal stub config here.
  71. backend._config = types.SimpleNamespace(
  72. system_default_container_registry=container_registry,
  73. )
  74. backend._fallback_registry = fallback_registry
  75. assert (
  76. apply_registry_override_to_image(
  77. backend._config, image_name, backend._fallback_registry
  78. )
  79. == expect_image_name
  80. )
  81. if container_registry:
  82. backend._config = types.SimpleNamespace(system_default_container_registry=None)
  83. assert (
  84. apply_registry_override_to_image(
  85. backend._config, image_name, backend._fallback_registry
  86. )
  87. == image_name
  88. )
  89. @pytest.mark.parametrize(
  90. "backend_parameters, expected",
  91. [
  92. (
  93. ["--ctx-size 1024"],
  94. ["--ctx-size", "1024"],
  95. ),
  96. (
  97. ["--served-model-name foo"],
  98. ["--served-model-name", "foo"],
  99. ),
  100. (
  101. ['--served-model-name "foo bar"'],
  102. ["--served-model-name", "foo bar"],
  103. ),
  104. (
  105. ['--arg1', '--arg2 "val with spaces"'],
  106. ['--arg1', '--arg2', 'val with spaces'],
  107. ),
  108. (
  109. ['--arg1 "val with spaces"', '--arg2="val with spaces"'],
  110. ['--arg1', 'val with spaces', '--arg2="val with spaces"'],
  111. ),
  112. (
  113. [
  114. """--hf-overrides '{"architectures": ["NewModel"]}'""",
  115. """--hf-overrides={"architectures": ["NewModel"]}""",
  116. ],
  117. [
  118. '--hf-overrides',
  119. '{"architectures": ["NewModel"]}',
  120. """--hf-overrides={"architectures": ["NewModel"]}""",
  121. ],
  122. ),
  123. # Test cases for whitespace handling
  124. (
  125. [" --ctx-size=1024"],
  126. ["--ctx-size=1024"],
  127. ),
  128. (
  129. ["--ctx-size =1024"],
  130. ["--ctx-size=1024"],
  131. ),
  132. (
  133. [" --ctx-size =1024"],
  134. ["--ctx-size=1024"],
  135. ),
  136. (
  137. ["--ctx-size = 1024"],
  138. ["--ctx-size=1024"],
  139. ),
  140. (
  141. [" --ctx-size 1024"],
  142. ["--ctx-size", "1024"],
  143. ),
  144. (
  145. [" --max-model-len=8192"],
  146. ["--max-model-len=8192"],
  147. ),
  148. (
  149. ["--foo =bar", " --baz = qux"],
  150. ["--foo=bar", "--baz=qux"],
  151. ),
  152. (
  153. None,
  154. [],
  155. ),
  156. ],
  157. )
  158. def test_flatten_backend_param(backend_parameters, expected):
  159. backend = CustomServer.__new__(CustomServer)
  160. backend._model = types.SimpleNamespace(backend_parameters=backend_parameters)
  161. assert backend._flatten_backend_param() == expected
  162. @pytest.mark.parametrize(
  163. "backend_parameters, backend_version, expected",
  164. [
  165. (None, None, []),
  166. ([], "0.15.2", []),
  167. ([], "0.16.0", ["--disable-access-log-for-endpoints", "/metrics"]),
  168. (
  169. ["--disable-access-log-for-endpoints=/health,/metrics"],
  170. "0.16.0",
  171. [],
  172. ),
  173. (
  174. ["--disable-access-log-for-endpoints", "/health,/metrics"],
  175. "0.16.0",
  176. [],
  177. ),
  178. ],
  179. )
  180. def test_vllm_access_log_arguments(backend_parameters, backend_version, expected):
  181. assert (
  182. get_vllm_access_log_arguments(backend_parameters, backend_version) == expected
  183. )
  184. @pytest.mark.parametrize(
  185. "backend_parameters, backend_version, expected",
  186. [
  187. (None, None, []),
  188. ([], "0.5.8", []),
  189. ([], "0.5.8.post1", ["--uvicorn-access-log-exclude-prefixes", "/metrics"]),
  190. (
  191. ["--uvicorn-access-log-exclude-prefixes=/health"],
  192. "0.5.8.post1",
  193. [],
  194. ),
  195. (
  196. ["--uvicorn-access-log-exclude-prefixes", "/health"],
  197. "0.5.8.post1",
  198. [],
  199. ),
  200. ],
  201. )
  202. def test_sglang_access_log_arguments(backend_parameters, backend_version, expected):
  203. assert (
  204. get_sglang_access_log_arguments(backend_parameters, backend_version) == expected
  205. )
  206. @pytest.mark.parametrize(
  207. "backend_parameters, backend_version, expected",
  208. [
  209. # Unknown version: do not inject (we cannot version-gate it).
  210. (None, None, []),
  211. # Below the v0.9.0.1 cutoff: skipped (V1 silently dropped the field).
  212. ([], "0.9.0", []),
  213. # At/after the cutoff: injected.
  214. ([], "0.9.0.1", ["--enable-prompt-tokens-details"]),
  215. ([], "0.10.0", ["--enable-prompt-tokens-details"]),
  216. # User explicitly opted in: do not duplicate.
  217. (["--enable-prompt-tokens-details"], "0.10.0", []),
  218. # User explicitly opted out: respect their choice.
  219. (["--no-enable-prompt-tokens-details"], "0.10.0", []),
  220. # Prefix-caching flags are not GPUStack's responsibility — left to the user.
  221. (["--enable-prefix-caching"], "0.10.0", ["--enable-prompt-tokens-details"]),
  222. ],
  223. )
  224. def test_vllm_cache_report_arguments(backend_parameters, backend_version, expected):
  225. assert (
  226. get_vllm_cache_report_arguments(backend_parameters, backend_version) == expected
  227. )
  228. @pytest.mark.parametrize(
  229. "backend_parameters, backend_version, expected",
  230. [
  231. # Unknown version: do not inject (we cannot version-gate it).
  232. (None, None, []),
  233. # Below the v0.3.4 cutoff: skipped.
  234. ([], "0.3.3", []),
  235. # At/after the cutoff: injected.
  236. ([], "0.3.4", ["--enable-cache-report"]),
  237. ([], "0.5.8.post1", ["--enable-cache-report"]),
  238. # User already passed it: do not duplicate.
  239. (["--enable-cache-report"], "0.5.8.post1", []),
  240. ],
  241. )
  242. def test_sglang_cache_report_arguments(backend_parameters, backend_version, expected):
  243. assert (
  244. get_sglang_cache_report_arguments(backend_parameters, backend_version)
  245. == expected
  246. )
  247. def test_vllm_set_cache_env_defaults_to_config_cache_dir(tmp_path):
  248. backend = VLLMServer.__new__(VLLMServer)
  249. backend._config = types.SimpleNamespace(cache_dir=str(tmp_path))
  250. env = {}
  251. backend._set_cache_env(env)
  252. expected = tmp_path / "vllm"
  253. assert env["VLLM_CACHE_ROOT"] == str(expected)
  254. assert expected.is_dir()
  255. def test_vllm_set_cache_env_respects_user_override(tmp_path):
  256. backend = VLLMServer.__new__(VLLMServer)
  257. backend._config = types.SimpleNamespace(cache_dir=str(tmp_path))
  258. env = {"VLLM_CACHE_ROOT": "/custom/cache"}
  259. backend._set_cache_env(env)
  260. assert env["VLLM_CACHE_ROOT"] == "/custom/cache"
  261. # Default cache dir should not be created when the user overrode it.
  262. assert not (tmp_path / "vllm").exists()
  263. def test_vllm_command_args_include_late_system_flags_as_injected():
  264. backend = VLLMServer.__new__(VLLMServer)
  265. backend.inference_backend = None
  266. backend._model_path = "/models/llm"
  267. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  268. backend._model_instance = types.SimpleNamespace(
  269. model_name="llm",
  270. gpu_indexes=[],
  271. ports=[4000],
  272. computed_resource_claim=None,
  273. )
  274. backend._model = types.SimpleNamespace(
  275. backend=BackendEnum.VLLM,
  276. backend_parameters=[],
  277. backend_version=None,
  278. categories=[],
  279. extended_kv_cache=None,
  280. speculative_config=None,
  281. )
  282. backend._derive_max_model_len = lambda: None
  283. backend._get_speculative_arguments = lambda: []
  284. backend._get_selected_gpu_devices = lambda: [
  285. types.SimpleNamespace(vendor="NVIDIA", arch_family=None)
  286. ]
  287. arguments, injected = backend._build_command_args(port=4000, is_distributed=False)
  288. assert arguments[-6:] == [
  289. "--host",
  290. "192.168.50.10",
  291. "--port",
  292. "4000",
  293. "--served-model-name",
  294. "llm",
  295. ]
  296. assert injected == [
  297. "--host",
  298. "192.168.50.10",
  299. "--port",
  300. "4000",
  301. "--served-model-name",
  302. "llm",
  303. ]
  304. def test_vllm_command_args_exclude_user_backend_parameters_from_injected():
  305. backend = VLLMServer.__new__(VLLMServer)
  306. backend.inference_backend = None
  307. backend._model_path = "/models/llm"
  308. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  309. backend._model_instance = types.SimpleNamespace(
  310. model_name="llm",
  311. gpu_indexes=[],
  312. ports=[4000],
  313. computed_resource_claim=None,
  314. )
  315. backend._model = types.SimpleNamespace(
  316. backend=BackendEnum.VLLM,
  317. backend_parameters=["--host", "0.0.0.0", "--temperature", "0.2"],
  318. backend_version=None,
  319. categories=[],
  320. extended_kv_cache=None,
  321. speculative_config=None,
  322. )
  323. backend._derive_max_model_len = lambda: None
  324. backend._get_speculative_arguments = lambda: []
  325. backend._get_selected_gpu_devices = lambda: [
  326. types.SimpleNamespace(vendor="NVIDIA", arch_family=None)
  327. ]
  328. arguments, injected = backend._build_command_args(port=4000, is_distributed=False)
  329. assert "--temperature" in arguments
  330. assert "--temperature" not in injected
  331. assert "--host" not in injected
  332. assert injected == ["--port", "4000", "--served-model-name", "llm"]
  333. def test_sglang_command_args_include_model_and_late_system_flags_as_injected():
  334. backend = SGLangServer.__new__(SGLangServer)
  335. backend.inference_backend = None
  336. backend._model_path = "/models/llm"
  337. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  338. backend._model_instance = types.SimpleNamespace(
  339. gpu_indexes=[],
  340. ports=[4000],
  341. computed_resource_claim=None,
  342. )
  343. backend._model = types.SimpleNamespace(
  344. backend_parameters=[],
  345. backend_version=None,
  346. env={"GPUSTACK_DISABLE_METRICS": "1"},
  347. extended_kv_cache=None,
  348. speculative_config=None,
  349. )
  350. backend._derive_max_model_len = lambda: None
  351. backend._get_model_architecture = lambda: []
  352. backend._get_speculative_arguments = lambda: []
  353. backend._get_hicache_arguments = lambda: []
  354. backend._get_selected_gpu_devices = lambda: [
  355. types.SimpleNamespace(vendor="NVIDIA", arch_family=None)
  356. ]
  357. _, injected = backend._build_command_args(
  358. port=4000,
  359. is_distributed=False,
  360. is_distributed_leader=False,
  361. )
  362. assert injected == [
  363. "--model-path",
  364. "/models/llm",
  365. "--host",
  366. "192.168.50.10",
  367. "--port",
  368. "4000",
  369. ]
  370. def test_vox_box_command_args_return_injected_parameters():
  371. backend = VoxBoxServer.__new__(VoxBoxServer)
  372. backend.inference_backend = None
  373. backend._model_path = "/models/audio"
  374. backend._config = types.SimpleNamespace(data_dir="/var/lib/gpustack")
  375. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  376. backend._model_instance = types.SimpleNamespace(gpu_indexes=[1])
  377. backend._model = types.SimpleNamespace(backend_parameters=[], backend_version=None)
  378. _, injected = backend._build_command_args(port=4000)
  379. assert injected == [
  380. "--model",
  381. "/models/audio",
  382. "--data-dir",
  383. "/var/lib/gpustack",
  384. "--host",
  385. "192.168.50.10",
  386. "--port",
  387. "4000",
  388. "--device",
  389. "cuda:1",
  390. ]
  391. def test_custom_command_args_return_injected_parameters_after_entrypoint():
  392. backend = CustomServer.__new__(CustomServer)
  393. backend._model_path = "/models/custom"
  394. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  395. backend._model_instance = types.SimpleNamespace(ports=[4000])
  396. backend._model = types.SimpleNamespace(
  397. backend_parameters=["--temperature", "0.2"],
  398. backend_version=None,
  399. env={},
  400. name="custom-model",
  401. run_command="python -m custom.launch --model-path {{model_path}} --port {{port}}",
  402. )
  403. backend.inference_backend = types.SimpleNamespace(
  404. replace_command_param=lambda **_: (
  405. "python -m custom.launch --model-path /models/custom --port 4000"
  406. )
  407. )
  408. arguments, injected = backend._build_command_args()
  409. assert arguments[-2:] == ["--temperature", "0.2"]
  410. assert injected == ["--model-path", "/models/custom", "--port", "4000"]
  411. def test_custom_command_args_include_short_flags_as_injected():
  412. backend = CustomServer.__new__(CustomServer)
  413. backend._model_path = "/models/custom"
  414. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  415. backend._model_instance = types.SimpleNamespace(ports=[4000])
  416. backend._model = types.SimpleNamespace(
  417. backend_parameters=["-u", "1"],
  418. backend_version=None,
  419. env={},
  420. name="custom-model",
  421. run_command="custom-server -s 0.0.0.0 -t 4",
  422. )
  423. backend.inference_backend = types.SimpleNamespace(
  424. replace_command_param=lambda **_: "custom-server -s 0.0.0.0 -t 4"
  425. )
  426. _, injected = backend._build_command_args()
  427. assert injected == ["-s", "0.0.0.0", "-t", "4"]
  428. def test_injected_parameters_start_at_zero_with_explicit_container_entrypoint():
  429. backend = CustomServer.__new__(CustomServer)
  430. backend._model_path = "/models/custom"
  431. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  432. backend._model_instance = types.SimpleNamespace(ports=[4000])
  433. backend._model = types.SimpleNamespace(
  434. backend_parameters=["-u", "1"],
  435. backend_version=None,
  436. env={},
  437. name="custom-model",
  438. run_command="-m /models/custom -t 4",
  439. )
  440. backend.inference_backend = types.SimpleNamespace(
  441. replace_command_param=lambda **_: "-m /models/custom -t 4"
  442. )
  443. _, injected = backend._build_command_args(entrypoint=["llama-server"])
  444. assert injected == ["-m", "/models/custom", "-t", "4"]
  445. @pytest.mark.parametrize(
  446. "default_entrypoint, version_entrypoint, default_run_command, expected_entrypoint, expected_injected",
  447. [
  448. (
  449. "llama-server",
  450. None,
  451. "-m {{model_path}} -p {{port}}",
  452. ["llama-server"],
  453. ["-m", "/models/custom", "-p", "4000"],
  454. ),
  455. (
  456. "unused-entrypoint",
  457. "python -m custom.launch",
  458. "--model-path {{model_path}} --port {{port}}",
  459. ["python", "-m", "custom.launch"],
  460. ["--model-path", "/models/custom", "--port", "4000"],
  461. ),
  462. ],
  463. )
  464. def test_custom_backend_configured_entrypoint_injected_parameters(
  465. default_entrypoint,
  466. version_entrypoint,
  467. default_run_command,
  468. expected_entrypoint,
  469. expected_injected,
  470. ):
  471. backend = CustomServer.__new__(CustomServer)
  472. backend._model_path = "/models/custom"
  473. backend._worker = types.SimpleNamespace(ip="192.168.50.10")
  474. backend._model_instance = types.SimpleNamespace(ports=[4000])
  475. backend._model = types.SimpleNamespace(
  476. backend_parameters=["--user-param", "1"],
  477. backend_version="cpu",
  478. env={},
  479. name="custom-model",
  480. run_command=None,
  481. )
  482. backend.inference_backend = InferenceBackend(
  483. backend_name="custom-entrypoint-backend",
  484. default_version="cpu",
  485. default_entrypoint=default_entrypoint,
  486. default_run_command=default_run_command,
  487. version_configs=VersionConfigDict(
  488. root={
  489. "cpu": VersionConfig(
  490. image_name="custom/backend:cpu",
  491. entrypoint=version_entrypoint,
  492. custom_framework="cpu",
  493. )
  494. }
  495. ),
  496. )
  497. entrypoint = backend.inference_backend.get_container_entrypoint("cpu")
  498. arguments, injected = backend._build_command_args(entrypoint=entrypoint)
  499. assert entrypoint == expected_entrypoint
  500. assert arguments[-2:] == ["--user-param", "1"]
  501. assert injected == expected_injected