status_scorer.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import logging
  2. from typing import List, Optional
  3. from gpustack.policies.base import ModelInstanceScore, ModelInstanceScorer
  4. from gpustack.schemas.models import Model, ModelInstance, ModelInstanceStateEnum
  5. from gpustack.schemas.workers import Worker, WorkerStateEnum
  6. from gpustack.server.db import async_session
  7. logger = logging.getLogger(__name__)
  8. class StatusScorer(ModelInstanceScorer):
  9. def __init__(
  10. self,
  11. model: Model,
  12. model_instance: Optional[ModelInstance] = None,
  13. max_score: float = 100.0,
  14. ):
  15. self._model = model
  16. self._model_instance = model_instance
  17. self._max_score = max_score
  18. async def score_instances(
  19. self, instances: List[ModelInstance]
  20. ) -> List[ModelInstanceScore]:
  21. """
  22. Score the instances with the worker and instance status.
  23. """
  24. logger.debug(f"model {self._model.name}, score instances with status policy")
  25. scored_instances = []
  26. async with async_session() as session:
  27. workers = await Worker.all(session)
  28. worker_map = {worker.id: worker for worker in workers}
  29. for instance in instances:
  30. if instance.worker_id is None:
  31. scored_instances.append(
  32. ModelInstanceScore(model_instance=instance, score=0)
  33. )
  34. continue
  35. score = 0
  36. worker = worker_map.get(instance.worker_id)
  37. if worker is None:
  38. scored_instances.append(
  39. ModelInstanceScore(model_instance=instance, score=0)
  40. )
  41. continue
  42. if worker.state == WorkerStateEnum.NOT_READY:
  43. score = 0
  44. elif instance.state == ModelInstanceStateEnum.ERROR:
  45. score = 0
  46. elif (
  47. worker.state == WorkerStateEnum.READY
  48. and instance.state == ModelInstanceStateEnum.RUNNING
  49. ):
  50. score = self._max_score
  51. else:
  52. score = self._max_score * 0.5
  53. scored_instances.append(
  54. ModelInstanceScore(model_instance=instance, score=score)
  55. )
  56. return scored_instances