test_serve_manager.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from datetime import datetime, timezone
  2. from pathlib import Path
  3. from types import SimpleNamespace
  4. from unittest.mock import ANY, MagicMock, patch
  5. from gpustack.schemas.models import (
  6. BackendEnum,
  7. DistributedServerCoordinateModeEnum,
  8. DistributedServers,
  9. ModelInstanceSubordinateWorker,
  10. ModelInstanceStateEnum,
  11. )
  12. from gpustack.worker.serve_manager import ServeManager
  13. from tests.utils.model import new_model, new_model_instance
  14. def _build_serve_manager(worker_id: int = 1):
  15. clientset = MagicMock()
  16. clientset.model_instances.list.return_value = SimpleNamespace(items=[])
  17. cfg = SimpleNamespace(log_dir="/tmp")
  18. manager = ServeManager(lambda: worker_id, lambda: clientset, cfg)
  19. manager._inference_backend_manager = MagicMock()
  20. return manager, clientset
  21. def test_sync_model_instances_state_marks_main_unreachable_when_subordinate_unreachable():
  22. manager, clientset = _build_serve_manager()
  23. model_instance = new_model_instance(
  24. 1,
  25. "distributed-instance",
  26. 1,
  27. worker_id=1,
  28. state=ModelInstanceStateEnum.RUNNING,
  29. )
  30. model_instance.worker_ip = "127.0.0.1"
  31. model_instance.port = 8000
  32. model_instance.distributed_servers = DistributedServers(
  33. mode=DistributedServerCoordinateModeEnum.RUN_FIRST,
  34. subordinate_workers=[
  35. ModelInstanceSubordinateWorker(
  36. worker_id=2,
  37. worker_name="worker-2",
  38. worker_ip="10.0.0.2",
  39. state=ModelInstanceStateEnum.UNREACHABLE,
  40. state_message="Worker is unreachable from the server",
  41. )
  42. ],
  43. )
  44. clientset.model_instances.list.return_value = SimpleNamespace(
  45. items=[model_instance]
  46. )
  47. model = new_model(1, "test", 1, huggingface_repo_id="Qwen/Qwen2.5-0.5B-Instruct")
  48. model.backend = BackendEnum.VLLM
  49. model.backend_version = "0.8.0"
  50. with (
  51. patch(
  52. "gpustack.worker.serve_manager.get_workload",
  53. return_value=SimpleNamespace(state="running"),
  54. ),
  55. patch.object(manager, "_is_provisioning", return_value=False),
  56. patch.object(manager, "_get_model", return_value=model),
  57. patch.object(manager, "_update_model_instance") as update_model_instance,
  58. ):
  59. manager.sync_model_instances_state()
  60. update_model_instance.assert_called_once_with(
  61. model_instance.id,
  62. state=ModelInstanceStateEnum.UNREACHABLE,
  63. state_message=(
  64. "Distributed serving unreachable in subordinate worker "
  65. "10.0.0.2: Worker is unreachable from the server."
  66. ),
  67. )
  68. def test_restart_error_model_instance_uses_transient_backoff_count():
  69. manager, _ = _build_serve_manager()
  70. model_instance = new_model_instance(
  71. 1,
  72. "restarted-instance",
  73. 1,
  74. worker_id=1,
  75. state=ModelInstanceStateEnum.ERROR,
  76. )
  77. model_instance.restart_count = 20
  78. model_instance.last_restart_time = datetime.now(timezone.utc)
  79. with (
  80. patch.object(manager, "_is_provisioning", return_value=False),
  81. patch.object(manager, "_update_model_instance") as update_model_instance,
  82. patch("gpustack.worker.serve_manager.logger"),
  83. ):
  84. manager._restart_error_model_instance(model_instance)
  85. update_model_instance.assert_called_once_with(
  86. model_instance.id,
  87. restart_count=21,
  88. last_restart_time=ANY,
  89. state=ModelInstanceStateEnum.SCHEDULED,
  90. state_message="",
  91. )
  92. def test_restart_model_instance_preserves_transient_backoff_count():
  93. manager, _ = _build_serve_manager()
  94. model_instance = new_model_instance(
  95. 1,
  96. "restarted-instance",
  97. 1,
  98. worker_id=1,
  99. state=ModelInstanceStateEnum.SCHEDULED,
  100. )
  101. manager._restart_backoff_counts[model_instance.id] = 1
  102. with (
  103. patch.object(manager, "_is_provisioning", return_value=False),
  104. patch.object(manager, "_start_model_instance"),
  105. ):
  106. manager._restart_model_instance(model_instance)
  107. assert manager._restart_backoff_counts[model_instance.id] == 1
  108. def test_cleanup_old_logs_keeps_only_current_and_previous_restart(tmp_path: Path):
  109. """Keep main/container logs for R and R-1; delete older restart_count files."""
  110. serve_dir = tmp_path / "serve"
  111. serve_dir.mkdir(parents=True)
  112. mid = 42
  113. for name in (
  114. f"{mid}.0.log",
  115. f"{mid}.1.log",
  116. f"{mid}.2.log",
  117. f"{mid}.container.0.log",
  118. f"{mid}.container.1.log",
  119. f"{mid}.container.2.log",
  120. ):
  121. (serve_dir / name).write_text("x", encoding="utf-8")
  122. manager, _clients = _build_serve_manager()
  123. manager._serve_log_dir = str(serve_dir)
  124. manager._cleanup_old_logs(mid, 2)
  125. remaining = sorted(p.name for p in serve_dir.iterdir())
  126. assert remaining == [
  127. f"{mid}.1.log",
  128. f"{mid}.2.log",
  129. f"{mid}.container.1.log",
  130. f"{mid}.container.2.log",
  131. ]
  132. def test_cleanup_old_logs_restart_zero_keeps_only_zero(tmp_path: Path):
  133. serve_dir = tmp_path / "serve"
  134. serve_dir.mkdir(parents=True)
  135. mid = 7
  136. for name in (f"{mid}.0.log", f"{mid}.1.log", f"{mid}.container.1.log"):
  137. (serve_dir / name).write_text("x", encoding="utf-8")
  138. manager, _clients = _build_serve_manager()
  139. manager._serve_log_dir = str(serve_dir)
  140. manager._cleanup_old_logs(mid, 0)
  141. remaining = sorted(p.name for p in serve_dir.iterdir())
  142. assert remaining == [f"{mid}.0.log"]