| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- import asyncio
- import datetime
- import logging
- from typing import Dict, Set
- from sqlalchemy import update
- from gpustack.schemas.workers import Worker
- from gpustack.server.db import async_session
- from gpustack.server.services import WorkerService
- logger = logging.getLogger(__name__)
- FLUSH_INTERVAL_SECONDS = 5
- # Buffer to store worker IDs that need heartbeat update
- heartbeat_flush_buffer: Set[int] = set()
- heartbeat_flush_buffer_lock = asyncio.Lock()
- # Buffer to store worker status updates: {worker_id: input_dict}
- worker_status_flush_buffer: Dict[int, dict] = {}
- worker_status_flush_buffer_lock = asyncio.Lock()
- async def flush_heartbeats():
- """
- Flush worker heartbeat updates to the database periodically.
- Uses a single UPDATE statement to update all workers with the same timestamp.
- """
- if not heartbeat_flush_buffer:
- return
- # Copy buffer and clear it atomically
- async with heartbeat_flush_buffer_lock:
- local_buffer = set(heartbeat_flush_buffer)
- heartbeat_flush_buffer.clear()
- try:
- async with async_session() as session:
- # Single UPDATE for all workers with the same timestamp
- # UPDATE workers SET heartbeat_time = '2024-01-27 10:00:00' WHERE id IN (1, 2, 3, ...)
- heartbeat_time = datetime.datetime.now(datetime.timezone.utc).replace(
- microsecond=0
- )
- stmt = (
- update(Worker)
- .where(Worker.id.in_(local_buffer))
- .values(heartbeat_time=heartbeat_time)
- )
- await session.execute(stmt)
- await session.commit()
- except Exception as e:
- logger.error(f"Error flushing heartbeats to DB: {e}")
- async def flush_worker_status():
- """
- Flush worker status updates to the database periodically.
- Uses batch_update to update multiple workers with different status data.
- """
- if not worker_status_flush_buffer:
- return
- async with worker_status_flush_buffer_lock:
- to_update_worker_ids = list(worker_status_flush_buffer.keys())
- to_update_worker_status = dict(worker_status_flush_buffer)
- worker_status_flush_buffer.clear()
- try:
- async with async_session() as session:
- # Query workers by ids
- workers = await Worker.all_by_fields(
- session=session, extra_conditions=[Worker.id.in_(to_update_worker_ids)]
- )
- for worker in workers:
- for key, value in to_update_worker_status.get(worker.id, {}).items():
- setattr(worker, key, value)
- worker.compute_state()
- await WorkerService(session).batch_update(workers)
- except Exception as e:
- logger.error(f"Error flushing worker status to DB: {e}")
- async def flush_worker_status_to_db():
- """
- Flush both worker heartbeats and status updates to the database periodically.
- """
- while True:
- await asyncio.sleep(FLUSH_INTERVAL_SECONDS)
- await flush_heartbeats()
- await flush_worker_status()
|