worker_status_buffer.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import asyncio
  2. import datetime
  3. import logging
  4. from typing import Dict, Set
  5. from sqlalchemy import update
  6. from gpustack.schemas.workers import Worker
  7. from gpustack.server.db import async_session
  8. from gpustack.server.services import WorkerService
  9. logger = logging.getLogger(__name__)
  10. FLUSH_INTERVAL_SECONDS = 5
  11. # Buffer to store worker IDs that need heartbeat update
  12. heartbeat_flush_buffer: Set[int] = set()
  13. heartbeat_flush_buffer_lock = asyncio.Lock()
  14. # Buffer to store worker status updates: {worker_id: input_dict}
  15. worker_status_flush_buffer: Dict[int, dict] = {}
  16. worker_status_flush_buffer_lock = asyncio.Lock()
  17. async def flush_heartbeats():
  18. """
  19. Flush worker heartbeat updates to the database periodically.
  20. Uses a single UPDATE statement to update all workers with the same timestamp.
  21. """
  22. if not heartbeat_flush_buffer:
  23. return
  24. # Copy buffer and clear it atomically
  25. async with heartbeat_flush_buffer_lock:
  26. local_buffer = set(heartbeat_flush_buffer)
  27. heartbeat_flush_buffer.clear()
  28. try:
  29. async with async_session() as session:
  30. # Single UPDATE for all workers with the same timestamp
  31. # UPDATE workers SET heartbeat_time = '2024-01-27 10:00:00' WHERE id IN (1, 2, 3, ...)
  32. heartbeat_time = datetime.datetime.now(datetime.timezone.utc).replace(
  33. microsecond=0
  34. )
  35. stmt = (
  36. update(Worker)
  37. .where(Worker.id.in_(local_buffer))
  38. .values(heartbeat_time=heartbeat_time)
  39. )
  40. await session.execute(stmt)
  41. await session.commit()
  42. except Exception as e:
  43. logger.error(f"Error flushing heartbeats to DB: {e}")
  44. async def flush_worker_status():
  45. """
  46. Flush worker status updates to the database periodically.
  47. Uses batch_update to update multiple workers with different status data.
  48. """
  49. if not worker_status_flush_buffer:
  50. return
  51. async with worker_status_flush_buffer_lock:
  52. to_update_worker_ids = list(worker_status_flush_buffer.keys())
  53. to_update_worker_status = dict(worker_status_flush_buffer)
  54. worker_status_flush_buffer.clear()
  55. try:
  56. async with async_session() as session:
  57. # Query workers by ids
  58. workers = await Worker.all_by_fields(
  59. session=session, extra_conditions=[Worker.id.in_(to_update_worker_ids)]
  60. )
  61. for worker in workers:
  62. for key, value in to_update_worker_status.get(worker.id, {}).items():
  63. setattr(worker, key, value)
  64. worker.compute_state()
  65. await WorkerService(session).batch_update(workers)
  66. except Exception as e:
  67. logger.error(f"Error flushing worker status to DB: {e}")
  68. async def flush_worker_status_to_db():
  69. """
  70. Flush both worker heartbeats and status updates to the database periodically.
  71. """
  72. while True:
  73. await asyncio.sleep(FLUSH_INTERVAL_SECONDS)
  74. await flush_heartbeats()
  75. await flush_worker_status()