| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- """
- Local cache for detecting changes in cross-instance events.
- When using PostgreSQL-based pub/sub, only the ID is transmitted across instances.
- Subscribers use this cache to store the previous state and detect what fields changed.
- """
- import logging
- from typing import Any, Dict, Generic, Optional, Tuple, TypeVar
- from cachetools import LRUCache
- logger = logging.getLogger(__name__)
- T = TypeVar("T")
- class ChangeDetector(Generic[T]):
- """
- Detects changes between old and new object states.
- Usage:
- detector = ChangeDetector[Worker]("worker")
- # On first event
- worker = await Worker.one_by_id(session, event_id)
- detector.put(worker.id, worker)
- # On subsequent events
- old_worker = detector.get(event_id) # Get cached old state
- worker = await Worker.one_by_id(session, event_id) # Query new state
- changed_fields = detector.detect_changes(old_worker, worker)
- detector.put(worker.id, worker) # Update cache
- """
- def __init__(self, entity_name: str, maxsize: int = 10000):
- self._entity_name = entity_name
- self._cache: LRUCache[Any, T] = LRUCache(maxsize=maxsize)
- def get(self, id: Any) -> Optional[T]:
- """Get cached object by ID."""
- return self._cache.get(id)
- def put(self, id: Any, obj: T) -> None:
- """Cache an object."""
- self._cache[id] = obj
- def remove(self, id: Any) -> None:
- """Remove an object from cache."""
- self._cache.pop(id, None)
- def detect_changes(
- self, old_obj: Optional[T], new_obj: T
- ) -> Dict[str, Tuple[Any, Any]]:
- """
- Detect field changes between old and new object.
- For list (relationship) fields, emit a ``(removed, added)`` delta
- matching the shape produced by the local ``find_history`` hook in
- ``active_record.py``, so callbacks work identically on local and
- cross-instance paths. Fields that fail to load (e.g. lazy relationship
- on a detached instance) are silently skipped.
- Returns:
- Dict of field_name -> (old_value, new_value) for scalar fields, or
- (removed_list, added_list) for relationship fields.
- """
- if old_obj is None:
- return {}
- changed_fields = {}
- # Get fields to compare (exclude internal SQLModel fields)
- fields_to_compare = getattr(new_obj, "model_fields", None)
- if fields_to_compare is None:
- # Fallback: compare all attributes
- fields_to_compare = [
- attr
- for attr in dir(new_obj)
- if not attr.startswith("_")
- and not callable(getattr(new_obj, attr, None))
- ]
- for field_name in fields_to_compare:
- if field_name.startswith("_"):
- continue
- try:
- old_val = getattr(old_obj, field_name, None)
- new_val = getattr(new_obj, field_name, None)
- if isinstance(old_val, list) or isinstance(new_val, list):
- old_list = old_val if isinstance(old_val, list) else []
- new_list = new_val if isinstance(new_val, list) else []
- diff = self._list_diff(old_list, new_list)
- if diff is not None:
- changed_fields[field_name] = diff
- continue
- if old_val != new_val:
- changed_fields[field_name] = (old_val, new_val)
- except Exception as e:
- logger.debug(
- f"Error comparing field {field_name} for {self._entity_name}: {e}"
- )
- continue
- return changed_fields
- @staticmethod
- def _list_diff(old_list: list, new_list: list) -> Optional[Tuple[list, list]]:
- """Return a ``(removed, added)`` delta between two relationship lists.
- Elements are keyed by ``.id`` (attribute) or ``["id"]`` (dict). If any
- element is keyless, fall back to whole-list equality and emit an empty
- delta to signal a change without trying to attribute add/remove.
- Returns None when the lists are equivalent (no change).
- """
- def key_of(item: Any) -> Any:
- if item is None:
- return None
- if hasattr(item, "id"):
- return getattr(item, "id")
- if isinstance(item, dict):
- return item.get("id")
- return None
- old_keys = [key_of(o) for o in old_list]
- new_keys = [key_of(n) for n in new_list]
- if any(k is None for k in old_keys) or any(k is None for k in new_keys):
- # Keyless elements — can't reliably attribute add/remove.
- return None if old_list == new_list else ([], [])
- old_set = set(old_keys)
- new_set = set(new_keys)
- if old_set == new_set:
- return None
- removed = [o for o, k in zip(old_list, old_keys) if k not in new_set]
- added = [n for n, k in zip(new_list, new_keys) if k not in old_set]
- return (removed, added)
- class EventCacheManager:
- """
- Manages change detectors for different entity types.
- This is a singleton-style manager that provides cached change detection
- for cross-instance events where only IDs are transmitted.
- """
- def __init__(self):
- self._detectors: Dict[str, ChangeDetector] = {}
- self._preloaded: set = set()
- def get_detector(self, entity_name: str) -> ChangeDetector:
- """Get or create a change detector for an entity type."""
- if entity_name not in self._detectors:
- self._detectors[entity_name] = ChangeDetector(entity_name)
- return self._detectors[entity_name]
- def is_preloaded(self, entity_name: str) -> bool:
- """Check if an entity type has been preloaded."""
- return entity_name in self._preloaded
- def mark_preloaded(self, entity_name: str) -> None:
- """Mark an entity type as preloaded."""
- self._preloaded.add(entity_name)
- def clear(self) -> None:
- """Clear all caches."""
- self._detectors.clear()
- self._preloaded.clear()
- # Global cache manager instance
- _cache_manager = EventCacheManager()
- def get_change_detector(entity_name: str) -> ChangeDetector:
- """Get a change detector for the specified entity type."""
- return _cache_manager.get_detector(entity_name)
- def clear_all_caches() -> None:
- """Clear all event caches. Useful for testing."""
- _cache_manager.clear()
- async def preload_cache(entity_name: str, model_class, session) -> int:
- """
- Preload cache for an entity type by querying all records.
- This ensures that the first cross-instance event can detect changes correctly.
- Without preloading, the first event will have empty changed_fields.
- Args:
- entity_name: The entity type name (e.g., 'worker', 'model')
- model_class: The SQLModel class to query
- session: Database session
- Returns:
- Number of records cached
- Example:
- async with async_session() as session:
- count = await preload_cache('worker', Worker, session)
- logger.info(f"Preloaded {count} workers into cache")
- """
- manager = _cache_manager
- if manager.is_preloaded(entity_name):
- return 0
- detector = manager.get_detector(entity_name)
- records = await model_class.all(session)
- for record in records:
- if hasattr(record, 'id'):
- detector.put(record.id, record)
- manager.mark_preloaded(entity_name)
- logger.info(f"Preloaded {len(records)} {entity_name} records into event cache")
- return len(records)
|