model_route_principals.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. """Manage principal-based access on a ModelRoute (ALLOWED_PRINCIPALS).
  2. Mounted under /v2/model-routes/{id}/principals on the admin router.
  3. The set of principals (USER / ORG / GROUP) attached to a route only
  4. takes effect when the route's ``access_policy`` is ``ALLOWED_PRINCIPALS``.
  5. Storage: each principal is one ``model_route_principals`` row with a
  6. single ``principal_id`` FK. Kind is read from the joined ``principals``
  7. row. The API surface keeps the legacy ``(principal_type, principal_id)``
  8. shape, where ``principal_id`` here is the principals.id of the target
  9. (USER / ORG / GROUP principal).
  10. """
  11. from typing import List
  12. from fastapi import APIRouter
  13. from pydantic import BaseModel
  14. from sqlmodel import select
  15. from gpustack.api.exceptions import (
  16. AlreadyExistsException,
  17. InvalidException,
  18. NotFoundException,
  19. )
  20. from gpustack.schemas.links import ModelRoutePrincipalLink
  21. from gpustack.schemas.model_routes import ModelRoute
  22. from gpustack.schemas.principals import Principal, PrincipalType
  23. from gpustack.schemas.users import User
  24. from gpustack.server.deps import SessionDep
  25. from gpustack.server.services import revoke_model_access_cache
  26. router = APIRouter()
  27. class PrincipalRef(BaseModel):
  28. principal_type: PrincipalType
  29. principal_id: int
  30. class PrincipalView(BaseModel):
  31. route_id: int
  32. principal_type: PrincipalType
  33. principal_id: int
  34. async def _load_route(session, route_id: int) -> ModelRoute:
  35. route = await ModelRoute.one_by_id(session, route_id)
  36. if not route or route.deleted_at is not None:
  37. raise NotFoundException(message="Model route not found")
  38. return route
  39. async def _validate_principal(
  40. session, principal_type: PrincipalType, principal_id: int
  41. ) -> Principal:
  42. target = await Principal.one_by_id(session, principal_id)
  43. if not target or target.deleted_at is not None:
  44. raise InvalidException(message=f"Principal {principal_id} not found")
  45. if target.kind != principal_type:
  46. raise InvalidException(
  47. message=(
  48. f"Principal {principal_id} is a {target.kind.value}, "
  49. f"not a {principal_type.value}"
  50. )
  51. )
  52. if target.kind == PrincipalType.USER:
  53. user = await User.one_by_field(session, "principal_id", principal_id)
  54. if user is None or user.is_system or user.deleted_at is not None:
  55. raise InvalidException(message=f"User principal {principal_id} not found")
  56. return target
  57. def _row_to_view(row: ModelRoutePrincipalLink, kind: PrincipalType) -> PrincipalView:
  58. return PrincipalView(
  59. route_id=row.route_id,
  60. principal_type=kind,
  61. principal_id=row.principal_id,
  62. )
  63. async def _resolve_views(
  64. session, rows: List[ModelRoutePrincipalLink]
  65. ) -> List[PrincipalView]:
  66. principal_ids = {r.principal_id for r in rows}
  67. kinds: dict[int, PrincipalType] = {}
  68. if principal_ids:
  69. result = await session.exec(
  70. select(Principal).where(Principal.id.in_(principal_ids))
  71. )
  72. kinds = {p.id: p.kind for p in result.all()}
  73. return [
  74. _row_to_view(r, kinds.get(r.principal_id, PrincipalType.USER)) for r in rows
  75. ]
  76. @router.get("/{id}/principals", response_model=List[PrincipalView])
  77. async def list_route_principals(session: SessionDep, id: int):
  78. await _load_route(session, id)
  79. stmt = select(ModelRoutePrincipalLink).where(ModelRoutePrincipalLink.route_id == id)
  80. rows = list((await session.exec(stmt)).all())
  81. return await _resolve_views(session, rows)
  82. @router.post("/{id}/principals", response_model=PrincipalView)
  83. async def add_route_principal(session: SessionDep, id: int, body: PrincipalRef):
  84. await _load_route(session, id)
  85. target = await _validate_principal(session, body.principal_type, body.principal_id)
  86. existing_stmt = select(ModelRoutePrincipalLink).where(
  87. ModelRoutePrincipalLink.route_id == id,
  88. ModelRoutePrincipalLink.principal_id == body.principal_id,
  89. )
  90. if (await session.exec(existing_stmt)).first() is not None:
  91. raise AlreadyExistsException(message="Principal already attached to route")
  92. try:
  93. link = ModelRoutePrincipalLink(
  94. route_id=id,
  95. principal_id=body.principal_id,
  96. )
  97. session.add(link)
  98. await session.commit()
  99. await session.refresh(link)
  100. # Visibility may have widened; bust the access cache for the
  101. # route. Pass model=None to broadly invalidate accessible-model
  102. # caches; the set of affected users for an org/group principal
  103. # can't be derived cheaply from ``route`` alone, so we err on
  104. # the side of correctness.
  105. await revoke_model_access_cache(session=session)
  106. except Exception as e:
  107. await session.rollback()
  108. raise InvalidException(message=f"Failed to add principal: {e}")
  109. return _row_to_view(link, target.kind)
  110. @router.delete("/{id}/principals/{principal_type}/{principal_id}")
  111. async def remove_route_principal(
  112. session: SessionDep,
  113. id: int,
  114. principal_type: PrincipalType,
  115. principal_id: int,
  116. ):
  117. await _load_route(session, id)
  118. stmt = select(ModelRoutePrincipalLink).where(
  119. ModelRoutePrincipalLink.route_id == id,
  120. ModelRoutePrincipalLink.principal_id == principal_id,
  121. )
  122. link = (await session.exec(stmt)).first()
  123. if not link:
  124. raise NotFoundException(message="Principal not attached to route")
  125. try:
  126. await session.delete(link)
  127. await session.commit()
  128. await revoke_model_access_cache(session=session)
  129. except Exception as e:
  130. await session.rollback()
  131. raise InvalidException(message=f"Failed to remove principal: {e}")