base.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. import logging
  4. from typing import Dict, List, Optional, Tuple
  5. from gpustack.schemas.models import (
  6. ComputedResourceClaim,
  7. ModelInstance,
  8. ModelInstanceSubordinateWorker,
  9. )
  10. from gpustack.schemas.workers import Worker
  11. logger = logging.getLogger(__name__)
  12. @dataclass
  13. class ModelInstanceScore:
  14. model_instance: ModelInstance
  15. score: Optional[float] = None
  16. @dataclass
  17. class ModelInstanceScheduleCandidate:
  18. worker: Worker
  19. gpu_indexes: Optional[List[int]]
  20. computed_resource_claim: ComputedResourceClaim
  21. gpu_type: Optional[str] = None
  22. gpu_addresses: Optional[List[str]] = None
  23. score: Optional[float] = None
  24. overcommit: Optional[bool] = None
  25. # for multi-worker distributed scheduling
  26. subordinate_workers: Optional[List[ModelInstanceSubordinateWorker]] = None
  27. def to_log_string(self) -> str:
  28. log_entries = [
  29. f"worker: '{self.worker.name}'",
  30. ]
  31. if self.gpu_indexes:
  32. log_entries.append(f"gpu_indexes: {self.gpu_indexes}")
  33. if self.gpu_addresses:
  34. log_entries.append(f"gpu_addresses: {self.gpu_addresses}")
  35. if self.computed_resource_claim.offload_layers:
  36. log_entries.append(
  37. f"offload_layers: {self.computed_resource_claim.offload_layers}"
  38. )
  39. if self.computed_resource_claim.tensor_split:
  40. log_entries.append(
  41. f"tensor_split: {self.computed_resource_claim.tensor_split}"
  42. )
  43. if self.overcommit:
  44. log_entries.append("overcommit: true")
  45. if self.subordinate_workers:
  46. sw_str = '), ('.join(
  47. [
  48. f"worker_id: {sw.worker_id}, "
  49. f"worker_name: {sw.worker_name}, "
  50. f"worker_ip: {sw.worker_ip}, "
  51. f"worker_ifname {sw.worker_ifname}, "
  52. f"total_gpus: {sw.total_gpus}, "
  53. f"gpu_indexes: {sw.gpu_indexes}, "
  54. f"gpu_addresses: {sw.gpu_addresses}"
  55. for sw in self.subordinate_workers
  56. ]
  57. )
  58. log_entries.append(f"subordinate_workers: [{sw_str}]")
  59. return ', '.join(log_entries)
  60. @dataclass
  61. class AllocationResource:
  62. ram: int
  63. vram: Dict[int, int]
  64. @dataclass
  65. class Allocatable(AllocationResource):
  66. pass
  67. @dataclass
  68. class Allocated(AllocationResource):
  69. pass
  70. class WorkerFilter(ABC):
  71. @abstractmethod
  72. def filter(self, workers: List[Worker]) -> Tuple[List[Worker], List[str]]:
  73. """
  74. Filter workers suitable for scheduling.
  75. :return: A tuple containing:
  76. - A list of workers that pass the filter.
  77. - A list of messages why certain workers were filtered out.
  78. """
  79. pass
  80. class WorkerFilterChain:
  81. def __init__(self, filters: List[WorkerFilter]):
  82. self.filters = filters
  83. async def filter(self, workers) -> Tuple[List[Worker], List[str]]:
  84. """
  85. Applies all filters sequentially to the list of workers.
  86. :param workers: The initial list of workers.
  87. :return: A tuple containing:
  88. - The final list of workers that pass all filters.
  89. - A list of messages for all workers filtered out across all filters.
  90. """
  91. messages = []
  92. for policy in self.filters:
  93. workers, filter_messages = await policy.filter(workers)
  94. messages.extend(filter_messages)
  95. if not workers:
  96. break
  97. return workers, messages
  98. class ModelInstanceScorer(ABC):
  99. @property
  100. def max_score(self) -> Optional[float]:
  101. return getattr(self, "_max_score", None)
  102. @abstractmethod
  103. async def score_instances(
  104. self, instances: List[ModelInstance]
  105. ) -> List[ModelInstanceScore]:
  106. """
  107. Score the instances.
  108. :param instances: The list of instances to score.
  109. :return: A list of scored instances.
  110. """
  111. pass
  112. class ScheduleCandidatesScorer(ABC):
  113. @abstractmethod
  114. async def score(
  115. self, candidates: List[ModelInstanceScheduleCandidate]
  116. ) -> List[ModelInstanceScheduleCandidate]:
  117. """
  118. Score the candidates.
  119. :param candidates: The list of candidates to score.
  120. :return: A list of scored candidates.
  121. """
  122. pass