worker_syncer.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import asyncio
  2. import logging
  3. import aiohttp
  4. from typing import Callable, Optional
  5. from gpustack.schemas.workers import Worker, WorkerStateEnum
  6. from gpustack.server.db import async_session
  7. from gpustack.server.services import WorkerService
  8. from gpustack.server.worker_request import is_worker_reachable
  9. from gpustack import envs
  10. logger = logging.getLogger(__name__)
  11. class WorkerSyncer:
  12. """
  13. WorkerSyncer syncs worker status periodically.
  14. 1. Performs worker reachability checks.
  15. 2. Performs readiness checks based on heartbeats.
  16. """
  17. def __init__(
  18. self,
  19. http_client_getter: Callable[[], Optional[aiohttp.ClientSession]],
  20. http_client_no_proxy_getter: Callable[[], Optional[aiohttp.ClientSession]],
  21. interval=15,
  22. worker_unreachable_timeout=20,
  23. ):
  24. self._interval = interval
  25. self._worker_unreachable_timeout = worker_unreachable_timeout
  26. self._http_client_getter = http_client_getter
  27. self._http_client_no_proxy_getter = http_client_no_proxy_getter
  28. logger.debug(
  29. f"WorkerSyncer initialized with unreachable check mode: {envs.WORKER_UNREACHABLE_CHECK_MODE}"
  30. )
  31. async def start(self):
  32. client = self._http_client_getter()
  33. while True:
  34. await asyncio.sleep(self._interval)
  35. try:
  36. client = client or self._http_client_getter()
  37. if client is None:
  38. logger.debug("HTTP client not available, skipping worker sync")
  39. continue
  40. await self._sync_workers_states()
  41. except Exception as e:
  42. logger.error(f"Failed to sync workers: {e}")
  43. async def _sync_workers_states(self):
  44. """
  45. Sync workers' states by checking their reachability and readiness.
  46. """
  47. async with async_session() as session:
  48. all_workers = await Worker.all(session)
  49. if not all_workers:
  50. return
  51. if self._should_check_unreachable(len(all_workers)):
  52. tasks = [
  53. self._set_worker_unreachable(worker)
  54. for worker in all_workers
  55. if not worker.state.is_provisioning
  56. ]
  57. await asyncio.gather(*tasks)
  58. state_changed_workers = self.filter_state_change_workers(all_workers)
  59. should_update_workers = []
  60. state_to_worker_name = {
  61. WorkerStateEnum.NOT_READY: [],
  62. WorkerStateEnum.UNREACHABLE: [],
  63. WorkerStateEnum.READY: [],
  64. WorkerStateEnum.MAINTENANCE: [],
  65. }
  66. for worker in state_changed_workers:
  67. if worker and worker.state in state_to_worker_name:
  68. should_update_workers.append(worker)
  69. state_to_worker_name[worker.state].append(worker.name)
  70. async with async_session() as session:
  71. for worker in should_update_workers:
  72. # reload from DB and update states only
  73. to_update_worker = await WorkerService(session).get_by_id(worker.id)
  74. if to_update_worker:
  75. to_update_worker.unreachable = worker.unreachable
  76. to_update_worker.state = worker.state
  77. to_update_worker.state_message = worker.state_message
  78. await WorkerService(session).update(to_update_worker)
  79. for state, worker_names in state_to_worker_name.items():
  80. if worker_names:
  81. logger.info(f"Marked worker {', '.join(worker_names)} as {state}")
  82. def _should_check_unreachable(self, worker_count: int) -> bool:
  83. """
  84. Determine if unreachable check should be performed based on mode and worker count.
  85. Args:
  86. worker_count: Total number of workers
  87. Returns:
  88. True if unreachable check should be performed, False otherwise
  89. """
  90. mode = envs.WORKER_UNREACHABLE_CHECK_MODE
  91. auto_threshold = 50 # Auto-disable threshold for worker count
  92. if mode == "disabled":
  93. return False
  94. elif mode == "enabled":
  95. return True
  96. elif mode == "auto":
  97. if worker_count > auto_threshold:
  98. return False
  99. return True
  100. else:
  101. logger.warning(
  102. f"Invalid WORKER_UNREACHABLE_CHECK_MODE: {mode}, defaulting to 'auto'"
  103. )
  104. # Default to auto behavior
  105. return worker_count <= auto_threshold
  106. async def _set_worker_unreachable(self, worker: Worker):
  107. worker.unreachable = not await is_worker_reachable(
  108. worker=worker,
  109. proxy_client=self._http_client_getter(),
  110. no_proxy_client=self._http_client_no_proxy_getter(),
  111. timeout_in_second=self._worker_unreachable_timeout,
  112. )
  113. @staticmethod
  114. def filter_state_change_workers(workers: list[Worker]) -> list[Worker]:
  115. """
  116. Filter workers whose state has changed.
  117. Args:
  118. workers: List of workers to check
  119. Returns:
  120. List of workers whose state has changed
  121. """
  122. state_changed_workers = []
  123. for worker in workers:
  124. original_worker_state = worker.state
  125. original_worker_state_message = worker.state_message
  126. worker.compute_state()
  127. if (
  128. worker.state != original_worker_state
  129. or worker.state_message != original_worker_state_message
  130. ):
  131. state_changed_workers.append(worker)
  132. return state_changed_workers