strategies.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from abc import ABC, abstractmethod
  2. import logging
  3. from typing import Dict, List
  4. import itertools
  5. from gpustack.schemas.models import ModelInstance
  6. logger = logging.getLogger(__name__)
  7. class LoadBalancingStrategy(ABC):
  8. @abstractmethod
  9. async def select_instance(self, instances: List[ModelInstance]) -> ModelInstance:
  10. pass
  11. class RoundRobinStrategy(LoadBalancingStrategy):
  12. def __init__(self):
  13. self._iterators: Dict[int, itertools.cycle] = {}
  14. self._instance_lists: Dict[int, List[ModelInstance]] = {}
  15. async def select_instance(self, instances: List[ModelInstance]) -> ModelInstance:
  16. if len(instances) == 0:
  17. raise Exception("No instances available")
  18. model_id = instances[0].model_id
  19. if (
  20. model_id not in self._iterators
  21. or self._instance_lists[model_id] != instances
  22. ):
  23. logger.debug(f"Creating new iterator for model {model_id}")
  24. self._iterators[model_id] = itertools.cycle(instances)
  25. self._instance_lists[model_id] = instances
  26. return next(self._iterators[model_id])