services.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. import logging
  2. from datetime import datetime, timezone
  3. from typing import List, Optional, Union, Set, Tuple
  4. from sqlmodel import SQLModel, select
  5. from sqlmodel.ext.asyncio.session import AsyncSession
  6. from sqlalchemy.orm import selectinload
  7. from gpustack.api.exceptions import InternalServerErrorException
  8. from gpustack.schemas.api_keys import ApiKey
  9. from gpustack.schemas.links import ModelRoutePrincipalLink
  10. from gpustack.schemas.model_files import ModelFile
  11. from gpustack.schemas.models import (
  12. Model,
  13. ModelInstance,
  14. ModelInstanceStateEnum,
  15. )
  16. from gpustack.schemas.model_routes import (
  17. ModelRoute,
  18. MyModel,
  19. ModelRouteTarget,
  20. TargetStateEnum,
  21. AccessPolicyEnum,
  22. effective_route_name,
  23. )
  24. from gpustack.schemas.principals import (
  25. OrgRole,
  26. PLATFORM_PRINCIPAL_ID,
  27. Principal,
  28. PrincipalMembership,
  29. PrincipalType,
  30. )
  31. from gpustack.schemas.users import User
  32. from gpustack.schemas.clusters import Cluster
  33. from gpustack.schemas.workers import Worker
  34. from gpustack.server.cache import (
  35. delete_cache_by_key,
  36. locked_cached,
  37. )
  38. logger = logging.getLogger(__name__)
  39. class UserService:
  40. def __init__(self, session: AsyncSession):
  41. self.session = session
  42. @locked_cached()
  43. async def get_by_id(self, user_id: int) -> Optional[User]:
  44. result = await User.one_by_id(
  45. self.session,
  46. user_id,
  47. options=[selectinload(User.worker), selectinload(User.cluster)],
  48. )
  49. if result is None:
  50. return None
  51. if result.worker is not None:
  52. # detach worker to avoid lazy loading
  53. self.session.expunge(result.worker)
  54. self.session.expunge(result)
  55. return result
  56. @locked_cached()
  57. async def get_by_username(self, username: str) -> Optional[User]:
  58. result = await User.one_by_field(self.session, "username", username)
  59. if result is None:
  60. return None
  61. self.session.expunge(result)
  62. return result
  63. async def create(self, user: User):
  64. return await create_user_with_principal(self.session, user)
  65. async def update(self, user: User, source: Union[dict, SQLModel, None] = None):
  66. result = await user.update(self.session, source)
  67. await delete_cache_by_key(self.get_by_id, user.id)
  68. await delete_cache_by_key(self.get_user_accessible_model_names, user.id)
  69. await delete_cache_by_key(self.get_by_username, user.username)
  70. return result
  71. async def delete(self, user: User):
  72. apikeys = await APIKeyService(self.session).get_by_user_id(user.id)
  73. result = await user.delete(self.session)
  74. await delete_cache_by_key(self.get_by_id, user.id)
  75. await delete_cache_by_key(self.get_user_accessible_model_names, user.id)
  76. await delete_cache_by_key(self.get_by_username, user.username)
  77. for apikey in apikeys:
  78. await delete_cache_by_key(
  79. APIKeyService.get_by_access_key, apikey.access_key
  80. )
  81. return result
  82. async def model_allowed_for_user(
  83. self, model_name: str, user_id: int, api_key: Optional[ApiKey]
  84. ) -> bool:
  85. limited_model_names: Optional[Set[str]] = (
  86. set(api_key.allowed_model_names)
  87. if api_key is not None
  88. and api_key.allowed_model_names is not None
  89. and len(api_key.allowed_model_names) > 0
  90. else None
  91. )
  92. accessible_model_names: Set[str] = await self.get_user_accessible_model_names(
  93. user_id
  94. )
  95. allowed = model_name in intersection_nullable_set(
  96. accessible_model_names, limited_model_names
  97. )
  98. if not allowed:
  99. logger.info(
  100. "Access denied: model_name=%r user_id=%d " "accessible=%s limited=%s",
  101. model_name,
  102. user_id,
  103. sorted(accessible_model_names),
  104. sorted(limited_model_names) if limited_model_names else None,
  105. )
  106. return allowed
  107. @locked_cached()
  108. async def get_user_accessible_model_names(self, user_id: int) -> Set[str]:
  109. # Get all accessible model names for the user. The set holds two
  110. # forms per route:
  111. # 1. Org-effective name (`<slug>/<route>` for non-platform
  112. # Orgs, raw for platform) — matches `/v1/models` output and
  113. # the gateway's ingress header matcher.
  114. # 2. Raw `route.name` — matches the post-`modelMapping` value
  115. # that Higress's AI proxy hands back via
  116. # `x-higress-llm-model` on the auth callback. Without this
  117. # the callback would deny chat traffic for non-platform
  118. # Orgs even though the gateway already routed it to the
  119. # correct ingress.
  120. # Cross-Org collisions on raw names are fine: each user's set is
  121. # isolated, and Higress's per-Org ingress already disambiguates
  122. # which underlying instance receives the request.
  123. user: User = await self.get_by_id(user_id)
  124. if user is None:
  125. return set()
  126. if user.is_admin or user.is_system:
  127. routes = await ModelRoute.all_by_field(self.session, "deleted_at", None)
  128. else:
  129. routes = await MyModel.all_by_fields(
  130. self.session, {"user_id": user.id, "deleted_at": None}
  131. )
  132. principal_ids = {
  133. r.owner_principal_id for r in routes if r.owner_principal_id is not None
  134. }
  135. principal_by_id = {}
  136. if principal_ids:
  137. rows = (
  138. await self.session.exec(
  139. select(Principal).where(Principal.id.in_(principal_ids))
  140. )
  141. ).all()
  142. principal_by_id = {p.id: p for p in rows}
  143. names: Set[str] = set()
  144. for r in routes:
  145. owner = (
  146. principal_by_id.get(r.owner_principal_id)
  147. if r.owner_principal_id
  148. else None
  149. )
  150. names.add(
  151. effective_route_name(
  152. r.name,
  153. getattr(owner, "slug", None),
  154. getattr(owner, "id", None) == PLATFORM_PRINCIPAL_ID,
  155. )
  156. )
  157. names.add(r.name)
  158. return names
  159. async def create_user_with_principal(session: AsyncSession, user: User) -> User:
  160. """Persist a User together with its 1:1 USER-principal.
  161. Replaces the bare ``User.create(...)`` call at every user-creation
  162. site (local POST /users, SSO callbacks, bootstrap admin, worker
  163. registration).
  164. Why the dance:
  165. - ``users.principal_id`` is NOT NULL, so the principal row must
  166. exist before the user row is inserted.
  167. - Callers naturally construct users with relationship attributes
  168. (``cluster=cluster``, ``worker=worker``). Those backref-populate
  169. the parent's ``cluster_users`` / ``workers`` collections at
  170. construction time, before the user is in any session — which
  171. both emits a noisy ``SAWarning`` and, more importantly, leaves
  172. a dangling ``InstanceState`` reference that crashes the
  173. bus-event payload deepcopy at commit time.
  174. The fix is to (1) ``session.add(user)`` immediately so the
  175. pre-construction backref entries point at a session-tracked
  176. object, then (2) use the ``user.principal`` relationship attribute
  177. so SQLAlchemy's unit of work inserts the principal first and
  178. auto-populates ``user.principal_id`` during a single combined
  179. flush. The principal's ``slug`` is patched to ``user-{user.id}``
  180. afterward, once the auto-generated user id is known.
  181. Caller commits.
  182. """
  183. # Step 1 — make the session aware of the user before any flush
  184. # touches related collections.
  185. session.add(user)
  186. # Step 2 — link via the relationship attribute so SQLAlchemy
  187. # orders ``principal`` before ``user`` and threads the auto-id
  188. # through automatically.
  189. principal = Principal(
  190. kind=PrincipalType.USER,
  191. name=user.username,
  192. slug=None,
  193. )
  194. user.principal = principal
  195. session.add(principal)
  196. await session.flush([principal, user])
  197. # Step 3 — slug is globally unique among non-NULL values; assign
  198. # the canonical ``user-{id}`` form now that the user id is known.
  199. principal.slug = f"user-{user.id}"
  200. await session.flush([principal])
  201. return user
  202. async def provision_user_principal(session: AsyncSession, user: User) -> Principal:
  203. """Backfill a USER-principal for an existing user that lacks one.
  204. Used by SSO callbacks for users created before the multi-tenancy
  205. migration shipped — they exist in the database without
  206. ``principal_id``. Fresh user creation goes through
  207. ``create_user_with_principal`` instead.
  208. """
  209. principal = Principal(
  210. kind=PrincipalType.USER,
  211. name=user.username,
  212. slug=f"user-{user.id}",
  213. )
  214. session.add(principal)
  215. await session.flush([principal])
  216. user.principal_id = principal.id
  217. session.add(user)
  218. await session.flush([user])
  219. return principal
  220. async def provision_bootstrap_admin_orgs(session: AsyncSession, user: User) -> None:
  221. """Add the bootstrap admin as ADMIN of the platform Org.
  222. Assumes ``user`` already has a ``principal_id`` (created via
  223. ``create_user_with_principal``). Caller commits.
  224. """
  225. now = datetime.now(timezone.utc).replace(tzinfo=None)
  226. session.add(
  227. PrincipalMembership(
  228. parent_principal_id=PLATFORM_PRINCIPAL_ID,
  229. member_principal_id=user.principal_id,
  230. role=OrgRole.ADMIN,
  231. created_at=now,
  232. updated_at=now,
  233. )
  234. )
  235. class APIKeyService:
  236. def __init__(self, session: AsyncSession):
  237. self.session = session
  238. @locked_cached()
  239. async def get_by_access_key(self, access_key: str) -> Optional[ApiKey]:
  240. result = await ApiKey.one_by_field(self.session, "access_key", access_key)
  241. if result is None:
  242. return None
  243. self.session.expunge(result)
  244. return result
  245. async def get_by_user_id(self, user_id: int) -> List[ApiKey]:
  246. results = await ApiKey.all_by_field(self.session, "user_id", user_id)
  247. if results is None:
  248. return []
  249. for result in results:
  250. self.session.expunge(result)
  251. return results
  252. async def update(self, api_key: ApiKey, source: Union[dict, SQLModel, None] = None):
  253. result = await api_key.update(self.session, source)
  254. await delete_cache_by_key(self.get_by_access_key, api_key.access_key)
  255. return result
  256. async def delete(self, api_key: ApiKey):
  257. result = await api_key.delete(self.session)
  258. await delete_cache_by_key(self.get_by_access_key, api_key.access_key)
  259. return result
  260. class ClusterService:
  261. def __init__(self, session: AsyncSession):
  262. self.session = session
  263. @locked_cached()
  264. async def get_by_id(self, cluster_id: int) -> Optional[Cluster]:
  265. result = await Cluster.one_by_id(self.session, cluster_id)
  266. if result is None:
  267. return None
  268. self.session.expunge(result)
  269. return result
  270. class WorkerService:
  271. def __init__(self, session: AsyncSession):
  272. self.session = session
  273. @locked_cached()
  274. async def get_by_id(self, worker_id: int) -> Optional[Worker]:
  275. result = await Worker.one_by_id(self.session, worker_id)
  276. if result is None:
  277. return None
  278. self.session.expunge(result)
  279. return result
  280. @locked_cached()
  281. async def get_by_cluster_id_name(
  282. self, cluster_id: int, name: str
  283. ) -> Optional[Worker]:
  284. result = await Worker.one_by_fields(
  285. self.session, fields={"cluster_id": cluster_id, "name": name}
  286. )
  287. if result is None:
  288. return None
  289. self.session.expunge(result)
  290. return result
  291. @locked_cached()
  292. async def get_by_name(self, name: str) -> Optional[Worker]:
  293. result = await Worker.one_by_field(self.session, "name", name)
  294. if result is None:
  295. return None
  296. self.session.expunge(result)
  297. return result
  298. async def update(
  299. self, worker: Worker, source: Union[dict, SQLModel, None] = None, **kwargs
  300. ):
  301. result = await worker.update(self.session, source, **kwargs)
  302. # Worker cache is high-frequency, non-security-critical, skip coordinator sync
  303. await delete_cache_by_key(self.get_by_id, worker.id, sync_coordinator=False)
  304. await delete_cache_by_key(self.get_by_name, worker.name, sync_coordinator=False)
  305. return result
  306. async def batch_update(
  307. self,
  308. workers: List[Worker],
  309. source: Union[dict, SQLModel, None] = None,
  310. **kwargs,
  311. ) -> int:
  312. if not workers:
  313. return 0
  314. updated = await Worker.batch_update(self.session, workers)
  315. for w in workers:
  316. # Worker cache is high-frequency, non-security-critical, skip coordinator sync
  317. await delete_cache_by_key(self.get_by_id, w.id, sync_coordinator=False)
  318. await delete_cache_by_key(self.get_by_name, w.name, sync_coordinator=False)
  319. return updated
  320. async def delete(self, worker: Worker, **kwargs):
  321. worker_id = worker.id
  322. worker_name = worker.name
  323. result = await worker.delete(self.session, **kwargs)
  324. # Worker cache is high-frequency, non-security-critical, skip coordinator sync
  325. await delete_cache_by_key(self.get_by_id, worker_id, sync_coordinator=False)
  326. await delete_cache_by_key(self.get_by_name, worker_name, sync_coordinator=False)
  327. return result
  328. class ModelRouteService:
  329. def __init__(self, session: AsyncSession):
  330. self.session = session
  331. @locked_cached()
  332. async def get_by_name(self, name: str) -> Optional[ModelRoute]:
  333. result = await ModelRoute.one_by_field(self.session, "name", name)
  334. if result is None:
  335. return None
  336. self.session.expunge(result)
  337. return result
  338. @locked_cached()
  339. async def get_model_auth_info_by_name(
  340. self, name: str
  341. ) -> Optional[Tuple[AccessPolicyEnum, str]]:
  342. # Higress's auth callback may hand us either the Org-effective
  343. # name (`<slug>/<route>`) or the raw `route.name` depending on
  344. # whether `modelMapping` has fired yet. Resolve both forms.
  345. route: Optional[ModelRoute] = None
  346. if "/" in name:
  347. slug, _, rest = name.partition("/")
  348. if rest:
  349. owner = await Principal.one_by_field(self.session, "slug", slug)
  350. if owner is not None:
  351. route = await ModelRoute.one_by_fields(
  352. self.session,
  353. {"name": rest, "owner_principal_id": owner.id},
  354. )
  355. if route is None:
  356. route = await ModelRoute.one_by_field(self.session, "name", name)
  357. if route is None:
  358. return None
  359. route_targets = await ModelRouteTarget.all_by_fields(
  360. self.session,
  361. fields={"route_id": route.id},
  362. )
  363. if len(route_targets) == 0:
  364. return None
  365. models = await Model.all_by_fields(
  366. session=self.session,
  367. fields={},
  368. extra_conditions=[
  369. Model.id.in_(
  370. [e.model_id for e in route_targets if e.model_id is not None]
  371. )
  372. ],
  373. )
  374. # set a default static token to avoid empty token response for public maas model route
  375. registration_token = "static_token_not_found"
  376. for model in models:
  377. cluster = await Cluster.one_by_id(self.session, model.cluster_id)
  378. if cluster.registration_token is not None:
  379. registration_token = cluster.registration_token
  380. break
  381. return route.access_policy, registration_token
  382. @locked_cached()
  383. async def get_model_ids_by_model_route_name(self, name: str) -> List[Model]:
  384. # Clients send the principal-prefixed effective name (e.g.
  385. # "org1/qwen3-0.6b" or "user-42/qwen3-0.6b"). Targets are stored
  386. # keyed by raw ``route_name``, so split off the prefix and
  387. # constrain by the route's owning principal. Platform routes
  388. # have no prefix — fall back to the legacy lookup.
  389. owner_principal_id: Optional[int] = None
  390. raw_name = name
  391. if "/" in name:
  392. slug, _, rest = name.partition("/")
  393. if rest:
  394. owner = await Principal.one_by_field(self.session, "slug", slug)
  395. if owner is not None:
  396. owner_principal_id = owner.id
  397. raw_name = rest
  398. # If the slug didn't match a principal, fall through and
  399. # try the literal name (handles edge cases like a route
  400. # called "literal/with/slashes" before the prefix
  401. # convention existed).
  402. target_fields = {
  403. "route_name": raw_name,
  404. "state": TargetStateEnum.ACTIVE,
  405. "deleted_at": None,
  406. }
  407. targets = await ModelRouteTarget.all_by_fields(
  408. self.session,
  409. fields=target_fields,
  410. options=[selectinload(ModelRouteTarget.model)],
  411. )
  412. # When a principal slug was parsed, narrow to that owner's
  413. # route by joining through the parent ModelRoute's
  414. # ``owner_principal_id``. Avoids an extra round-trip when the
  415. # route name is globally unique (the typical single-Org case).
  416. if owner_principal_id is not None and len(targets) > 0:
  417. route_ids = {t.route_id for t in targets if t.route_id is not None}
  418. owner_routes = await ModelRoute.all_by_fields(
  419. self.session,
  420. fields={
  421. "owner_principal_id": owner_principal_id,
  422. "deleted_at": None,
  423. },
  424. )
  425. allowed_route_ids = {r.id for r in owner_routes if r.id in route_ids}
  426. targets = [t for t in targets if t.route_id in allowed_route_ids]
  427. models = [target.model for target in targets if target.model is not None]
  428. for model in models:
  429. self.session.expunge(model)
  430. return models
  431. async def update(
  432. self,
  433. model_route: ModelRoute,
  434. source: Union[dict, SQLModel, None] = None,
  435. auto_commit: bool = True,
  436. ):
  437. result = await model_route.update(self.session, source, auto_commit=auto_commit)
  438. await delete_cache_by_key(self.get_model_auth_info_by_name, model_route.name)
  439. await delete_cache_by_key(
  440. self.get_model_ids_by_model_route_name, model_route.name
  441. )
  442. return result
  443. async def delete(self, model_route: ModelRoute, auto_commit: bool = True):
  444. result = await model_route.delete(self.session, auto_commit=auto_commit)
  445. await delete_cache_by_key(self.get_model_auth_info_by_name, model_route.name)
  446. await delete_cache_by_key(
  447. self.get_model_ids_by_model_route_name, model_route.name
  448. )
  449. return result
  450. class ModelService:
  451. def __init__(self, session: AsyncSession):
  452. self.session = session
  453. @locked_cached()
  454. async def get_by_id(self, model_id: int) -> Optional[Model]:
  455. result = await Model.one_by_id(self.session, model_id)
  456. if result is None:
  457. return None
  458. self.session.expunge(result)
  459. return result
  460. @locked_cached()
  461. async def get_by_name(self, name: str) -> Optional[Model]:
  462. result = await Model.one_by_field(self.session, "name", name)
  463. if result is None:
  464. return None
  465. self.session.expunge(result)
  466. return result
  467. async def update(self, model: Model, source: Union[dict, SQLModel, None] = None):
  468. result = await model.update(self.session, source)
  469. await delete_cache_by_key(self.get_by_id, model.id)
  470. await delete_cache_by_key(self.get_by_name, model.name)
  471. return result
  472. async def delete(self, model: Model):
  473. result = await model.delete(self.session)
  474. await delete_cache_by_key(self.get_by_id, model.id)
  475. await delete_cache_by_key(self.get_by_name, model.name)
  476. return result
  477. class ModelInstanceService:
  478. def __init__(self, session: AsyncSession):
  479. self.session = session
  480. @locked_cached()
  481. async def get_by_id(self, id: int) -> Optional[ModelInstance]:
  482. result = await ModelInstance.one_by_id(self.session, id)
  483. if result is None:
  484. return None
  485. self.session.expunge(result)
  486. return result
  487. @locked_cached()
  488. async def get_running_instances(self, model_id: int) -> List[ModelInstance]:
  489. results = await ModelInstance.all_by_fields(
  490. self.session,
  491. fields={"model_id": model_id, "state": ModelInstanceStateEnum.RUNNING},
  492. )
  493. if results is None:
  494. return []
  495. for result in results:
  496. self.session.expunge(result)
  497. return results
  498. async def create(self, model_instance):
  499. result = await ModelInstance.create(self.session, model_instance)
  500. await delete_cache_by_key(self.get_running_instances, model_instance.model_id)
  501. return result
  502. async def update(
  503. self, model_instance: ModelInstance, source: Union[dict, SQLModel, None] = None
  504. ):
  505. result = await model_instance.update(self.session, source)
  506. await delete_cache_by_key(self.get_running_instances, model_instance.model_id)
  507. await delete_cache_by_key(self.get_by_id, model_instance.id)
  508. return result
  509. async def delete(self, model_instance: ModelInstance):
  510. result = await model_instance.delete(self.session)
  511. await delete_cache_by_key(self.get_running_instances, model_instance.model_id)
  512. await delete_cache_by_key(self.get_by_id, model_instance.id)
  513. return result
  514. async def batch_delete(self, model_instances: List[ModelInstance]):
  515. if not model_instances:
  516. return []
  517. names = [mi.name for mi in model_instances]
  518. ids = set()
  519. try:
  520. for m in model_instances:
  521. await m.delete(self.session, auto_commit=False)
  522. ids.add(m.model_id)
  523. await self.session.commit()
  524. for id in ids:
  525. await delete_cache_by_key(self.get_running_instances, id)
  526. return names
  527. except Exception as e:
  528. await self.session.rollback()
  529. raise InternalServerErrorException(
  530. message=f"Failed to delete model instances {names}: {e}"
  531. )
  532. async def batch_update(
  533. self,
  534. model_instances: List[ModelInstance],
  535. source: Union[dict, SQLModel, None] = None,
  536. ):
  537. names = [mi.name for mi in model_instances]
  538. ids = set()
  539. try:
  540. for m in model_instances:
  541. await m.update(self.session, source, auto_commit=False)
  542. ids.add(m.model_id)
  543. await self.session.commit()
  544. for id in ids:
  545. await delete_cache_by_key(self.get_running_instances, id)
  546. return names
  547. except Exception as e:
  548. await self.session.rollback()
  549. raise InternalServerErrorException(
  550. message=f"Failed to update model instances {names}: {e}"
  551. )
  552. class ModelFileService:
  553. def __init__(self, session: AsyncSession):
  554. self.session = session
  555. async def get_by_resolved_path(self, path: str) -> List[ModelFile]:
  556. results = await ModelFile.all_by_fields(
  557. self.session,
  558. )
  559. filtered_results = []
  560. for result in results:
  561. self.session.expunge(result)
  562. if path in result.resolved_paths:
  563. filtered_results.append(result)
  564. return filtered_results
  565. async def get_by_source_index(self, source_index: str) -> List[ModelFile]:
  566. results = await ModelFile.all_by_field(
  567. self.session, "source_index", source_index
  568. )
  569. if results is None:
  570. return None
  571. for result in results:
  572. self.session.expunge(result)
  573. return results
  574. async def create(self, model_file: ModelFile):
  575. return await ModelFile.create(self.session, model_file)
  576. def intersection_nullable_set(set1: Set[str], set2: Optional[Set[str]]) -> Set[str]:
  577. if set2 is None:
  578. return set1
  579. return set1.intersection(set2)
  580. async def delete_accessible_model_cache(
  581. *user_ids: int,
  582. ):
  583. for user_id in user_ids:
  584. await delete_cache_by_key(UserService.get_user_accessible_model_names, user_id)
  585. async def revoke_model_access_cache(
  586. session: AsyncSession,
  587. model: Optional[ModelRoute] = None,
  588. extra_user_ids: Optional[set[int]] = None,
  589. ):
  590. user_ids = set()
  591. if model is None:
  592. result = await session.exec(select(User.id))
  593. user_ids = set(result.all())
  594. else:
  595. # Users with a direct grant on this route's ACL — i.e. their
  596. # USER-principal appears in ``model_route_principals`` for this
  597. # route. Group / Org grants are intentionally not expanded
  598. # here: this helper invalidates per-user caches and the broader
  599. # invalidation path uses ``model=None`` (cache-bust everyone).
  600. stmt = (
  601. select(User.id)
  602. .join(
  603. ModelRoutePrincipalLink,
  604. ModelRoutePrincipalLink.principal_id == User.principal_id,
  605. )
  606. .where(ModelRoutePrincipalLink.route_id == model.id)
  607. )
  608. user_ids = set((await session.exec(stmt)).all())
  609. if extra_user_ids:
  610. user_ids.update(extra_user_ids)
  611. await delete_accessible_model_cache(*user_ids)