user_groups.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. """UserGroup management — Org admin+ or platform admin.
  2. Groups are GROUP-kind ``Principal`` rows whose ``parent_principal_id``
  3. points at their owning ORG-principal. Group memberships live in the
  4. unified ``principal_memberships`` table with ``role=NULL`` (groups
  5. don't have role tiers).
  6. """
  7. from datetime import datetime, timezone
  8. from typing import List, Optional
  9. from fastapi import APIRouter, Depends
  10. from pydantic import BaseModel
  11. from sqlmodel import select
  12. from gpustack.api.exceptions import (
  13. AlreadyExistsException,
  14. ForbiddenException,
  15. InvalidException,
  16. NotFoundException,
  17. )
  18. from gpustack.schemas.principals import (
  19. OrgRole,
  20. Principal,
  21. PrincipalMembership,
  22. PrincipalType,
  23. )
  24. from gpustack.schemas.user_groups import (
  25. UserGroupCreate,
  26. UserGroupListParams,
  27. UserGroupMembershipPublic,
  28. UserGroupPublic,
  29. UserGroupUpdate,
  30. UserGroupsPublic,
  31. )
  32. from gpustack.schemas.users import User
  33. from gpustack.server.deps import SessionDep, TenantContextDep
  34. router = APIRouter()
  35. class GroupMembershipCreate(BaseModel):
  36. user_id: int
  37. def _can_manage_groups(ctx, org_id: int) -> bool:
  38. if ctx.is_platform_admin:
  39. return True
  40. if ctx.current_principal_id != org_id:
  41. return False
  42. return ctx.org_role == OrgRole.ADMIN
  43. async def _load_org(session, org_id: int) -> Principal:
  44. org = await Principal.one_by_id(session, org_id)
  45. if not org or org.deleted_at is not None or org.kind != PrincipalType.ORG:
  46. raise NotFoundException(message="Organization not found")
  47. return org
  48. async def _load_group(session, org_id: int, group_id: int) -> Principal:
  49. group = await Principal.one_by_id(session, group_id)
  50. if (
  51. not group
  52. or group.deleted_at is not None
  53. or group.kind != PrincipalType.GROUP
  54. or group.parent_principal_id != org_id
  55. ):
  56. raise NotFoundException(message="Group not found")
  57. return group
  58. def _group_to_public(group: Principal) -> UserGroupPublic:
  59. return UserGroupPublic.from_principal(group)
  60. # ---- groups ----------------------------------------------------------------
  61. @router.get("/organizations/{org_id}/groups", response_model=UserGroupsPublic)
  62. async def list_groups(
  63. session: SessionDep,
  64. ctx: TenantContextDep,
  65. org_id: int,
  66. params: UserGroupListParams = Depends(),
  67. search: Optional[str] = None,
  68. ):
  69. await _load_org(session, org_id)
  70. if not ctx.is_platform_admin and ctx.current_principal_id != org_id:
  71. raise ForbiddenException(message="Not a member of this organization")
  72. fuzzy_fields = {"name": search} if search else {}
  73. page = await Principal.paginated_by_query(
  74. session=session,
  75. fields={
  76. "kind": PrincipalType.GROUP,
  77. "parent_principal_id": org_id,
  78. "deleted_at": None,
  79. },
  80. fuzzy_fields=fuzzy_fields,
  81. page=params.page,
  82. per_page=params.perPage,
  83. order_by=params.order_by,
  84. )
  85. page.items = [_group_to_public(g) for g in page.items]
  86. return page
  87. @router.post("/organizations/{org_id}/groups", response_model=UserGroupPublic)
  88. async def create_group(
  89. session: SessionDep,
  90. ctx: TenantContextDep,
  91. org_id: int,
  92. body: UserGroupCreate,
  93. ):
  94. await _load_org(session, org_id)
  95. if not _can_manage_groups(ctx, org_id):
  96. raise ForbiddenException(message="Insufficient permission to manage groups")
  97. existing = await Principal.one_by_fields(
  98. session,
  99. {
  100. "kind": PrincipalType.GROUP,
  101. "parent_principal_id": org_id,
  102. "name": body.name,
  103. "deleted_at": None,
  104. },
  105. )
  106. if existing:
  107. raise AlreadyExistsException(
  108. message=f"Group '{body.name}' already exists in this organization"
  109. )
  110. try:
  111. group = Principal(
  112. kind=PrincipalType.GROUP,
  113. parent_principal_id=org_id,
  114. name=body.name,
  115. description=body.description,
  116. )
  117. created = await Principal.create(session, group)
  118. except Exception as e:
  119. raise InvalidException(message=f"Failed to create group: {e}")
  120. return _group_to_public(created)
  121. @router.put("/organizations/{org_id}/groups/{group_id}", response_model=UserGroupPublic)
  122. async def update_group(
  123. session: SessionDep,
  124. ctx: TenantContextDep,
  125. org_id: int,
  126. group_id: int,
  127. body: UserGroupUpdate,
  128. ):
  129. group = await _load_group(session, org_id, group_id)
  130. if not _can_manage_groups(ctx, org_id):
  131. raise ForbiddenException(message="Insufficient permission to manage groups")
  132. try:
  133. await group.update(session, body.model_dump(exclude_unset=True))
  134. except Exception as e:
  135. raise InvalidException(message=f"Failed to update group: {e}")
  136. return _group_to_public(group)
  137. @router.delete("/organizations/{org_id}/groups/{group_id}")
  138. async def delete_group(
  139. session: SessionDep, ctx: TenantContextDep, org_id: int, group_id: int
  140. ):
  141. group = await _load_group(session, org_id, group_id)
  142. if not _can_manage_groups(ctx, org_id):
  143. raise ForbiddenException(message="Insufficient permission to manage groups")
  144. try:
  145. await group.delete(session)
  146. except Exception as e:
  147. raise InvalidException(message=f"Failed to delete group: {e}")
  148. # ---- group members ---------------------------------------------------------
  149. async def _resolve_user(session, user_id: int) -> Optional[User]:
  150. user = await User.one_by_id(session, user_id)
  151. if not user or user.is_system or user.deleted_at is not None:
  152. return None
  153. return user
  154. @router.get(
  155. "/organizations/{org_id}/groups/{group_id}/members",
  156. response_model=List[UserGroupMembershipPublic],
  157. )
  158. async def list_group_members(
  159. session: SessionDep,
  160. ctx: TenantContextDep,
  161. org_id: int,
  162. group_id: int,
  163. ):
  164. await _load_group(session, org_id, group_id)
  165. if not ctx.is_platform_admin and ctx.current_principal_id != org_id:
  166. raise ForbiddenException(message="Not a member of this organization")
  167. stmt = select(PrincipalMembership).where(
  168. PrincipalMembership.parent_principal_id == group_id,
  169. PrincipalMembership.deleted_at.is_(None),
  170. )
  171. rows = list((await session.exec(stmt)).all())
  172. member_ids = {r.member_principal_id for r in rows}
  173. user_by_principal: dict[int, User] = {}
  174. if member_ids:
  175. result = await session.exec(
  176. select(User).where(User.principal_id.in_(member_ids))
  177. )
  178. user_by_principal = {u.principal_id: u for u in result.all()}
  179. out: List[UserGroupMembershipPublic] = []
  180. for r in rows:
  181. u = user_by_principal.get(r.member_principal_id)
  182. out.append(
  183. UserGroupMembershipPublic(
  184. user_id=getattr(u, "id", 0),
  185. group_id=group_id,
  186. created_at=r.created_at,
  187. username=getattr(u, "username", None),
  188. full_name=getattr(u, "full_name", None),
  189. )
  190. )
  191. return out
  192. @router.post(
  193. "/organizations/{org_id}/groups/{group_id}/members",
  194. response_model=UserGroupMembershipPublic,
  195. )
  196. async def add_group_member(
  197. session: SessionDep,
  198. ctx: TenantContextDep,
  199. org_id: int,
  200. group_id: int,
  201. body: GroupMembershipCreate,
  202. ):
  203. await _load_group(session, org_id, group_id)
  204. if not _can_manage_groups(ctx, org_id):
  205. raise ForbiddenException(message="Insufficient permission to manage groups")
  206. user = await _resolve_user(session, body.user_id)
  207. if not user:
  208. raise NotFoundException(message="User not found")
  209. # User must be an active member of the group's org first.
  210. org_membership_stmt = select(PrincipalMembership.id).where(
  211. PrincipalMembership.parent_principal_id == org_id,
  212. PrincipalMembership.member_principal_id == user.principal_id,
  213. PrincipalMembership.deleted_at.is_(None),
  214. )
  215. if (await session.exec(org_membership_stmt)).first() is None:
  216. raise InvalidException(
  217. message=(
  218. f"User {body.user_id} is not a member of " f"organization {org_id}"
  219. )
  220. )
  221. existing_stmt = select(PrincipalMembership).where(
  222. PrincipalMembership.parent_principal_id == group_id,
  223. PrincipalMembership.member_principal_id == user.principal_id,
  224. PrincipalMembership.deleted_at.is_(None),
  225. )
  226. if (await session.exec(existing_stmt)).first() is not None:
  227. raise AlreadyExistsException(
  228. message=f"User {body.user_id} is already in group {group_id}"
  229. )
  230. try:
  231. now = datetime.now(timezone.utc).replace(tzinfo=None)
  232. link = PrincipalMembership(
  233. parent_principal_id=group_id,
  234. member_principal_id=user.principal_id,
  235. role=None,
  236. created_at=now,
  237. updated_at=now,
  238. )
  239. session.add(link)
  240. await session.commit()
  241. await session.refresh(link)
  242. except Exception as e:
  243. await session.rollback()
  244. raise InvalidException(message=f"Failed to add group member: {e}")
  245. return UserGroupMembershipPublic(
  246. user_id=user.id,
  247. group_id=group_id,
  248. created_at=link.created_at,
  249. username=user.username,
  250. full_name=user.full_name,
  251. )
  252. @router.delete("/organizations/{org_id}/groups/{group_id}/members/{user_id}")
  253. async def remove_group_member(
  254. session: SessionDep,
  255. ctx: TenantContextDep,
  256. org_id: int,
  257. group_id: int,
  258. user_id: int,
  259. ):
  260. await _load_group(session, org_id, group_id)
  261. if not _can_manage_groups(ctx, org_id):
  262. raise ForbiddenException(message="Insufficient permission to manage groups")
  263. user = await _resolve_user(session, user_id)
  264. if not user:
  265. raise NotFoundException(message="Group membership not found")
  266. stmt = select(PrincipalMembership).where(
  267. PrincipalMembership.parent_principal_id == group_id,
  268. PrincipalMembership.member_principal_id == user.principal_id,
  269. PrincipalMembership.deleted_at.is_(None),
  270. )
  271. link = (await session.exec(stmt)).first()
  272. if not link:
  273. raise NotFoundException(message="Group membership not found")
  274. try:
  275. await link.delete(session, soft=True)
  276. except Exception as e:
  277. await session.rollback()
  278. raise InvalidException(message=f"Failed to remove group member: {e}")