detector_factory.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import logging
  2. from typing import Dict, Optional, List
  3. from gpustack.detectors.base import (
  4. GPUDetector,
  5. GPUDevicesStatus,
  6. SystemInfoDetector,
  7. )
  8. from gpustack.detectors.runtime.runtime import Runtime
  9. from gpustack.schemas.workers import SystemInfo
  10. from gpustack.detectors.fastfetch.fastfetch import Fastfetch
  11. logger = logging.getLogger(__name__)
  12. class DetectorFactory:
  13. def __init__(
  14. self,
  15. device: Optional[str] = None,
  16. gpu_detectors: Optional[Dict[str, List[GPUDetector]]] = None,
  17. system_info_detector: Optional[SystemInfoDetector] = None,
  18. ):
  19. self.system_info_detector = system_info_detector or Fastfetch()
  20. self.device = device
  21. if device:
  22. self.gpu_detectors = gpu_detectors.get(device) or []
  23. else:
  24. self.gpu_detectors = [Runtime()]
  25. def detect_gpus(self) -> GPUDevicesStatus:
  26. for detector in self.gpu_detectors:
  27. if detector.is_available():
  28. gpus = detector.gather_gpu_info()
  29. if gpus:
  30. return self._filter_gpu_devices(gpus)
  31. return []
  32. def detect_system_info(self) -> SystemInfo:
  33. return self.system_info_detector.gather_system_info()
  34. @staticmethod
  35. def _filter_gpu_devices(gpu_devices: GPUDevicesStatus) -> GPUDevicesStatus:
  36. filtered: GPUDevicesStatus = []
  37. for device in gpu_devices:
  38. if not device.memory or not device.memory.total or device.memory.total <= 0:
  39. logger.debug(
  40. f"Skipping GPU device {device.name} ({device.device_index}, {device.device_chip_index}) due to invalid memory info"
  41. )
  42. continue
  43. filtered.append(device)
  44. return filtered