active_record.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. import asyncio
  2. from datetime import datetime, timedelta, timezone
  3. import importlib
  4. import json
  5. import logging
  6. import math
  7. from typing import Any, AsyncGenerator, Callable, Iterable, List, Optional, Union, Tuple
  8. import anyio
  9. from fastapi.encoders import jsonable_encoder
  10. from sqlalchemy import func, event as sa_event, inspect
  11. from sqlmodel import SQLModel, and_, asc, col, desc, or_, select, text
  12. from sqlmodel.ext.asyncio.session import AsyncSession
  13. from sqlalchemy.exc import IntegrityError, OperationalError
  14. from sqlalchemy.orm import Session
  15. from sqlalchemy.orm.exc import FlushError
  16. from sqlalchemy.orm.state import InstanceState
  17. from gpustack.schemas.common import PaginatedList, Pagination
  18. from gpustack.server.bus import Event, EventType, event_bus
  19. from gpustack.server.cache import locked_cached, delete_cache_by_key, class_key
  20. from gpustack.server.db import async_session
  21. logger = logging.getLogger(__name__)
  22. class CommitEvent:
  23. name: str
  24. event: Event
  25. def __init__(self, name: str, type: EventType, data: Any):
  26. self.name = name
  27. self.event = Event(type=type, data=data)
  28. @sa_event.listens_for(Session, "before_flush")
  29. def find_history(session: AsyncSession, flush_context, instances):
  30. events: List[CommitEvent] = session.info.get("pending_events", [])
  31. if len(events) == 0:
  32. return
  33. dirty_objs = [obj for obj in session.dirty if instances is None or obj in instances]
  34. for event in events:
  35. if event.event.data not in dirty_objs:
  36. continue
  37. obj = event.event.data
  38. relationship_keys = set(rel.key for rel in obj.__mapper__.relationships)
  39. state = inspect(obj)
  40. for attr in state.attrs:
  41. hist = attr.history
  42. if hist.has_changes():
  43. if attr.key in relationship_keys:
  44. added_dump = [
  45. obj.__class__.model_validate(obj.model_dump())
  46. for obj in hist.added
  47. if obj is not None
  48. ]
  49. deleted_dump = [
  50. obj.__class__.model_validate(obj.model_dump())
  51. for obj in hist.deleted
  52. if obj is not None
  53. ]
  54. event.event.changed_fields[attr.key] = (deleted_dump, added_dump)
  55. else:
  56. event.event.changed_fields[attr.key] = (hist.deleted, hist.added)
  57. # commit hook to send events after a database commit
  58. @sa_event.listens_for(Session, "after_commit")
  59. def send_post_commit_events(session: AsyncSession):
  60. events: List[CommitEvent] = session.info.pop("pending_events", [])
  61. for event in events:
  62. # copy before submit to avoid mutation
  63. id = getattr(event.event.data, "id", None)
  64. logger.trace(f"Sending event {event.name} of type {event.event.type}, id {id}")
  65. bus_event = event.event
  66. try:
  67. # Detach from the SQLAlchemy session so subscribers don't see
  68. # lazy loads or further mutations on the same row.
  69. bus_event.data = bus_event.data.model_copy(deep=True)
  70. asyncio.create_task(event_bus.publish(event.name, bus_event))
  71. except Exception as e:
  72. logger.exception(f"Failed to publish events: {e}")
  73. class ActiveRecordMixin:
  74. """ActiveRecordMixin provides a set of methods to interact with the database."""
  75. __config__ = None
  76. @property
  77. def primary_key(self):
  78. """Return the primary key of the object."""
  79. return self.__mapper__.primary_key_from_instance(self)
  80. @classmethod
  81. async def first(cls, session: AsyncSession):
  82. """Return the first object of the model."""
  83. statement = select(cls)
  84. result = await session.exec(statement)
  85. return result.first()
  86. @classmethod
  87. async def one_by_id(
  88. cls,
  89. session: AsyncSession,
  90. id: int,
  91. for_update: bool = False,
  92. options: Optional[List] = None,
  93. ):
  94. """Return the object with the given id. Return None if not found.
  95. If `for_update` is True, the row will be locked until the end of the transaction.
  96. If `options` is provided, it will be passed to the query for eager loading relationships.
  97. """
  98. return await session.get(cls, id, with_for_update=for_update, options=options)
  99. @classmethod
  100. async def first_by_field(cls, session: AsyncSession, field: str, value: Any):
  101. """Return the first object with the given field and value. Return None if not found."""
  102. return await cls.first_by_fields(session, {field: value})
  103. @classmethod
  104. async def one_by_field(
  105. cls,
  106. session: AsyncSession,
  107. field: str,
  108. value: Any,
  109. options: Optional[List] = None,
  110. ):
  111. """Return the object with the given field and value. Return None if not found."""
  112. return await cls.one_by_fields(session, {field: value}, options=options)
  113. @classmethod
  114. async def first_by_fields(cls, session: AsyncSession, fields: dict):
  115. """
  116. Return the first object with the given fields and values.
  117. Return None if not found.
  118. """
  119. statement = select(cls)
  120. for key, value in fields.items():
  121. statement = statement.where(getattr(cls, key) == value)
  122. result = await session.exec(statement)
  123. return result.first()
  124. @classmethod
  125. async def one_by_fields(
  126. cls, session: AsyncSession, fields: dict, options: Optional[List] = None
  127. ):
  128. """Return the object with the given fields and values. Return None if not found."""
  129. statement = select(cls)
  130. for key, value in fields.items():
  131. statement = statement.where(getattr(cls, key) == value)
  132. if options:
  133. statement = statement.options(*options)
  134. result = await session.exec(statement)
  135. return result.first()
  136. @classmethod
  137. async def all_by_field(
  138. cls, session: AsyncSession, field: str, value: Any, for_update: bool = False
  139. ):
  140. """
  141. Return all objects with the given field and value.
  142. Return an empty list if not found.
  143. """
  144. statement = select(cls).where(getattr(cls, field) == value)
  145. if for_update:
  146. statement = statement.with_for_update()
  147. result = await session.exec(statement)
  148. return result.all()
  149. @classmethod
  150. async def all_by_fields(
  151. cls,
  152. session: AsyncSession,
  153. fields: dict = {},
  154. fuzzy_fields: Optional[dict] = None,
  155. extra_conditions: Optional[List] = None,
  156. options: Optional[List] = None,
  157. ):
  158. """
  159. Return all objects with the given fields and values.
  160. Return an empty list if not found.
  161. """
  162. statement = select(cls)
  163. for key, value in fields.items():
  164. statement = statement.where(getattr(cls, key) == value)
  165. if fuzzy_fields:
  166. fuzzy_conditions = [
  167. func.lower(getattr(cls, key)).like(f"%{str(value).lower()}%")
  168. for key, value in fuzzy_fields.items()
  169. ]
  170. statement = statement.where(or_(*fuzzy_conditions))
  171. if extra_conditions:
  172. statement = statement.where(and_(*extra_conditions))
  173. if options:
  174. statement = statement.options(*options)
  175. result = await session.exec(statement)
  176. return result.all()
  177. @classmethod
  178. async def paginated_by_query(
  179. cls,
  180. session: AsyncSession,
  181. fields: Optional[dict] = None,
  182. fuzzy_fields: Optional[dict] = None,
  183. extra_conditions: Optional[List] = None,
  184. page: int = 1,
  185. per_page: int = 100,
  186. order_by: Optional[List[Tuple[Union[str, Any], str]]] = None,
  187. options: Optional[List] = None,
  188. ) -> PaginatedList[SQLModel]:
  189. """
  190. Return a paginated and optionally sorted list of objects matching the given query criteria.
  191. Args:
  192. session (AsyncSession): The SQLAlchemy async session used to interact with the database.
  193. fields (Optional[dict]): Exact match filters as key-value pairs.
  194. fuzzy_fields (Optional[dict]): Fuzzy match filters using the SQL `LIKE` operator.
  195. extra_conditions (Optional[List]): Additional SQLAlchemy conditions to apply to the query.
  196. page (int): Page number for pagination, starting from 1. Default is 1.
  197. per_page (int): Number of items per page. Default is 100.
  198. order_by (Optional[List[Tuple[Union[str, Any], str]]]): List of tuples specifying the
  199. fields or expressions to sort by and their respective directions ('asc' or 'desc').
  200. If not provided, defaults to `created_at DESC`.
  201. Returns:
  202. PaginatedList[SQLModel]: A paginated list of matching objects with pagination metadata.
  203. """
  204. statement = select(cls)
  205. if fields:
  206. conditions = [
  207. col(getattr(cls, key)) == value for key, value in fields.items()
  208. ]
  209. statement = statement.where(and_(*conditions))
  210. if fuzzy_fields:
  211. fuzzy_conditions = [
  212. func.lower(getattr(cls, key)).like(f"%{str(value).lower()}%")
  213. for key, value in fuzzy_fields.items()
  214. ]
  215. statement = statement.where(or_(*fuzzy_conditions))
  216. if extra_conditions:
  217. statement = statement.where(and_(*extra_conditions))
  218. if options:
  219. statement = statement.options(*options)
  220. if not order_by:
  221. order_by = [("created_at", "desc")]
  222. for field_expression, direction in order_by:
  223. if isinstance(field_expression, str):
  224. if '.' in str(field_expression):
  225. # Nested fields for JSON columns
  226. expr = cls._parse_nested_field_expression(session, field_expression)
  227. statement = statement.order_by(
  228. asc(expr) if direction.lower() == "asc" else desc(expr)
  229. )
  230. else:
  231. # Regular fields
  232. column = col(getattr(cls, field_expression))
  233. statement = statement.order_by(
  234. asc(column) if direction.lower() == "asc" else desc(column)
  235. )
  236. else:
  237. # Expression
  238. statement = statement.order_by(
  239. asc(field_expression)
  240. if direction.lower() == "asc"
  241. else desc(field_expression)
  242. )
  243. if page is not None and page > 0 and per_page is not None:
  244. statement = statement.offset((page - 1) * per_page).limit(per_page)
  245. items = (await session.exec(statement)).all()
  246. count_statement = select(func.count(cls.id))
  247. if fields:
  248. conditions = [
  249. col(getattr(cls, key)) == value for key, value in fields.items()
  250. ]
  251. count_statement = count_statement.where(and_(*conditions))
  252. if fuzzy_fields:
  253. fuzzy_conditions = [
  254. col(getattr(cls, key)).like(f"%{value}%")
  255. for key, value in fuzzy_fields.items()
  256. ]
  257. count_statement = count_statement.where(or_(*fuzzy_conditions))
  258. if extra_conditions:
  259. count_statement = count_statement.where(and_(*extra_conditions))
  260. count = (await session.exec(count_statement)).one()
  261. total_page = math.ceil(count / per_page)
  262. pagination = Pagination(
  263. page=page,
  264. perPage=per_page,
  265. total=count,
  266. totalPage=total_page,
  267. )
  268. return PaginatedList[cls](items=items, pagination=pagination)
  269. @classmethod
  270. def _parse_nested_field_expression(
  271. cls, session: AsyncSession, field_expression: Union[str, Any]
  272. ) -> Any:
  273. """
  274. Parse dot-separated nested field expressions and generate appropriate
  275. database expressions for sorting.
  276. Supports JSON field sorting with dot notation, e.g., "memory.utilization_rate".
  277. Supports casting using "::type" suffix, e.g., "status.memory.utilization_rate::numeric".
  278. Supported casting types include: numeric, boolean, text, date, datetime.
  279. Supports getting the length of JSON arrays using "[]" suffix.
  280. Compatible with both PostgreSQL and MySQL databases.
  281. Args:
  282. cls: SQLModel class
  283. session: Database session (used to detect database dialect)
  284. field_expression: Field or expression, e.g., "memory.utilization_rate"
  285. Returns:
  286. SQLAlchemy expression object suitable for use in ORDER BY clause
  287. Raises:
  288. AttributeError: If the base column doesn't exist in the model
  289. """
  290. if not isinstance(field_expression, str):
  291. # Already an expression
  292. return field_expression
  293. cast_type = None
  294. if "::" in field_expression:
  295. field_expression, cast_type = field_expression.rsplit("::", 1)
  296. ALLOWED_CASTS = {'numeric', 'boolean', 'text', 'date', 'datetime'}
  297. if cast_type and cast_type not in ALLOWED_CASTS:
  298. raise ValueError(f"Invalid cast type: {cast_type}")
  299. is_array_length = False
  300. if field_expression.endswith("[]"):
  301. is_array_length = True
  302. field_expression = field_expression[:-2]
  303. parts = field_expression.split('.')
  304. if len(parts) < 2:
  305. # Regular fields
  306. return getattr(cls, field_expression)
  307. # The first part is the column name in the database table
  308. column_name = parts[0]
  309. json_path_parts = parts[1:]
  310. dialect = None
  311. if session.bind and hasattr(session.bind, 'dialect'):
  312. dialect = session.bind.dialect.name
  313. if dialect == 'postgresql':
  314. if is_array_length:
  315. # Build PGSQL JSON path length expression like jsonb_array_length(status->'gpus')
  316. json_path_str = "->".join([f"'{part}'" for part in json_path_parts])
  317. json_expr = text(
  318. f"COALESCE(jsonb_array_length(({column_name}->{json_path_str})::jsonb), 0)"
  319. )
  320. else:
  321. # Build PGSQL JSON path like '(status#>>{"memory","utilization_rate"})::numeric'
  322. json_path_str = ",".join([f'"{part}"' for part in json_path_parts])
  323. if not cast_type:
  324. cast_type = "numeric"
  325. json_expr = text(
  326. f"({column_name}#>>'{{{json_path_str}}}')::{cast_type}"
  327. )
  328. elif dialect == 'mysql':
  329. if is_array_length:
  330. # Build MySQL JSON path length expression like JSON_LENGTH(status, '$.gpus')
  331. json_path = '.'.join(json_path_parts)
  332. json_expr = text(
  333. f"COALESCE(JSON_LENGTH({column_name}, '$.{json_path}'), 0)"
  334. )
  335. else:
  336. # Build MySQL JSON path like '$.utilization_rate' or '$.memory.utilization_rate'
  337. json_path = '.'.join(json_path_parts)
  338. json_expr = text(f"JSON_VALUE({column_name}, '$.{json_path}')")
  339. else:
  340. # Should not reach
  341. raise RuntimeError(f"Unsupported database dialect: {dialect}")
  342. return json_expr
  343. @classmethod
  344. def convert_without_saving(
  345. cls, source: Union[dict, SQLModel], update: Optional[dict] = None
  346. ) -> SQLModel:
  347. """
  348. Convert the source to the model without saving to the database.
  349. Return None if failed.
  350. """
  351. if isinstance(source, cls):
  352. obj = source
  353. if update:
  354. for k, v in update.items():
  355. setattr(obj, k, v)
  356. elif isinstance(source, SQLModel):
  357. obj = cls.from_orm(source, update=update)
  358. elif isinstance(source, dict):
  359. obj = cls.parse_obj(source, update=update)
  360. return obj
  361. async def _refresh_related_objects(self, session: AsyncSession):
  362. """Refresh all related objects of the given object."""
  363. for rel in self.__mapper__.relationships:
  364. if rel.direction.name != "MANYTOONE":
  365. continue
  366. if not hasattr(self, rel.key):
  367. continue
  368. rel_obj = getattr(self, rel.key, None)
  369. if rel_obj is None:
  370. continue
  371. elif isinstance(rel_obj, list) and len(rel_obj) == 0:
  372. continue
  373. elif isinstance(rel_obj, InstanceState):
  374. continue
  375. await session.refresh(rel_obj)
  376. def _get_flush_targets(self, session: AsyncSession) -> List[SQLModel]:
  377. # always needs to flush self
  378. rtn = [self]
  379. state: InstanceState = inspect(self)
  380. dirty_objs = [obj for obj in session.dirty]
  381. for rel in self.__mapper__.relationships:
  382. if rel.direction.name != "MANYTOMANY":
  383. continue
  384. attr = state.attrs[rel.key]
  385. if not attr.history.has_changes():
  386. continue
  387. for obj in attr.value:
  388. if obj in dirty_objs:
  389. rtn.append(obj)
  390. for obj in attr.history.deleted:
  391. if obj in dirty_objs:
  392. rtn.append(obj)
  393. return rtn
  394. @classmethod
  395. async def create(
  396. cls,
  397. session: AsyncSession,
  398. source: Union[dict, SQLModel],
  399. update: Optional[dict] = None,
  400. auto_commit: bool = True,
  401. ) -> Optional[SQLModel]:
  402. """Create and save a new record for the model."""
  403. obj = cls.convert_without_saving(source, update)
  404. if obj is None:
  405. return None
  406. cls._publish_event_after_commit(session, EventType.CREATED, obj)
  407. await obj.save(session, auto_commit=auto_commit)
  408. return obj
  409. @classmethod
  410. async def create_or_update(
  411. cls,
  412. session: AsyncSession,
  413. source: Union[dict, SQLModel],
  414. update: Optional[dict] = None,
  415. auto_commit: bool = True,
  416. ) -> Optional[SQLModel]:
  417. """Create or update a record for the model."""
  418. obj = cls.convert_without_saving(source, update)
  419. if obj is None:
  420. return None
  421. pk = cls.__mapper__.primary_key_from_instance(obj)
  422. if pk[0] is not None:
  423. existing = await session.get(cls, pk)
  424. if existing is None:
  425. return None
  426. else:
  427. await existing.update(session, obj, auto_commit=auto_commit)
  428. return existing
  429. else:
  430. return await cls.create(session, obj, auto_commit=auto_commit)
  431. @classmethod
  432. async def count(cls, session: AsyncSession) -> int:
  433. """Return the number of records in the model."""
  434. statement = select(func.count()).select_from(cls)
  435. result = await session.exec(statement)
  436. return result.one()
  437. @classmethod
  438. async def count_by_field(cls, session: AsyncSession, field: str, value: Any) -> int:
  439. """Return the number of records matching the given field and value."""
  440. return await cls.count_by_fields(session, {field: value})
  441. @classmethod
  442. async def count_by_fields(
  443. cls,
  444. session: AsyncSession,
  445. fields: dict = {},
  446. extra_conditions: Optional[List] = None,
  447. ) -> int:
  448. """
  449. Return the number of records matching the given fields and conditions.
  450. """
  451. statement = select(func.count(cls.id))
  452. for key, value in fields.items():
  453. statement = statement.where(getattr(cls, key) == value)
  454. if extra_conditions:
  455. statement = statement.where(and_(*extra_conditions))
  456. result = await session.exec(statement)
  457. return result.one_or_none() or 0
  458. async def refresh(self, session: AsyncSession):
  459. """Refresh the object from the database."""
  460. await session.refresh(self)
  461. async def save(self, session: AsyncSession, auto_commit=True):
  462. """Save the object to the database. Raise exception if failed."""
  463. session.add(self)
  464. try:
  465. targets = self._get_flush_targets(session)
  466. await session.flush(targets)
  467. if auto_commit:
  468. await session.commit()
  469. await session.refresh(self)
  470. if session.is_active:
  471. await self._refresh_related_objects(session)
  472. # Invalidate cached_all cache on successful write
  473. await self.__class__._invalidate_cached_all()
  474. except (IntegrityError, OperationalError, FlushError) as e:
  475. await session.rollback()
  476. raise e
  477. except Exception as e:
  478. await session.rollback()
  479. raise e
  480. async def update(
  481. self,
  482. session: AsyncSession,
  483. source: Union[dict, SQLModel, None] = None,
  484. auto_commit=True,
  485. ):
  486. """Update the object with the source and save to the database."""
  487. if isinstance(source, SQLModel):
  488. source = {
  489. key: getattr(source, key, None) for key in source.model_fields_set
  490. }
  491. elif source is None:
  492. source = {}
  493. for key, value in source.items():
  494. setattr(self, key, value)
  495. self._publish_event_after_commit(session, EventType.UPDATED, self)
  496. await self.save(session, auto_commit=auto_commit)
  497. @classmethod
  498. async def batch_update(
  499. cls,
  500. session: AsyncSession,
  501. updates: List[SQLModel],
  502. auto_commit: bool = True,
  503. ) -> int:
  504. """Batch update multiple records with different data.
  505. Args:
  506. session: The database session
  507. updates: A list of SQLModel objects with id field
  508. auto_commit: Whether to commit the transaction automatically
  509. Returns:
  510. The number of records successfully updated
  511. Example:
  512. updates = [
  513. Model(id=1, name="llama", state="ready"),
  514. Model(id=2, name="qwen", state="not_ready"),
  515. ]
  516. count = await Model.batch_update(session, updates)
  517. """
  518. if not updates:
  519. return 0
  520. try:
  521. for obj in updates:
  522. cls._publish_event_after_commit(session, EventType.UPDATED, obj)
  523. session.add(obj)
  524. if auto_commit:
  525. await session.commit()
  526. # Invalidate cached_all cache after successful batch update
  527. await cls._invalidate_cached_all()
  528. return len(updates)
  529. except Exception as e:
  530. await session.rollback()
  531. raise e
  532. async def delete(self, session: AsyncSession, soft=False, auto_commit=True):
  533. """Delete the object from the database."""
  534. self._publish_event_after_commit(session, EventType.DELETED, self)
  535. if soft or self._has_cascade_delete():
  536. if hasattr(self, "deleted_at"):
  537. # timestamp is stored without timezone in db
  538. self.deleted_at = datetime.now(timezone.utc).replace(tzinfo=None)
  539. await self.save(session, auto_commit=False)
  540. await self._handle_cascade_delete(session, soft=soft, auto_commit=False)
  541. if not soft:
  542. await session.delete(self)
  543. if not auto_commit:
  544. return
  545. await session.commit()
  546. # Invalidate cached_all cache after successful delete
  547. await self.__class__._invalidate_cached_all()
  548. async def _handle_cascade_delete(
  549. self, session: AsyncSession, soft=False, auto_commit=True
  550. ):
  551. """Handle cascading deletes for all defined relationships."""
  552. for rel in self.__mapper__.relationships:
  553. if rel.cascade.delete:
  554. # Load the related objects
  555. await session.refresh(self)
  556. related_objects = getattr(self, rel.key)
  557. # Delete each related object
  558. if isinstance(related_objects, list):
  559. for related_object in related_objects:
  560. await related_object.delete(
  561. session, soft=soft, auto_commit=auto_commit
  562. )
  563. elif related_objects:
  564. await related_objects.delete(
  565. session, soft=soft, auto_commit=auto_commit
  566. )
  567. def _has_cascade_delete(self):
  568. """Check if the model has cascade delete relationships."""
  569. return any(rel.cascade.delete for rel in self.__mapper__.relationships)
  570. @classmethod
  571. async def all(cls, session: AsyncSession, options: Optional[List] = None):
  572. """Return all objects of the model."""
  573. statement = select(cls)
  574. if options:
  575. statement = statement.options(*options)
  576. result = await session.exec(statement)
  577. return result.all()
  578. @classmethod
  579. async def _do_cached_all_query(cls, options: Optional[List] = None):
  580. """Execute the cached_all query in a shielded context.
  581. This runs the entire database operation including session cleanup
  582. in a way that's protected from anyio cancellation.
  583. """
  584. session = async_session()
  585. try:
  586. results = await cls.all(session, options=options)
  587. for item in results:
  588. session.expunge(item)
  589. return results
  590. finally:
  591. await session.close()
  592. @classmethod
  593. @locked_cached(key=class_key("cached_all"))
  594. async def cached_all(cls, options: Optional[List] = None):
  595. """Return all objects with caching for subscribe() initial data loading."""
  596. logger.debug(f"Loading cached {cls.__name__} with options={options}")
  597. # Run the entire database operation in a shielded context to protect
  598. # from anyio cancellation. This prevents connection pool issues when
  599. # CancelledError interrupts database operations or session cleanup.
  600. with anyio.CancelScope(shield=True):
  601. return await cls._do_cached_all_query(options)
  602. @classmethod
  603. async def _invalidate_cached_all(cls):
  604. """Invalidate cached_all cache for this model class."""
  605. cache_key = class_key("cached_all")(None, cls)
  606. await delete_cache_by_key(_key=cache_key)
  607. @classmethod
  608. async def delete_all(cls, session: AsyncSession, soft=False):
  609. """Delete all objects of the model."""
  610. for obj in await cls.all(session):
  611. cls._publish_event_after_commit(session, EventType.DELETED, obj)
  612. await obj.delete(session, soft=soft, auto_commit=False)
  613. try:
  614. await session.commit()
  615. # Invalidate cached_all cache after successful delete_all
  616. await cls._invalidate_cached_all()
  617. except Exception as e:
  618. await session.rollback()
  619. logger.error(f"Failed to delete all objects of {cls.__name__}: {e}")
  620. raise
  621. @classmethod
  622. def _publish_event_after_commit(
  623. cls, session: AsyncSession, event_type: str, data: Any
  624. ):
  625. session.info.setdefault("pending_events", []).append(
  626. CommitEvent(
  627. name=cls.__name__.lower(),
  628. type=event_type,
  629. data=data,
  630. )
  631. )
  632. @classmethod
  633. async def subscribe(
  634. cls,
  635. source: str,
  636. options: Optional[List] = None,
  637. event_types: Optional[Iterable[EventType]] = None,
  638. replay_existing: bool = True,
  639. ) -> AsyncGenerator[Event, None]:
  640. """Subscribe to bus events for this model.
  641. ``source`` labels the consumer in queue-full logs. ``event_types``
  642. whitelists pre-enqueue, so filtered events don't take queue slots.
  643. ``replay_existing=False`` skips the initial CREATED snapshot for
  644. consumers that bootstrap themselves.
  645. """
  646. topic = cls.__name__.lower()
  647. subscriber = event_bus.subscribe(topic, source=source, event_types=event_types)
  648. logger.info(
  649. "subscribed, source=%s topic=%s subscriber=%s",
  650. source,
  651. topic,
  652. id(subscriber),
  653. )
  654. if replay_existing:
  655. include_created = event_types is None or EventType.CREATED in event_types
  656. if include_created:
  657. initial_items = await cls.cached_all(options=options)
  658. for item in initial_items:
  659. yield Event(type=EventType.CREATED, data=item)
  660. heartbeat_interval = timedelta(seconds=15)
  661. last_event_time = datetime.now(timezone.utc)
  662. try:
  663. while True:
  664. try:
  665. event = await asyncio.wait_for(
  666. subscriber.receive(), timeout=heartbeat_interval.total_seconds()
  667. )
  668. yield event
  669. except asyncio.TimeoutError:
  670. if (
  671. datetime.now(timezone.utc) - last_event_time
  672. >= heartbeat_interval
  673. ):
  674. yield Event(type=EventType.HEARTBEAT, data=None)
  675. last_event_time = datetime.now(timezone.utc)
  676. finally:
  677. event_bus.unsubscribe(cls.__name__.lower(), subscriber)
  678. @classmethod
  679. async def streaming(
  680. cls,
  681. fields: Optional[dict] = None,
  682. fuzzy_fields: Optional[dict] = None,
  683. filter_func: Optional[Callable[[Any], bool]] = None,
  684. options: Optional[List] = None,
  685. ) -> AsyncGenerator[str, None]:
  686. """Stream events matching the given criteria as JSON strings.
  687. Args:
  688. fields: Exact match filters as key-value pairs
  689. fuzzy_fields: Fuzzy match filters
  690. filter_func: Optional filter function to apply to event data
  691. options: SQLAlchemy options for eager loading relationships (e.g., selectinload)
  692. """
  693. try:
  694. async for event in cls.subscribe(source="streaming", options=options):
  695. if event.type == EventType.HEARTBEAT:
  696. yield "\n\n"
  697. continue
  698. if not cls._match_fields(event, fields):
  699. continue
  700. if not cls._match_fuzzy_fields(event, fuzzy_fields):
  701. continue
  702. if filter_func and not filter_func(event.data):
  703. continue
  704. public_event = Event(
  705. type=event.type,
  706. data=cls._convert_to_public_class(event.data),
  707. changed_fields=event.changed_fields,
  708. id=event.id,
  709. )
  710. formatted = cls._format_event(public_event)
  711. if formatted is not None:
  712. yield formatted
  713. except asyncio.CancelledError:
  714. pass
  715. except Exception as e:
  716. logger.error(f"Error in streaming {cls.__name__}: {e}")
  717. @classmethod
  718. def _match_fields(cls, event: Any, fields: Optional[dict]) -> bool:
  719. """Match fields using AND condition."""
  720. for key, value in (fields or {}).items():
  721. if getattr(event.data, key, None) != value:
  722. return False
  723. return True
  724. @classmethod
  725. def _match_fuzzy_fields(cls, event: Any, fuzzy_fields: Optional[dict]) -> bool:
  726. """Match fuzzy fields using OR condition."""
  727. for key, value in (fuzzy_fields or {}).items():
  728. attr_value = str(getattr(event.data, key, "")).lower()
  729. if str(value).lower() in attr_value:
  730. return True
  731. return not fuzzy_fields
  732. @classmethod
  733. def _convert_to_public_class(cls, data: Any) -> Any:
  734. """Convert the instance to the corresponding Public class if it exists."""
  735. # If data is a dict (e.g., ID-only event), return as-is
  736. if isinstance(data, dict):
  737. return data
  738. class_module = importlib.import_module(cls.__module__)
  739. public_class = getattr(class_module, f"{cls.__name__}Public", None)
  740. return public_class.model_validate(data) if public_class else data
  741. @staticmethod
  742. def _format_event(event: Any) -> Optional[str]:
  743. """Format the event as a JSON string."""
  744. # Skip ID-only events for CREATED/UPDATED - they can't be properly validated by clients
  745. # But allow ID-only for DELETED - clients can still remove by ID from their cache
  746. if (
  747. event.type != EventType.DELETED
  748. and isinstance(event.data, dict)
  749. and set(event.data.keys()) == {"id"}
  750. ):
  751. return None
  752. return json.dumps(jsonable_encoder(event), separators=(",", ":")) + "\n\n"