| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- import inspect
- import asyncio
- import logging
- import functools
- from typing import Any, Callable, Optional, TYPE_CHECKING
- from cachetools import LRUCache
- from aiocache import Cache, BaseCache
- from gpustack import envs
- from gpustack.server.coordinator.base import Event, EventType
- if TYPE_CHECKING:
- from gpustack.server.coordinator.base import Coordinator
- logger = logging.getLogger(__name__)
- cache = Cache(Cache.MEMORY)
- # Cache locks for locked_cached decorator
- # Locks are created per cache key and should be cleaned up when cache expires
- # Using LRUCache from cachetools for automatic LRU eviction
- _cache_locks: LRUCache[str, asyncio.Lock] = LRUCache(
- maxsize=envs.SERVER_CACHE_LOCKS_MAX_SIZE
- )
- # Global coordinator reference for distributed cache synchronization
- _coordinator: Optional["Coordinator"] = None
- def set_coordinator(coordinator: Optional["Coordinator"]) -> None:
- """Set the coordinator for distributed cache synchronization.
- This is called during server startup to enable cache invalidation
- broadcasting across instances.
- """
- global _coordinator
- _coordinator = coordinator
- if coordinator:
- # Subscribe to cache invalidation events
- coordinator.subscribe("cache", _handle_cache_invalidate)
- logger.debug("Distributed cache synchronization enabled")
- def _handle_cache_invalidate(event: "Event") -> None:
- """Handle cache invalidation events from other instances."""
- if event.type == EventType.DELETED and event.data:
- key = event.data.get("key")
- if key:
- # Use asyncio.create_task since this is called from sync context
- try:
- loop = asyncio.get_running_loop()
- loop.create_task(_local_delete_cache(key))
- except RuntimeError:
- # No event loop running, ignore
- pass
- async def _local_delete_cache(key: str) -> None:
- """Delete cache locally without broadcasting (for remote events)."""
- logger.trace(f"Deleting cache for key: {key} (from remote)")
- await cache.delete(key)
- _cache_locks.pop(key, None)
- async def _broadcast_invalidation(key: str) -> None:
- """Broadcast cache invalidation to other instances."""
- if _coordinator is None:
- return
- try:
- await _coordinator.publish(
- "cache", Event(type=EventType.DELETED, data={"key": key})
- )
- logger.trace(f"Broadcasted cache invalidation for key: {key}")
- except Exception as e:
- logger.warning(f"Failed to broadcast cache invalidation: {e}")
- def build_cache_key(func: Callable, *args, **kwargs):
- sig = inspect.signature(func)
- params = list(sig.parameters.values())
- # locked_cached.decorator strips 'self' before calling here, but unbound
- # functions still have 'self' in their signature. Strip it so keys match
- # when delete_cache_by_key is called with a bound method (no self in sig).
- if params and params[0].name in ("self", "cls") and not hasattr(func, "__self__"):
- sig = sig.replace(parameters=params[1:])
- try:
- bound = sig.bind(*args, **kwargs)
- bound.apply_defaults()
- # bound.arguments follows declaration order, so kwargs ordering is stable.
- return func.__qualname__ + str(tuple(bound.arguments.values()))
- except TypeError:
- # Fallback for callers that pass args not matching the function signature
- # (e.g. build_cache_key used as a manual key-construction helper).
- # Sort kwargs for a stable key regardless of call-site ordering.
- return func.__qualname__ + str(args) + str(sorted(kwargs.items()))
- async def delete_cache_by_key(
- func=None, *args, sync_coordinator: bool = True, **kwargs
- ):
- """Delete cache by key or function.
- Args:
- func: The cached function (optional)
- *args: Arguments to build the cache key
- sync_coordinator: Whether to broadcast invalidation to other instances via coordinator.
- Default is True for security-sensitive data.
- Set to False for high-frequency, non-critical caches (e.g., Worker status).
- **kwargs: Additional arguments including `_key` for explicit key
- """
- key = kwargs.pop("_key", None)
- if key is None:
- if func is None:
- raise ValueError("Either func or key must be provided")
- key = build_cache_key(func, *args, **kwargs)
- logger.trace(f"Deleting cache for key: {key}")
- await cache.delete(key)
- _cache_locks.pop(key, None)
- # Broadcast to other instances via coordinator
- if sync_coordinator:
- await _broadcast_invalidation(key)
- async def set_cache_by_key(key: str, value: Any):
- logger.trace(f"Set cache for key: {key}")
- await cache.set(key, value)
- def class_key(suffix: str):
- """Generate a cache key builder for class methods.
- Usage:
- @locked_cached(key=class_key("all_cached"))
- async def cached_all(cls, session, ...):
- ...
- The generated key will be "{ClassName}.{suffix}", e.g., "Worker.all_cached"
- """
- # FIXME: The kwargs should be taken into account for more fine-grained cache keys,
- # but for now we just use the class name and suffix for simplicity.
- # Using kwargs as key causes https://github.com/gpustack/gpustack/issues/4813.
- def builder(f, *args, **kwargs):
- cls = args[0] # First arg is cls for classmethod
- return f"{cls.__name__}.{suffix}"
- return builder
- class locked_cached:
- def __init__(
- self,
- ttl: int = envs.SERVER_CACHE_TTL_SECONDS,
- cache: BaseCache = cache,
- key: str = None,
- ):
- self.cache = cache
- self.ttl = ttl
- self.key = key
- def __call__(self, f):
- @functools.wraps(f)
- async def wrapper(*args, **kwargs):
- return await self.decorator(f, *args, **kwargs)
- wrapper.cache = self.cache
- wrapper.cache_key = self.key
- return wrapper
- async def get_from_cache(self, key: str):
- return await self.cache.get(key)
- async def set_in_cache(self, key: str, value: Any):
- await self.cache.set(key, value, ttl=self.ttl)
- async def decorator(self, f, *args, **kwargs):
- if self.key is not None:
- key = self.key(f, *args, **kwargs) if callable(self.key) else self.key
- else:
- # no self arg
- key = build_cache_key(f, *args[1:], **kwargs)
- value = await self.get_from_cache(key)
- if value is not None:
- return value
- lock = _cache_locks.setdefault(key, asyncio.Lock())
- async with lock:
- value = await self.get_from_cache(key)
- if value is not None:
- return value
- logger.trace(f"cache miss for key: {key}")
- result = await f(*args, **kwargs)
- if result is not None:
- await self.set_in_cache(key, result)
- return result
|