bus.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import asyncio
  2. import logging
  3. from enum import Enum
  4. from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
  5. from gpustack.envs import EVENT_BUS_SUBSCRIBER_QUEUE_SIZE
  6. # Re-export from coordinator.base for backward compatibility
  7. from gpustack.server.coordinator.base import Event, EventType
  8. from gpustack.server.coordinator.cache import get_change_detector
  9. from gpustack.server.coordinator.models import get_model_for_topic
  10. logger = logging.getLogger(__name__)
  11. class EventCountKind(Enum):
  12. """Subscriber-side counter buckets surfaced as Prometheus labels.
  13. On the normal completion path:
  14. ``RECEIVED = FILTERED + COALESCED + ENQUEUED`` and
  15. ``BACKPRESSURED ⊆ ENQUEUED``. If a ``put`` is cancelled mid-flight,
  16. BACKPRESSURED may have been bumped without a matching ENQUEUED — the
  17. ``latest_by_key`` rollback in ``enqueue`` keeps the queue/dict
  18. invariant intact, but the counter invariant is best-effort.
  19. """
  20. RECEIVED = "received"
  21. FILTERED = "filtered"
  22. COALESCED = "coalesced"
  23. ENQUEUED = "enqueued"
  24. BACKPRESSURED = "backpressured"
  25. # Re-export for backward compatibility
  26. __all__ = [
  27. 'Event',
  28. 'EventType',
  29. 'EventCountKind',
  30. 'Subscriber',
  31. 'EventBus',
  32. 'event_bus',
  33. 'set_coordinator',
  34. 'event_decoder',
  35. ]
  36. def event_decoder(obj):
  37. if "type" in obj:
  38. obj["type"] = EventType[obj["type"]]
  39. return obj
  40. class Subscriber:
  41. """A bus subscriber owning its own bounded queue.
  42. UPDATED events for the same id are coalesced via ``latest_by_key``.
  43. Invariant: ``id ∈ latest_by_key`` iff there is (or will be) a queue
  44. token whose ``receive()`` will pop it. When the queue is full the
  45. producer awaits ``put`` rather than dropping. Publish paths spawn
  46. enqueue in their own tasks (see ``EventBus._route_event``), so
  47. backpressure stalls only the per-event task, not the caller.
  48. """
  49. def __init__(
  50. self,
  51. topic: Optional[str] = None,
  52. source: Optional[str] = None,
  53. event_types: Optional[Iterable[EventType]] = None,
  54. queue_size: Optional[int] = None,
  55. ):
  56. self.topic = topic
  57. self.source = source
  58. self.event_types: Optional[Set[EventType]] = (
  59. set(event_types) if event_types else None
  60. )
  61. self.queue: asyncio.Queue = asyncio.Queue(
  62. maxsize=(
  63. queue_size
  64. if queue_size is not None
  65. else EVENT_BUS_SUBSCRIBER_QUEUE_SIZE
  66. )
  67. )
  68. self.latest_by_key: Dict[Any, Event] = {}
  69. self.lock = asyncio.Lock()
  70. # Read by ``BusMetricsCollector`` and reflected as Prometheus counters.
  71. self.event_counts: Dict[Tuple[EventCountKind, str], int] = {}
  72. def _bump(self, kind: EventCountKind, event_type: EventType) -> None:
  73. key = (kind, event_type.name)
  74. self.event_counts[key] = self.event_counts.get(key, 0) + 1
  75. def should_enqueue(self, event: Event) -> bool:
  76. """Pre-enqueue filter. Drops events the subscriber has opted out of."""
  77. if self.event_types is not None and event.type not in self.event_types:
  78. return False
  79. return True
  80. async def enqueue(self, event: Event):
  81. self._bump(EventCountKind.RECEIVED, event.type)
  82. if not self.should_enqueue(event):
  83. self._bump(EventCountKind.FILTERED, event.type)
  84. return
  85. if event.type == EventType.UPDATED and event.id is not None:
  86. async with self.lock:
  87. if event.id in self.latest_by_key:
  88. self.latest_by_key[event.id] = event
  89. self._bump(EventCountKind.COALESCED, event.type)
  90. return
  91. self.latest_by_key[event.id] = event
  92. # Release the lock before awaiting put: a full queue would
  93. # otherwise serialize unrelated ids behind it.
  94. try:
  95. await self._put_with_backpressure(event)
  96. except BaseException:
  97. # If the put was cancelled or errored before a token reached
  98. # the queue, neither this event nor any later UPDATED that
  99. # piggybacked on the same dict entry will ever be popped.
  100. # Roll back so the next UPDATED for this id can re-enter
  101. # the queue — otherwise we'd reproduce the #4794 stranded-id
  102. # bug, just triggered by cancellation instead of QueueFull.
  103. async with self.lock:
  104. self.latest_by_key.pop(event.id, None)
  105. raise
  106. return
  107. await self._put_with_backpressure(event)
  108. async def _put_with_backpressure(self, event: Event):
  109. if self.queue.full():
  110. logger.warning(
  111. "Subscriber queue full, applying backpressure: "
  112. "source=%s topic=%s event_type=%s id=%s queue_size=%s",
  113. self.source,
  114. self.topic,
  115. event.type.name,
  116. event.id,
  117. self.queue.qsize(),
  118. )
  119. self._bump(EventCountKind.BACKPRESSURED, event.type)
  120. await self.queue.put(event)
  121. self._bump(EventCountKind.ENQUEUED, event.type)
  122. async def receive(self) -> Any:
  123. event = await self.queue.get()
  124. if event.type == EventType.UPDATED and event.id is not None:
  125. async with self.lock:
  126. return self.latest_by_key.pop(event.id, event)
  127. return event
  128. class EventBus:
  129. def __init__(self):
  130. """
  131. Initialize EventBus.
  132. Uses coordinator for distributed pub/sub when available,
  133. otherwise operates in local-only mode.
  134. """
  135. self.subscribers: Dict[str, List[Subscriber]] = {}
  136. self._coordinator = None
  137. self._listen_task: Optional[asyncio.Task] = None
  138. self._subscribed_channels: set = set()
  139. # Holds strong references to fire-and-forget tasks so the GC
  140. # doesn't reap them mid-execution (Python's create_task only
  141. # holds a weak reference to the task it returns).
  142. self._pending_tasks: Set[asyncio.Task] = set()
  143. def _spawn(self, coro) -> asyncio.Task:
  144. """``asyncio.create_task`` plus retain-and-discard bookkeeping."""
  145. task = asyncio.create_task(coro)
  146. self._pending_tasks.add(task)
  147. task.add_done_callback(self._pending_tasks.discard)
  148. return task
  149. def set_coordinator(self, coordinator):
  150. """Set the coordinator for distributed pub/sub."""
  151. self._coordinator = coordinator
  152. async def start(self):
  153. """Start the EventBus listener."""
  154. if self._coordinator:
  155. # Register ourselves as a subscriber to coordinator
  156. for topic in self.subscribers:
  157. await self._subscribe_to_coordinator(topic)
  158. logger.info("EventBus started with coordinator")
  159. async def stop(self):
  160. """Stop the EventBus."""
  161. if self._listen_task:
  162. self._listen_task.cancel()
  163. try:
  164. await self._listen_task
  165. except asyncio.CancelledError:
  166. pass
  167. logger.info("EventBus stopped")
  168. def subscribe(
  169. self,
  170. topic: str,
  171. source: Optional[str] = None,
  172. event_types: Optional[Iterable[EventType]] = None,
  173. ) -> Subscriber:
  174. """Subscribe to a topic.
  175. ``source`` is a free-form label used in queue-full log lines so
  176. operators can identify which consumer is backpressuring. ``event_types``
  177. is an optional whitelist applied before enqueue — events not matching
  178. are dropped without occupying a queue slot.
  179. """
  180. subscriber = Subscriber(topic=topic, source=source, event_types=event_types)
  181. if topic not in self.subscribers:
  182. self.subscribers[topic] = []
  183. # Subscribe to coordinator if available
  184. if self._coordinator:
  185. self._spawn(self._subscribe_to_coordinator(topic))
  186. self.subscribers[topic].append(subscriber)
  187. return subscriber
  188. async def _subscribe_to_coordinator(self, topic: str):
  189. """Subscribe to coordinator for a topic."""
  190. if topic in self._subscribed_channels:
  191. return
  192. try:
  193. # Create a closure that captures the topic
  194. def on_event(event: Event):
  195. self._on_coordinator_event(event, topic)
  196. # Register callback with coordinator
  197. self._coordinator.subscribe(topic, on_event)
  198. self._subscribed_channels.add(topic)
  199. logger.debug(f"Subscribed to coordinator topic: {topic}")
  200. except Exception as e:
  201. logger.error(f"Failed to subscribe to coordinator topic {topic}: {e}")
  202. def _on_coordinator_event(self, event: Event, topic: str):
  203. """Handle event received from coordinator.
  204. Coordinator implementations must invoke this callback from the main
  205. event loop (see Coordinator.subscribe); a coordinator whose driver
  206. fires events from a background thread is responsible for bridging
  207. to the main loop itself (e.g. via loop.call_soon_threadsafe).
  208. """
  209. try:
  210. self._spawn(self._process_coordinator_event(event, topic))
  211. except RuntimeError:
  212. logger.warning(
  213. f"No running event loop for coordinator event on topic {topic}, skipping"
  214. )
  215. async def _process_coordinator_event(self, event: Event, topic: str):
  216. """
  217. Process event from coordinator.
  218. For cross-instance events (only ID received), this method:
  219. 1. Fetches full data from database
  220. 2. Detects changes using local cache
  221. 3. Reconstructs the event with complete data and changed_fields
  222. """
  223. # Delay import to avoid circular imports
  224. from gpustack.server.db import async_session
  225. # Check if this is a cross-instance event (only has ID)
  226. is_id_only = (
  227. event.data is not None
  228. and isinstance(event.data, dict)
  229. and set(event.data.keys()) == {"id"}
  230. )
  231. if not is_id_only:
  232. # Local event or cache event, route directly
  233. logger.trace(
  234. f"Routing non-ID-only event for topic {topic}: data type={type(event.data).__name__}, keys={list(event.data.keys()) if isinstance(event.data, dict) else 'N/A'}, id={event.id}"
  235. )
  236. self._route_event(event, topic)
  237. return
  238. # Skip events with no ID - we can't fetch from database
  239. if event.id is None:
  240. logger.warning(
  241. f"Skipping event for topic {topic}: no ID present, cannot fetch data."
  242. )
  243. return
  244. try:
  245. model_class = get_model_for_topic(topic)
  246. if model_class is None:
  247. # Unknown topic, skip to avoid sending incomplete data
  248. logger.debug(f"Skipping event for topic {topic}: no model class found.")
  249. return
  250. # Use ChangeDetector to detect changes and manage cache
  251. detector = get_change_detector(topic)
  252. old_obj = detector.get(event.id)
  253. async with async_session() as session:
  254. # Fetch full object from database
  255. obj = await model_class.one_by_id(session, event.id)
  256. if event.type == EventType.DELETED:
  257. # For DELETED events, object is already gone from DB
  258. # Use cached old_obj as the data for the event
  259. if old_obj is not None:
  260. # Use cached object to provide full data for DELETED event
  261. enriched_event = Event(
  262. type=event.type,
  263. data=old_obj,
  264. changed_fields={},
  265. id=event.id,
  266. )
  267. logger.trace(
  268. f"Enriched DELETED event for topic {topic}: id={event.id}, "
  269. f"using cached {type(old_obj).__name__}"
  270. )
  271. self._route_event(enriched_event, topic)
  272. else:
  273. # No cached object, route ID-only event for DELETED
  274. # so clients know the object was deleted
  275. logger.trace(
  276. f"Routing ID-only DELETED event for topic {topic}: id={event.id}, "
  277. f"no cached object available"
  278. )
  279. self._route_event(event, topic)
  280. # Always remove from cache on DELETE
  281. detector.remove(event.id)
  282. return
  283. if obj is None:
  284. # Object not in DB (race condition or already deleted), skip
  285. logger.debug(
  286. f"Skipping event for topic {topic}: object {event.id} not found in database."
  287. )
  288. return
  289. # Detect changes for non-DELETE events
  290. changed_fields = detector.detect_changes(old_obj, obj)
  291. # Update cache with new object
  292. detector.put(event.id, obj)
  293. # Reconstruct event with full data and detected changes
  294. enriched_event = Event(
  295. type=event.type,
  296. data=obj,
  297. changed_fields=changed_fields,
  298. id=event.id,
  299. )
  300. logger.trace(
  301. f"Enriched event for topic {topic}: id={event.id}, "
  302. f"model={type(obj).__name__}, changed_fields={list(changed_fields.keys())}"
  303. )
  304. self._route_event(enriched_event, topic)
  305. except Exception as e:
  306. logger.error(
  307. f"Failed to enrich coordinator event for {topic}: {e}. "
  308. f"Skipping event to avoid sending incomplete data."
  309. )
  310. # Skip the event rather than sending incomplete data
  311. return
  312. def _route_event(self, event: Event, topic: str):
  313. """Route event to subscribers of the specific topic.
  314. Per-subscriber enqueue runs in its own task so a slow consumer
  315. cannot head-of-line block its peers under blocking backpressure.
  316. Trade-off: this fan-out is unbounded — for very hot topics with no
  317. coalescing protection (CREATED/DELETED), bursts can spawn many
  318. pending enqueue tasks on slow consumers. UPDATED is naturally
  319. bounded by ``latest_by_key`` coalescing.
  320. """
  321. if topic in self.subscribers:
  322. for subscriber in self.subscribers[topic]:
  323. self._spawn(subscriber.enqueue(event))
  324. def unsubscribe(self, topic: str, subscriber: Subscriber):
  325. """Unsubscribe from a topic."""
  326. if topic in self.subscribers:
  327. self.subscribers[topic].remove(subscriber)
  328. if not self.subscribers[topic]:
  329. del self.subscribers[topic]
  330. async def publish(self, topic: str, event: Event):
  331. """Publish an event to a topic.
  332. With a coordinator, distribution flows through it so every instance
  333. sees the event on the same path. On failure or in standalone mode,
  334. fall back to ``_route_event`` for local fan-out — each subscriber's
  335. enqueue runs in its own task, so backpressure on one consumer does
  336. not head-of-line block its peers.
  337. """
  338. if self._coordinator:
  339. try:
  340. await self._coordinator.publish(topic, event)
  341. return
  342. except Exception as e:
  343. logger.error(
  344. f"Failed to publish event to coordinator, "
  345. f"falling back to local delivery: {e}"
  346. )
  347. self._route_event(event, topic)
  348. event_bus = EventBus()
  349. def set_coordinator(coordinator):
  350. """Set the coordinator for the global event bus."""
  351. event_bus.set_coordinator(coordinator)