inference_backend_manager.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import asyncio
  2. import logging
  3. import threading
  4. from typing import Dict, Optional
  5. from gpustack.client import ClientSet
  6. from gpustack.schemas.inference_backend import InferenceBackend
  7. from gpustack.server.bus import EventType, Event
  8. logger = logging.getLogger(__name__)
  9. # Global lock for cache operations to avoid pickle serialization issues
  10. _cache_lock = threading.RLock()
  11. class InferenceBackendManager:
  12. """
  13. Unified singleton manager for InferenceBackend data.
  14. This class provides thread-safe access to InferenceBackend data
  15. across the worker layer, combining database operations and real-time listening.
  16. """
  17. def __init__(self, clientset: ClientSet):
  18. self.backends_cache: Dict[str, InferenceBackend] = {}
  19. # Listener related attributes
  20. self._clientset: Optional[ClientSet] = clientset
  21. self._running = False
  22. self._watch_task: Optional[asyncio.Task] = None
  23. self._initialize_cache()
  24. async def start_listener(self) -> None:
  25. """Start the listener service."""
  26. if not self._clientset:
  27. logger.warning("ClientSet not set, cannot start listener")
  28. return
  29. if self._running:
  30. logger.warning("InferenceBackendManager listener is already running")
  31. return
  32. self._running = True
  33. logger.info("Starting InferenceBackend listener service")
  34. # Start watching for changes
  35. self._watch_task = asyncio.create_task(self._watch_changes())
  36. def get_backend_by_name(self, backend_name: str) -> Optional[InferenceBackend]:
  37. with _cache_lock:
  38. if backend_name in self.backends_cache:
  39. return self.backends_cache[backend_name]
  40. return None
  41. def _initialize_cache(self) -> None:
  42. """Initialize the cache with existing InferenceBackend data."""
  43. try:
  44. logger.info("Initializing InferenceBackend cache")
  45. resp = self._clientset.http_client.get_httpx_client().get(
  46. "/inference-backends/all"
  47. )
  48. backends = resp.json()
  49. if backends:
  50. with _cache_lock:
  51. for backend in backends:
  52. backend = InferenceBackend.model_validate(backend)
  53. if backend:
  54. self.backends_cache[backend.backend_name] = backend
  55. logger.info(
  56. f"Initialized cache with {self.backends_cache.keys()} InferenceBackends"
  57. )
  58. else:
  59. logger.info("No existing InferenceBackends found")
  60. except Exception as e:
  61. logger.error(f"Failed to initialize InferenceBackend cache: {e}")
  62. raise
  63. async def _watch_changes(self) -> None:
  64. """Watch for InferenceBackend changes and update the cache."""
  65. while self._running:
  66. try:
  67. logger.info("Starting to watch InferenceBackend changes")
  68. await self._clientset.inference_backends.awatch(
  69. callback=self._handle_event
  70. )
  71. except asyncio.CancelledError:
  72. logger.info("InferenceBackend watch cancelled")
  73. break
  74. except Exception as e:
  75. logger.error(f"Error watching InferenceBackend changes: {e}")
  76. if self._running:
  77. # Wait before retrying
  78. await asyncio.sleep(5)
  79. def _merge_version_configs(
  80. self,
  81. old: Optional[InferenceBackend],
  82. backend: InferenceBackend,
  83. ) -> None:
  84. """
  85. Merge incoming backend version configs into the cached one while
  86. preserving built-in entries.
  87. Args:
  88. old: Previously cached backend for the same name.
  89. backend: Incoming backend to be merged into the cache.
  90. """
  91. if old and backend.is_built_in:
  92. # Snapshot previous and incoming version maps
  93. old_version = old.version_configs.root if old.version_configs else {}
  94. new_version = (
  95. backend.version_configs.root if backend.version_configs else {}
  96. )
  97. # Compute deletions: drop outdated non-built-in entries not present in new map
  98. delete_version = set()
  99. new_version_keys = set(new_version.keys())
  100. for k, v in old_version.items():
  101. if not v.built_in_frameworks and k not in new_version_keys:
  102. delete_version.add(k)
  103. # Start from old (preserves built-ins), then apply incoming updates
  104. merged = old_version
  105. for k, v in new_version.items():
  106. merged[k] = v
  107. # Remove marked entries and finalize
  108. for k in delete_version:
  109. merged.pop(k, None)
  110. backend.version_configs.root = merged
  111. def _handle_event(self, event: Event):
  112. """Handle a single InferenceBackend event."""
  113. try:
  114. # Parse the backend data
  115. backend = InferenceBackend.model_validate(event.data)
  116. if event.type == EventType.CREATED or event.type == EventType.UPDATED:
  117. with _cache_lock:
  118. old = self.backends_cache.get(backend.backend_name)
  119. self._merge_version_configs(old, backend)
  120. self.backends_cache[backend.backend_name] = backend
  121. logger.debug(
  122. f"Updated InferenceBackend in cache: {backend.id} ({event.type})"
  123. )
  124. elif event.type == EventType.DELETED:
  125. with _cache_lock:
  126. # Remove from both caches
  127. self.backends_cache.pop(backend.backend_name, None)
  128. logger.debug(
  129. f"Removed InferenceBackend from cache: {backend.backend_name}"
  130. )
  131. else:
  132. logger.warning(f"Unknown event type: {event.type}")
  133. except Exception as e:
  134. logger.error(f"Error handling InferenceBackend event: {e}")