base.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """
  2. Coordinator Abstract Base Class.
  3. This module defines the interface for coordinating multiple server instances.
  4. The open-source edition provides a local (single-node) implementation, while
  5. the enterprise edition provides distributed implementations using Redis or
  6. PostgreSQL.
  7. """
  8. from abc import ABC, abstractmethod
  9. from dataclasses import dataclass, field
  10. from typing import Any, Callable, Dict, List, Optional, Tuple
  11. from enum import Enum
  12. import logging
  13. logger = logging.getLogger(__name__)
  14. class EventType(Enum):
  15. CREATED = 1
  16. UPDATED = 2
  17. DELETED = 3
  18. UNKNOWN = 4
  19. HEARTBEAT = 5
  20. def __str__(self):
  21. return self.name
  22. @dataclass
  23. class Event:
  24. type: EventType
  25. data: Any
  26. changed_fields: Dict[str, Tuple[Any, Any]] = field(default_factory=dict)
  27. id: Optional[Any] = None
  28. def __post_init__(self):
  29. if isinstance(self.type, int):
  30. self.type = EventType(self.type)
  31. if self.id is None:
  32. self.id = self._derive_id_from_data()
  33. def _derive_id_from_data(self) -> Optional[Any]:
  34. if self.data is None:
  35. return None
  36. # For SQLModel objects
  37. if hasattr(self.data, "id"):
  38. return getattr(self.data, "id")
  39. # For plain dict
  40. if isinstance(self.data, dict):
  41. return self.data.get("id")
  42. return None
  43. def to_dict(self) -> Dict:
  44. """Serialize event to dict for transmission.
  45. For cross-instance communication, only the ID is transmitted.
  46. Subscribers should fetch full data from database and maintain local cache
  47. to detect changes if needed.
  48. """
  49. # Only pass ID to avoid serialization issues and NOTIFY payload limits
  50. data = None
  51. if self.data is not None:
  52. if hasattr(self.data, "id"):
  53. # SQLModel object - only get ID
  54. data = {"id": getattr(self.data, 'id')}
  55. elif isinstance(self.data, dict):
  56. data = {"id": self.data.get("id")} if "id" in self.data else self.data
  57. else:
  58. data = {"id": self.id} if self.id is not None else None
  59. return {
  60. "type": self.type.name,
  61. "data": data,
  62. # changed_fields is not transmitted across instances
  63. # Subscribers should detect changes using local cache
  64. "id": self.id,
  65. }
  66. @classmethod
  67. def from_dict(cls, data: Dict) -> "Event":
  68. """Deserialize event from dict."""
  69. return cls(
  70. type=EventType[data.get("type", "UNKNOWN")],
  71. data=data.get("data"),
  72. # changed_fields is not transmitted, subscribers detect changes locally
  73. id=data.get("id"),
  74. )
  75. class Coordinator(ABC):
  76. """
  77. Abstract base class for coordinating server instances.
  78. Implementations must provide:
  79. - Leader election for active-passive mode
  80. - Pub/Sub for event distribution across instances
  81. """
  82. def __init__(
  83. self,
  84. config: Any,
  85. leader_election_ttl: int = 30,
  86. leader_election_renew_interval: int = 10,
  87. ):
  88. self._config = config
  89. self._leader_election_ttl = leader_election_ttl
  90. self._leader_election_renew_interval = leader_election_renew_interval
  91. self._subscribers: Dict[str, List[Callable[[Event], Any]]] = {}
  92. self._is_leader = False
  93. @property
  94. def leader_election_ttl(self) -> int:
  95. """Get the leader election TTL in seconds."""
  96. return self._leader_election_ttl
  97. @property
  98. def leader_election_renew_interval(self) -> int:
  99. """Get the leader election renew interval in seconds."""
  100. return self._leader_election_renew_interval
  101. @abstractmethod
  102. async def start(self):
  103. """Start the coordinator and establish connections."""
  104. pass
  105. @abstractmethod
  106. async def stop(self):
  107. """Stop the coordinator and release resources."""
  108. pass
  109. # Leader Election
  110. @abstractmethod
  111. async def acquire_leadership(self, ttl: int) -> bool:
  112. """
  113. Try to acquire leadership lock.
  114. Args:
  115. ttl: Time to live in seconds for the leadership lock
  116. Returns:
  117. True if leadership was acquired, False otherwise
  118. """
  119. pass
  120. @abstractmethod
  121. async def renew_leadership(self, ttl: int) -> bool:
  122. """
  123. Renew the current leadership lock.
  124. Args:
  125. ttl: Time to live in seconds
  126. Returns:
  127. True if renewal was successful, False if leadership was lost
  128. """
  129. pass
  130. @abstractmethod
  131. async def release_leadership(self):
  132. """Release the current leadership lock."""
  133. pass
  134. def is_leader(self) -> bool:
  135. """Check if this instance is the current leader."""
  136. return self._is_leader
  137. def _set_leader(self, is_leader: bool):
  138. """Internal method to set leadership status."""
  139. was_leader = self._is_leader
  140. self._is_leader = is_leader
  141. if was_leader != is_leader:
  142. logger.info(f"Leadership changed: {was_leader} -> {is_leader}")
  143. # Pub/Sub
  144. @abstractmethod
  145. async def publish(self, channel: str, event: Event):
  146. """
  147. Publish an event to a channel.
  148. Args:
  149. channel: Channel name (e.g., 'model', 'worker')
  150. event: Event to publish
  151. """
  152. pass
  153. def subscribe(self, channel: str, callback: Callable[[Event], Any]):
  154. """
  155. Subscribe to a channel.
  156. Implementations MUST invoke ``callback`` on the main asyncio event
  157. loop. Coordinators whose underlying driver delivers events from a
  158. background thread must bridge to the main loop themselves (e.g. via
  159. ``loop.call_soon_threadsafe``) before calling the callback.
  160. Args:
  161. channel: Channel name
  162. callback: Function to call when event is received
  163. """
  164. if channel not in self._subscribers:
  165. self._subscribers[channel] = []
  166. self._subscribers[channel].append(callback)
  167. logger.debug(f"Subscribed to channel: {channel}")
  168. def unsubscribe(self, channel: str, callback: Callable[[Event], Any]):
  169. """Unsubscribe from a channel."""
  170. if channel in self._subscribers:
  171. self._subscribers[channel].remove(callback)
  172. if not self._subscribers[channel]:
  173. del self._subscribers[channel]
  174. def _notify_local_subscribers(self, channel: str, event: Event):
  175. """Notify local subscribers of an event."""
  176. if channel in self._subscribers:
  177. for callback in self._subscribers[channel]:
  178. try:
  179. callback(event)
  180. except Exception as e:
  181. logger.error(f"Error notifying subscriber: {e}")