cache.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. Local cache for detecting changes in cross-instance events.
  3. When using PostgreSQL-based pub/sub, only the ID is transmitted across instances.
  4. Subscribers use this cache to store the previous state and detect what fields changed.
  5. """
  6. import logging
  7. from typing import Any, Dict, Generic, Optional, Tuple, TypeVar
  8. from cachetools import LRUCache
  9. logger = logging.getLogger(__name__)
  10. T = TypeVar("T")
  11. class ChangeDetector(Generic[T]):
  12. """
  13. Detects changes between old and new object states.
  14. Usage:
  15. detector = ChangeDetector[Worker]("worker")
  16. # On first event
  17. worker = await Worker.one_by_id(session, event_id)
  18. detector.put(worker.id, worker)
  19. # On subsequent events
  20. old_worker = detector.get(event_id) # Get cached old state
  21. worker = await Worker.one_by_id(session, event_id) # Query new state
  22. changed_fields = detector.detect_changes(old_worker, worker)
  23. detector.put(worker.id, worker) # Update cache
  24. """
  25. def __init__(self, entity_name: str, maxsize: int = 10000):
  26. self._entity_name = entity_name
  27. self._cache: LRUCache[Any, T] = LRUCache(maxsize=maxsize)
  28. def get(self, id: Any) -> Optional[T]:
  29. """Get cached object by ID."""
  30. return self._cache.get(id)
  31. def put(self, id: Any, obj: T) -> None:
  32. """Cache an object."""
  33. self._cache[id] = obj
  34. def remove(self, id: Any) -> None:
  35. """Remove an object from cache."""
  36. self._cache.pop(id, None)
  37. def detect_changes(
  38. self, old_obj: Optional[T], new_obj: T
  39. ) -> Dict[str, Tuple[Any, Any]]:
  40. """
  41. Detect field changes between old and new object.
  42. For list (relationship) fields, emit a ``(removed, added)`` delta
  43. matching the shape produced by the local ``find_history`` hook in
  44. ``active_record.py``, so callbacks work identically on local and
  45. cross-instance paths. Fields that fail to load (e.g. lazy relationship
  46. on a detached instance) are silently skipped.
  47. Returns:
  48. Dict of field_name -> (old_value, new_value) for scalar fields, or
  49. (removed_list, added_list) for relationship fields.
  50. """
  51. if old_obj is None:
  52. return {}
  53. changed_fields = {}
  54. # Get fields to compare (exclude internal SQLModel fields)
  55. fields_to_compare = getattr(new_obj, "model_fields", None)
  56. if fields_to_compare is None:
  57. # Fallback: compare all attributes
  58. fields_to_compare = [
  59. attr
  60. for attr in dir(new_obj)
  61. if not attr.startswith("_")
  62. and not callable(getattr(new_obj, attr, None))
  63. ]
  64. for field_name in fields_to_compare:
  65. if field_name.startswith("_"):
  66. continue
  67. try:
  68. old_val = getattr(old_obj, field_name, None)
  69. new_val = getattr(new_obj, field_name, None)
  70. if isinstance(old_val, list) or isinstance(new_val, list):
  71. old_list = old_val if isinstance(old_val, list) else []
  72. new_list = new_val if isinstance(new_val, list) else []
  73. diff = self._list_diff(old_list, new_list)
  74. if diff is not None:
  75. changed_fields[field_name] = diff
  76. continue
  77. if old_val != new_val:
  78. changed_fields[field_name] = (old_val, new_val)
  79. except Exception as e:
  80. logger.debug(
  81. f"Error comparing field {field_name} for {self._entity_name}: {e}"
  82. )
  83. continue
  84. return changed_fields
  85. @staticmethod
  86. def _list_diff(old_list: list, new_list: list) -> Optional[Tuple[list, list]]:
  87. """Return a ``(removed, added)`` delta between two relationship lists.
  88. Elements are keyed by ``.id`` (attribute) or ``["id"]`` (dict). If any
  89. element is keyless, fall back to whole-list equality and emit an empty
  90. delta to signal a change without trying to attribute add/remove.
  91. Returns None when the lists are equivalent (no change).
  92. """
  93. def key_of(item: Any) -> Any:
  94. if item is None:
  95. return None
  96. if hasattr(item, "id"):
  97. return getattr(item, "id")
  98. if isinstance(item, dict):
  99. return item.get("id")
  100. return None
  101. old_keys = [key_of(o) for o in old_list]
  102. new_keys = [key_of(n) for n in new_list]
  103. if any(k is None for k in old_keys) or any(k is None for k in new_keys):
  104. # Keyless elements — can't reliably attribute add/remove.
  105. return None if old_list == new_list else ([], [])
  106. old_set = set(old_keys)
  107. new_set = set(new_keys)
  108. if old_set == new_set:
  109. return None
  110. removed = [o for o, k in zip(old_list, old_keys) if k not in new_set]
  111. added = [n for n, k in zip(new_list, new_keys) if k not in old_set]
  112. return (removed, added)
  113. class EventCacheManager:
  114. """
  115. Manages change detectors for different entity types.
  116. This is a singleton-style manager that provides cached change detection
  117. for cross-instance events where only IDs are transmitted.
  118. """
  119. def __init__(self):
  120. self._detectors: Dict[str, ChangeDetector] = {}
  121. self._preloaded: set = set()
  122. def get_detector(self, entity_name: str) -> ChangeDetector:
  123. """Get or create a change detector for an entity type."""
  124. if entity_name not in self._detectors:
  125. self._detectors[entity_name] = ChangeDetector(entity_name)
  126. return self._detectors[entity_name]
  127. def is_preloaded(self, entity_name: str) -> bool:
  128. """Check if an entity type has been preloaded."""
  129. return entity_name in self._preloaded
  130. def mark_preloaded(self, entity_name: str) -> None:
  131. """Mark an entity type as preloaded."""
  132. self._preloaded.add(entity_name)
  133. def clear(self) -> None:
  134. """Clear all caches."""
  135. self._detectors.clear()
  136. self._preloaded.clear()
  137. # Global cache manager instance
  138. _cache_manager = EventCacheManager()
  139. def get_change_detector(entity_name: str) -> ChangeDetector:
  140. """Get a change detector for the specified entity type."""
  141. return _cache_manager.get_detector(entity_name)
  142. def clear_all_caches() -> None:
  143. """Clear all event caches. Useful for testing."""
  144. _cache_manager.clear()
  145. async def preload_cache(entity_name: str, model_class, session) -> int:
  146. """
  147. Preload cache for an entity type by querying all records.
  148. This ensures that the first cross-instance event can detect changes correctly.
  149. Without preloading, the first event will have empty changed_fields.
  150. Args:
  151. entity_name: The entity type name (e.g., 'worker', 'model')
  152. model_class: The SQLModel class to query
  153. session: Database session
  154. Returns:
  155. Number of records cached
  156. Example:
  157. async with async_session() as session:
  158. count = await preload_cache('worker', Worker, session)
  159. logger.info(f"Preloaded {count} workers into cache")
  160. """
  161. manager = _cache_manager
  162. if manager.is_preloaded(entity_name):
  163. return 0
  164. detector = manager.get_detector(entity_name)
  165. records = await model_class.all(session)
  166. for record in records:
  167. if hasattr(record, 'id'):
  168. detector.put(record.id, record)
  169. manager.mark_preloaded(entity_name)
  170. logger.info(f"Preloaded {len(records)} {entity_name} records into event cache")
  171. return len(records)