| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- import asyncio
- import logging
- import aiohttp
- from typing import Callable, Optional
- from gpustack.schemas.workers import Worker, WorkerStateEnum
- from gpustack.server.db import async_session
- from gpustack.server.services import WorkerService
- from gpustack.server.worker_request import is_worker_reachable
- from gpustack import envs
- logger = logging.getLogger(__name__)
- class WorkerSyncer:
- """
- WorkerSyncer syncs worker status periodically.
- 1. Performs worker reachability checks.
- 2. Performs readiness checks based on heartbeats.
- """
- def __init__(
- self,
- http_client_getter: Callable[[], Optional[aiohttp.ClientSession]],
- http_client_no_proxy_getter: Callable[[], Optional[aiohttp.ClientSession]],
- interval=15,
- worker_unreachable_timeout=20,
- ):
- self._interval = interval
- self._worker_unreachable_timeout = worker_unreachable_timeout
- self._http_client_getter = http_client_getter
- self._http_client_no_proxy_getter = http_client_no_proxy_getter
- logger.debug(
- f"WorkerSyncer initialized with unreachable check mode: {envs.WORKER_UNREACHABLE_CHECK_MODE}"
- )
- async def start(self):
- client = self._http_client_getter()
- while True:
- await asyncio.sleep(self._interval)
- try:
- client = client or self._http_client_getter()
- if client is None:
- logger.debug("HTTP client not available, skipping worker sync")
- continue
- await self._sync_workers_states()
- except Exception as e:
- logger.error(f"Failed to sync workers: {e}")
- async def _sync_workers_states(self):
- """
- Sync workers' states by checking their reachability and readiness.
- """
- async with async_session() as session:
- all_workers = await Worker.all(session)
- if not all_workers:
- return
- if self._should_check_unreachable(len(all_workers)):
- tasks = [
- self._set_worker_unreachable(worker)
- for worker in all_workers
- if not worker.state.is_provisioning
- ]
- await asyncio.gather(*tasks)
- state_changed_workers = self.filter_state_change_workers(all_workers)
- should_update_workers = []
- state_to_worker_name = {
- WorkerStateEnum.NOT_READY: [],
- WorkerStateEnum.UNREACHABLE: [],
- WorkerStateEnum.READY: [],
- WorkerStateEnum.MAINTENANCE: [],
- }
- for worker in state_changed_workers:
- if worker and worker.state in state_to_worker_name:
- should_update_workers.append(worker)
- state_to_worker_name[worker.state].append(worker.name)
- async with async_session() as session:
- for worker in should_update_workers:
- # reload from DB and update states only
- to_update_worker = await WorkerService(session).get_by_id(worker.id)
- if to_update_worker:
- to_update_worker.unreachable = worker.unreachable
- to_update_worker.state = worker.state
- to_update_worker.state_message = worker.state_message
- await WorkerService(session).update(to_update_worker)
- for state, worker_names in state_to_worker_name.items():
- if worker_names:
- logger.info(f"Marked worker {', '.join(worker_names)} as {state}")
- def _should_check_unreachable(self, worker_count: int) -> bool:
- """
- Determine if unreachable check should be performed based on mode and worker count.
- Args:
- worker_count: Total number of workers
- Returns:
- True if unreachable check should be performed, False otherwise
- """
- mode = envs.WORKER_UNREACHABLE_CHECK_MODE
- auto_threshold = 50 # Auto-disable threshold for worker count
- if mode == "disabled":
- return False
- elif mode == "enabled":
- return True
- elif mode == "auto":
- if worker_count > auto_threshold:
- return False
- return True
- else:
- logger.warning(
- f"Invalid WORKER_UNREACHABLE_CHECK_MODE: {mode}, defaulting to 'auto'"
- )
- # Default to auto behavior
- return worker_count <= auto_threshold
- async def _set_worker_unreachable(self, worker: Worker):
- worker.unreachable = not await is_worker_reachable(
- worker=worker,
- proxy_client=self._http_client_getter(),
- no_proxy_client=self._http_client_no_proxy_getter(),
- timeout_in_second=self._worker_unreachable_timeout,
- )
- @staticmethod
- def filter_state_change_workers(workers: list[Worker]) -> list[Worker]:
- """
- Filter workers whose state has changed.
- Args:
- workers: List of workers to check
- Returns:
- List of workers whose state has changed
- """
- state_changed_workers = []
- for worker in workers:
- original_worker_state = worker.state
- original_worker_state_message = worker.state_message
- worker.compute_state()
- if (
- worker.state != original_worker_state
- or worker.state_message != original_worker_state_message
- ):
- state_changed_workers.append(worker)
- return state_changed_workers
|