test_bus.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import asyncio
  2. import logging
  3. import pytest
  4. from gpustack.server.bus import Event, EventType, Subscriber
  5. @pytest.mark.asyncio
  6. async def test_updated_event_overflow_does_not_leave_unreceivable_latest_event():
  7. """Regression for #4794: queue-full UPDATED ids must remain deliverable."""
  8. queue_size = 4
  9. subscriber = Subscriber(topic="modelinstance", source="test", queue_size=queue_size)
  10. total = queue_size + 5
  11. enqueue_tasks = [
  12. asyncio.create_task(
  13. subscriber.enqueue(
  14. Event(
  15. type=EventType.UPDATED,
  16. data={"id": event_id, "value": event_id},
  17. id=event_id,
  18. )
  19. )
  20. )
  21. for event_id in range(total)
  22. ]
  23. received_ids = []
  24. for _ in range(total):
  25. event = await asyncio.wait_for(subscriber.receive(), timeout=2)
  26. received_ids.append(event.id)
  27. await asyncio.gather(*enqueue_tasks)
  28. assert sorted(received_ids) == list(range(total))
  29. assert subscriber.latest_by_key == {}
  30. assert subscriber.queue.empty()
  31. @pytest.mark.asyncio
  32. async def test_updated_events_for_same_id_are_coalesced_to_latest():
  33. subscriber = Subscriber(topic="modelinstance", source="test")
  34. await subscriber.enqueue(
  35. Event(type=EventType.UPDATED, data={"id": 1, "value": "old"}, id=1)
  36. )
  37. await subscriber.enqueue(
  38. Event(type=EventType.UPDATED, data={"id": 1, "value": "mid"}, id=1)
  39. )
  40. await subscriber.enqueue(
  41. Event(type=EventType.UPDATED, data={"id": 1, "value": "new"}, id=1)
  42. )
  43. event = await asyncio.wait_for(subscriber.receive(), timeout=1)
  44. assert event.id == 1
  45. assert event.data["value"] == "new"
  46. assert subscriber.latest_by_key == {}
  47. assert subscriber.queue.empty()
  48. @pytest.mark.asyncio
  49. async def test_subscriber_filters_event_types_before_enqueue():
  50. subscriber = Subscriber(
  51. topic="modelinstance",
  52. source="scheduler",
  53. event_types={EventType.CREATED},
  54. )
  55. await subscriber.enqueue(Event(type=EventType.UPDATED, data={"id": 1}, id=1))
  56. await subscriber.enqueue(Event(type=EventType.DELETED, data={"id": 2}, id=2))
  57. assert subscriber.queue.empty()
  58. assert subscriber.latest_by_key == {}
  59. await subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 3}, id=3))
  60. event = await asyncio.wait_for(subscriber.receive(), timeout=1)
  61. assert event.type == EventType.CREATED
  62. assert event.id == 3
  63. @pytest.mark.asyncio
  64. async def test_queue_full_log_includes_metadata(caplog):
  65. """The warning must identify which subscriber backpressured."""
  66. subscriber = Subscriber(topic="modelinstance", source="scheduler", queue_size=1)
  67. await subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 1}, id=1))
  68. caplog.set_level(logging.WARNING, logger="gpustack.server.bus")
  69. pending = asyncio.create_task(
  70. subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 2}, id=2))
  71. )
  72. # Yield so the enqueue task hits the full-queue branch.
  73. await asyncio.sleep(0)
  74. await asyncio.sleep(0)
  75. await asyncio.wait_for(subscriber.receive(), timeout=1)
  76. await asyncio.wait_for(subscriber.receive(), timeout=1)
  77. await pending
  78. matching = [
  79. rec
  80. for rec in caplog.records
  81. if "queue full, applying backpressure" in rec.getMessage()
  82. ]
  83. assert matching, "expected queue-full backpressure log entry"
  84. msg = matching[0].getMessage()
  85. assert "source=scheduler" in msg
  86. assert "topic=modelinstance" in msg
  87. assert "event_type=CREATED" in msg
  88. assert "id=2" in msg
  89. assert "queue_size=1" in msg
  90. @pytest.mark.asyncio
  91. async def test_publish_does_not_let_slow_subscriber_block_peers():
  92. """A full-queue subscriber must not head-of-line block its peers."""
  93. from gpustack.server.bus import EventBus
  94. bus = EventBus()
  95. topic = "_test_publish_fanout"
  96. slow = bus.subscribe(topic, source="slow")
  97. fast = bus.subscribe(topic, source="fast")
  98. slow.queue = asyncio.Queue(maxsize=1)
  99. await slow.enqueue(Event(type=EventType.CREATED, data={"id": 0}, id=0))
  100. try:
  101. await bus.publish(topic, Event(type=EventType.CREATED, data={"id": 1}, id=1))
  102. delivered = await asyncio.wait_for(fast.receive(), timeout=1)
  103. assert delivered.id == 1
  104. assert slow.queue.qsize() == 1 # still backpressured
  105. finally:
  106. bus.unsubscribe(topic, slow)
  107. bus.unsubscribe(topic, fast)
  108. @pytest.mark.asyncio
  109. async def test_cancelled_updated_put_rolls_back_latest_by_key():
  110. """If the producer task is cancelled while awaiting backpressure,
  111. ``latest_by_key`` must be rolled back so the next UPDATED for the same
  112. id can re-enter the queue. Without rollback this reproduces the
  113. #4794 stranded-id bug, just triggered by cancel rather than QueueFull.
  114. """
  115. subscriber = Subscriber(topic="modelinstance", source="test", queue_size=1)
  116. # Fill the queue with an unrelated event so the next put will block.
  117. await subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 0}, id=0))
  118. # Start an UPDATED enqueue for id=42 — it writes latest_by_key[42]
  119. # then awaits put on the full queue.
  120. cancelled = asyncio.create_task(
  121. subscriber.enqueue(Event(type=EventType.UPDATED, data={"id": 42}, id=42))
  122. )
  123. for _ in range(5):
  124. await asyncio.sleep(0)
  125. if 42 in subscriber.latest_by_key:
  126. break
  127. assert 42 in subscriber.latest_by_key
  128. cancelled.cancel()
  129. try:
  130. await cancelled
  131. except asyncio.CancelledError:
  132. pass
  133. # Rollback should clear the orphan entry.
  134. assert 42 not in subscriber.latest_by_key
  135. # A fresh UPDATED for id=42 must be deliverable. Drain the prefill
  136. # first to avoid a second blocking put.
  137. drained = await asyncio.wait_for(subscriber.receive(), timeout=1)
  138. assert drained.id == 0
  139. await subscriber.enqueue(
  140. Event(type=EventType.UPDATED, data={"id": 42, "v": "fresh"}, id=42)
  141. )
  142. delivered = await asyncio.wait_for(subscriber.receive(), timeout=1)
  143. assert delivered.id == 42
  144. assert delivered.data["v"] == "fresh"
  145. @pytest.mark.asyncio
  146. async def test_non_updated_events_block_under_backpressure_not_drop():
  147. subscriber = Subscriber(topic="modelinstance", source="test", queue_size=2)
  148. await subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 1}, id=1))
  149. await subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 2}, id=2))
  150. pending = asyncio.create_task(
  151. subscriber.enqueue(Event(type=EventType.CREATED, data={"id": 3}, id=3))
  152. )
  153. await asyncio.sleep(0)
  154. assert not pending.done()
  155. first = await asyncio.wait_for(subscriber.receive(), timeout=1)
  156. assert first.id == 1
  157. await asyncio.wait_for(pending, timeout=1)
  158. second = await asyncio.wait_for(subscriber.receive(), timeout=1)
  159. third = await asyncio.wait_for(subscriber.receive(), timeout=1)
  160. assert {second.id, third.id} == {2, 3}