test_gpu.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import pytest
  2. from gpustack.utils.gpu import parse_gpu_id, compare_compute_capability
  3. expected_matched_inputs = {
  4. "worker1:cuda:0": {"worker_name": "worker1", "device": "cuda", "gpu_index": "0"},
  5. "worker_name:npu:12": {
  6. "worker_name": "worker_name",
  7. "device": "npu",
  8. "gpu_index": "12",
  9. },
  10. "test_worker:rocm:3": {
  11. "worker_name": "test_worker",
  12. "device": "rocm",
  13. "gpu_index": "3",
  14. },
  15. "example:musa:7": {"worker_name": "example", "device": "musa", "gpu_index": "7"},
  16. "name:example:musa:7": {
  17. "worker_name": "name:example",
  18. "device": "musa",
  19. "gpu_index": "7",
  20. },
  21. "name:example:mps:100": {
  22. "worker_name": "name:example",
  23. "device": "mps",
  24. "gpu_index": "100",
  25. },
  26. }
  27. expected_not_matched_inputs = [
  28. "worker1:cuda:not_a_number",
  29. ]
  30. @pytest.mark.unit
  31. def test_parse_gpu_id():
  32. for input, expected_output in expected_matched_inputs.items():
  33. is_matched, result = parse_gpu_id(input)
  34. assert is_matched, f"Expected {input} to be matched but it was not."
  35. assert result.get("worker_name") == expected_output.get(
  36. "worker_name"
  37. ), f"Expected worker_name to be {expected_output.get('worker_name')} but got {result.get('worker_name')}"
  38. assert result.get("device") == expected_output.get(
  39. "device"
  40. ), f"Expected device to be {expected_output.get('device')} but got {result.get('device')}"
  41. assert result.get("gpu_index") == expected_output.get(
  42. "gpu_index"
  43. ), f"Expected gpu_index to be {expected_output.get('gpu_index')} but got {result.get('gpu_index')}"
  44. for input in expected_not_matched_inputs:
  45. is_matched, result = parse_gpu_id(input)
  46. assert not is_matched, f"Expected {input} to not be matched but it was."
  47. assert result is None, f"Expected result to be None but got {result}"
  48. @pytest.mark.parametrize(
  49. "current, target, expected",
  50. [
  51. # Equal cases
  52. ("8.0", "8.0", 0),
  53. ("7.5", "7.5", 0),
  54. # Greater cases
  55. ("8.0", "7.5", 1),
  56. ("8.6", "8.0", 1),
  57. ("9.0", "8.9", 1),
  58. ("10.0", "9.0", 1),
  59. # Less cases
  60. ("7.5", "8.0", -1),
  61. ("8.0", "8.6", -1),
  62. ("8.9", "9.0", -1),
  63. # Invalid current, valid target -> -1
  64. (None, "8.0", -1),
  65. ("", "8.0", -1),
  66. (" ", "8.0", -1),
  67. ("invalid", "8.0", -1),
  68. ("8", "8.0", -1),
  69. ("8.", "8.0", -1),
  70. (".0", "8.0", -1),
  71. ("8.0.1", "8.0", -1),
  72. # Valid current, invalid target -> 1
  73. ("8.0", None, 1),
  74. ("8.0", "", 1),
  75. ("8.0", " ", 1),
  76. ("8.0", "invalid", 1),
  77. ("8.0", "8", 1),
  78. ("8.0", "8.", 1),
  79. ("8.0", ".0", 1),
  80. ("8.0", "8.0.1", 1),
  81. # Both invalid -> 0
  82. (None, None, 0),
  83. ("", "", 0),
  84. (" ", " ", 0),
  85. ("invalid", "invalid", 0),
  86. (None, "", 0),
  87. ("8", "invalid", 0),
  88. ("8.", ".0", 0),
  89. ("-1.0", "-2.0", 0),
  90. # Whitespace normalization
  91. (" 8.0 ", "8.0", 0),
  92. ("8.0", " 8.0 ", 0),
  93. (" 8.6 ", " 8.0 ", 1),
  94. ],
  95. )
  96. def test_compare_compute_capability(current, target, expected):
  97. assert compare_compute_capability(current, target) == expected