cache.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import inspect
  2. import asyncio
  3. import logging
  4. import functools
  5. from typing import Any, Callable, Optional, TYPE_CHECKING
  6. from cachetools import LRUCache
  7. from aiocache import Cache, BaseCache
  8. from gpustack import envs
  9. from gpustack.server.coordinator.base import Event, EventType
  10. if TYPE_CHECKING:
  11. from gpustack.server.coordinator.base import Coordinator
  12. logger = logging.getLogger(__name__)
  13. cache = Cache(Cache.MEMORY)
  14. # Cache locks for locked_cached decorator
  15. # Locks are created per cache key and should be cleaned up when cache expires
  16. # Using LRUCache from cachetools for automatic LRU eviction
  17. _cache_locks: LRUCache[str, asyncio.Lock] = LRUCache(
  18. maxsize=envs.SERVER_CACHE_LOCKS_MAX_SIZE
  19. )
  20. # Global coordinator reference for distributed cache synchronization
  21. _coordinator: Optional["Coordinator"] = None
  22. def set_coordinator(coordinator: Optional["Coordinator"]) -> None:
  23. """Set the coordinator for distributed cache synchronization.
  24. This is called during server startup to enable cache invalidation
  25. broadcasting across instances.
  26. """
  27. global _coordinator
  28. _coordinator = coordinator
  29. if coordinator:
  30. # Subscribe to cache invalidation events
  31. coordinator.subscribe("cache", _handle_cache_invalidate)
  32. logger.debug("Distributed cache synchronization enabled")
  33. def _handle_cache_invalidate(event: "Event") -> None:
  34. """Handle cache invalidation events from other instances."""
  35. if event.type == EventType.DELETED and event.data:
  36. key = event.data.get("key")
  37. if key:
  38. # Use asyncio.create_task since this is called from sync context
  39. try:
  40. loop = asyncio.get_running_loop()
  41. loop.create_task(_local_delete_cache(key))
  42. except RuntimeError:
  43. # No event loop running, ignore
  44. pass
  45. async def _local_delete_cache(key: str) -> None:
  46. """Delete cache locally without broadcasting (for remote events)."""
  47. logger.trace(f"Deleting cache for key: {key} (from remote)")
  48. await cache.delete(key)
  49. _cache_locks.pop(key, None)
  50. async def _broadcast_invalidation(key: str) -> None:
  51. """Broadcast cache invalidation to other instances."""
  52. if _coordinator is None:
  53. return
  54. try:
  55. await _coordinator.publish(
  56. "cache", Event(type=EventType.DELETED, data={"key": key})
  57. )
  58. logger.trace(f"Broadcasted cache invalidation for key: {key}")
  59. except Exception as e:
  60. logger.warning(f"Failed to broadcast cache invalidation: {e}")
  61. def build_cache_key(func: Callable, *args, **kwargs):
  62. sig = inspect.signature(func)
  63. params = list(sig.parameters.values())
  64. # locked_cached.decorator strips 'self' before calling here, but unbound
  65. # functions still have 'self' in their signature. Strip it so keys match
  66. # when delete_cache_by_key is called with a bound method (no self in sig).
  67. if params and params[0].name in ("self", "cls") and not hasattr(func, "__self__"):
  68. sig = sig.replace(parameters=params[1:])
  69. try:
  70. bound = sig.bind(*args, **kwargs)
  71. bound.apply_defaults()
  72. # bound.arguments follows declaration order, so kwargs ordering is stable.
  73. return func.__qualname__ + str(tuple(bound.arguments.values()))
  74. except TypeError:
  75. # Fallback for callers that pass args not matching the function signature
  76. # (e.g. build_cache_key used as a manual key-construction helper).
  77. # Sort kwargs for a stable key regardless of call-site ordering.
  78. return func.__qualname__ + str(args) + str(sorted(kwargs.items()))
  79. async def delete_cache_by_key(
  80. func=None, *args, sync_coordinator: bool = True, **kwargs
  81. ):
  82. """Delete cache by key or function.
  83. Args:
  84. func: The cached function (optional)
  85. *args: Arguments to build the cache key
  86. sync_coordinator: Whether to broadcast invalidation to other instances via coordinator.
  87. Default is True for security-sensitive data.
  88. Set to False for high-frequency, non-critical caches (e.g., Worker status).
  89. **kwargs: Additional arguments including `_key` for explicit key
  90. """
  91. key = kwargs.pop("_key", None)
  92. if key is None:
  93. if func is None:
  94. raise ValueError("Either func or key must be provided")
  95. key = build_cache_key(func, *args, **kwargs)
  96. logger.trace(f"Deleting cache for key: {key}")
  97. await cache.delete(key)
  98. _cache_locks.pop(key, None)
  99. # Broadcast to other instances via coordinator
  100. if sync_coordinator:
  101. await _broadcast_invalidation(key)
  102. async def set_cache_by_key(key: str, value: Any):
  103. logger.trace(f"Set cache for key: {key}")
  104. await cache.set(key, value)
  105. def class_key(suffix: str):
  106. """Generate a cache key builder for class methods.
  107. Usage:
  108. @locked_cached(key=class_key("all_cached"))
  109. async def cached_all(cls, session, ...):
  110. ...
  111. The generated key will be "{ClassName}.{suffix}", e.g., "Worker.all_cached"
  112. """
  113. # FIXME: The kwargs should be taken into account for more fine-grained cache keys,
  114. # but for now we just use the class name and suffix for simplicity.
  115. # Using kwargs as key causes https://github.com/gpustack/gpustack/issues/4813.
  116. def builder(f, *args, **kwargs):
  117. cls = args[0] # First arg is cls for classmethod
  118. return f"{cls.__name__}.{suffix}"
  119. return builder
  120. class locked_cached:
  121. def __init__(
  122. self,
  123. ttl: int = envs.SERVER_CACHE_TTL_SECONDS,
  124. cache: BaseCache = cache,
  125. key: str = None,
  126. ):
  127. self.cache = cache
  128. self.ttl = ttl
  129. self.key = key
  130. def __call__(self, f):
  131. @functools.wraps(f)
  132. async def wrapper(*args, **kwargs):
  133. return await self.decorator(f, *args, **kwargs)
  134. wrapper.cache = self.cache
  135. wrapper.cache_key = self.key
  136. return wrapper
  137. async def get_from_cache(self, key: str):
  138. return await self.cache.get(key)
  139. async def set_in_cache(self, key: str, value: Any):
  140. await self.cache.set(key, value, ttl=self.ttl)
  141. async def decorator(self, f, *args, **kwargs):
  142. if self.key is not None:
  143. key = self.key(f, *args, **kwargs) if callable(self.key) else self.key
  144. else:
  145. # no self arg
  146. key = build_cache_key(f, *args[1:], **kwargs)
  147. value = await self.get_from_cache(key)
  148. if value is not None:
  149. return value
  150. lock = _cache_locks.setdefault(key, asyncio.Lock())
  151. async with lock:
  152. value = await self.get_from_cache(key)
  153. if value is not None:
  154. return value
  155. logger.trace(f"cache miss for key: {key}")
  156. result = await f(*args, **kwargs)
  157. if result is not None:
  158. await self.set_in_cache(key, result)
  159. return result