model_instance_workers.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. from gpustack.schemas.models import ModelInstance
  4. @dataclass(frozen=True)
  5. class ModelInstanceWorkerMatch:
  6. is_main_worker: bool = False
  7. subordinate_worker_indexes: tuple[int, ...] = ()
  8. @property
  9. def matched(self) -> bool:
  10. return self.is_main_worker or bool(self.subordinate_worker_indexes)
  11. def get_model_instance_worker_match(
  12. instance: ModelInstance,
  13. *,
  14. worker_name: Optional[str] = None,
  15. worker_id: Optional[int] = None,
  16. ) -> ModelInstanceWorkerMatch:
  17. is_main_worker = False
  18. if worker_name is not None and instance.worker_name == worker_name:
  19. is_main_worker = True
  20. if worker_id is not None and instance.worker_id == worker_id:
  21. is_main_worker = True
  22. subordinate_worker_indexes = []
  23. subordinate_workers = (
  24. instance.distributed_servers.subordinate_workers
  25. if instance.distributed_servers
  26. and instance.distributed_servers.subordinate_workers
  27. else []
  28. )
  29. for index, subordinate_worker in enumerate(subordinate_workers):
  30. if worker_name is not None and subordinate_worker.worker_name == worker_name:
  31. subordinate_worker_indexes.append(index)
  32. continue
  33. if worker_id is not None and subordinate_worker.worker_id == worker_id:
  34. subordinate_worker_indexes.append(index)
  35. return ModelInstanceWorkerMatch(
  36. is_main_worker=is_main_worker,
  37. subordinate_worker_indexes=tuple(subordinate_worker_indexes),
  38. )