api_keys.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from datetime import datetime, timedelta, timezone
  2. import secrets
  3. from typing import Optional
  4. from fastapi import APIRouter, Depends, Query
  5. from fastapi.responses import StreamingResponse
  6. from sqlalchemy.orm import selectinload
  7. from sqlmodel import select
  8. from gpustack.api.exceptions import (
  9. AlreadyExistsException,
  10. ForbiddenException,
  11. InternalServerErrorException,
  12. InvalidException,
  13. NotFoundException,
  14. )
  15. from gpustack.security import API_KEY_PREFIX, get_secret_hash, get_key_pair
  16. from gpustack.server.deps import SessionDep, TenantContextDep
  17. from gpustack.schemas.api_keys import (
  18. ApiKey,
  19. ApiKeyCreate,
  20. ApiKeyListParams,
  21. ApiKeyPublic,
  22. ApiKeysPublic,
  23. ApiKeyUpdate,
  24. )
  25. from gpustack.schemas.users import User
  26. from gpustack.server.services import APIKeyService
  27. from gpustack.utils.api_keys import get_masked_api_key_value
  28. router = APIRouter()
  29. def _is_system_owned(api_key: ApiKey) -> bool:
  30. """API keys whose user is system-owned (workers, cluster sync, etc.).
  31. Filters by ``user.is_system`` rather than the api_key name because
  32. not every system-managed key follows the ``system/`` naming scheme
  33. (``Legacy Cluster Token``, ``Default Cluster Token``, …). Requires
  34. ``selectinload(ApiKey.user)`` on the query feeding this so the
  35. relationship is hydrated before access.
  36. """
  37. return bool(api_key.user and api_key.user.is_system)
  38. def _api_key_to_public(
  39. api_key: ApiKey, value: str = None, user_name: str = None
  40. ) -> ApiKeyPublic:
  41. """Convert an ApiKey object to an ApiKeyPublic object."""
  42. return ApiKeyPublic(
  43. name=api_key.name,
  44. description=api_key.description,
  45. id=api_key.id,
  46. user_name=user_name or api_key.user_name,
  47. value=value,
  48. masked_value=get_masked_api_key_value(api_key.access_key, api_key.is_custom),
  49. created_at=api_key.created_at,
  50. updated_at=api_key.updated_at,
  51. expires_at=api_key.expires_at,
  52. allowed_model_names=api_key.allowed_model_names,
  53. is_custom=api_key.is_custom,
  54. scope=api_key.scope,
  55. )
  56. def _is_hidden_api_key(api_key: ApiKey) -> bool:
  57. return _is_system_owned(api_key)
  58. @router.get("", response_model=ApiKeysPublic)
  59. async def get_api_keys(
  60. session: SessionDep,
  61. ctx: TenantContextDep,
  62. params: ApiKeyListParams = Depends(),
  63. user_id: Optional[str] = Query(
  64. None, description="Filter by user_id. Admin can use '*' to list all users."
  65. ),
  66. search: str = None,
  67. ):
  68. user = ctx.user
  69. fields = {"user_id": user.id}
  70. if user.is_admin and user_id is not None:
  71. if user_id == "*":
  72. fields = {}
  73. else:
  74. try:
  75. fields = {"user_id": int(user_id)}
  76. except ValueError:
  77. raise InvalidException(message="user_id must be an integer or '*'")
  78. # Tenant filter: scope keys to the current org context unless the platform
  79. # admin is explicitly browsing across orgs (no header / no api key org).
  80. if ctx.current_principal_id is not None:
  81. fields["owner_principal_id"] = ctx.current_principal_id
  82. fuzzy_fields = {}
  83. if search:
  84. fuzzy_fields = {"name": search}
  85. # Hide system-owned keys (workers, cluster sync, legacy/default
  86. # cluster tokens). The set isn't covered by a clean name prefix —
  87. # entries like "Default Cluster Token" / "Legacy Cluster Token" exist
  88. # alongside the "system/..." names — so we filter by the owning
  89. # user's ``is_system`` flag, which catches every variant.
  90. extra_conditions = [
  91. ApiKey.user_id.notin_(select(User.id).where(User.is_system.is_(True))),
  92. ]
  93. if params.watch:
  94. return StreamingResponse(
  95. ApiKey.streaming(
  96. fields=fields,
  97. fuzzy_fields=fuzzy_fields,
  98. filter_func=lambda api_key: not _is_hidden_api_key(api_key),
  99. options=[selectinload(ApiKey.user)],
  100. ),
  101. media_type="text/event-stream",
  102. )
  103. result = await ApiKey.paginated_by_query(
  104. session=session,
  105. fields=fields,
  106. fuzzy_fields=fuzzy_fields,
  107. extra_conditions=extra_conditions,
  108. page=params.page,
  109. per_page=params.perPage,
  110. order_by=params.order_by,
  111. options=[selectinload(ApiKey.user)],
  112. )
  113. # Convert ApiKey to ApiKeyPublic
  114. items = [_api_key_to_public(item) for item in result.items]
  115. result.items = items
  116. return result
  117. @router.post("", response_model=ApiKeyPublic)
  118. async def create_api_key(
  119. session: SessionDep, ctx: TenantContextDep, key_in: ApiKeyCreate
  120. ):
  121. user = ctx.user
  122. target_org_id = ctx.target_principal_id_for_write()
  123. if target_org_id is None:
  124. raise ForbiddenException(
  125. message="Organization context is required to create an API key"
  126. )
  127. fields = {
  128. "user_id": user.id,
  129. "owner_principal_id": target_org_id,
  130. "name": key_in.name,
  131. }
  132. existing = await ApiKey.one_by_fields(session, fields)
  133. if existing:
  134. raise AlreadyExistsException(message=f"Api key {key_in.name} already exists")
  135. if key_in.custom is None:
  136. access_key, secret_key = secrets.token_hex(8), secrets.token_hex(16)
  137. else:
  138. access_key, secret_key = get_key_pair(key_in.custom)
  139. existing_key = await ApiKey.one_by_field(
  140. session=session, field="access_key", value=access_key
  141. )
  142. if existing_key:
  143. expired = (
  144. existing_key.expires_at is not None
  145. and existing_key.expires_at <= datetime.now(timezone.utc)
  146. )
  147. message = (
  148. "Custom API Key duplicate with existing key "
  149. f"{existing_key.name} (id: {existing_key.id}, expired: {expired})"
  150. )
  151. raise AlreadyExistsException(message=message)
  152. current = datetime.now(timezone.utc)
  153. expires_at = None
  154. if key_in.expires_in and key_in.expires_in > 0:
  155. expires_at = current + timedelta(seconds=key_in.expires_in)
  156. try:
  157. api_key = ApiKey(
  158. name=key_in.name,
  159. description=key_in.description,
  160. user_id=user.id,
  161. owner_principal_id=target_org_id,
  162. access_key=access_key,
  163. hashed_secret_key=get_secret_hash(secret_key),
  164. expires_at=expires_at,
  165. allowed_model_names=key_in.allowed_model_names,
  166. is_custom=key_in.custom is not None,
  167. scope=key_in.scope,
  168. )
  169. api_key = await ApiKey.create(session, api_key)
  170. except Exception as e:
  171. raise InternalServerErrorException(message=f"Failed to create api key: {e}")
  172. value = (
  173. key_in.custom
  174. if key_in.custom
  175. else f"{API_KEY_PREFIX}_{access_key}_{secret_key}"
  176. )
  177. return _api_key_to_public(api_key, value=value, user_name=user.username)
  178. def _api_key_in_scope(api_key: ApiKey, ctx) -> bool:
  179. """An api_key is in the caller's scope if the caller is its owner, or a
  180. platform admin acting either across all orgs or in the key's org.
  181. """
  182. user = ctx.user
  183. if api_key.user_id == user.id and (
  184. ctx.current_principal_id is None
  185. or api_key.owner_principal_id == ctx.current_principal_id
  186. ):
  187. return True
  188. if user.is_admin and (
  189. ctx.current_principal_id is None
  190. or api_key.owner_principal_id == ctx.current_principal_id
  191. ):
  192. return True
  193. return False
  194. @router.delete("/{id}")
  195. async def delete_api_key(session: SessionDep, ctx: TenantContextDep, id: int):
  196. api_key = await ApiKey.one_by_id(session, id)
  197. if not api_key or not _api_key_in_scope(api_key, ctx):
  198. raise NotFoundException(message="Api key not found")
  199. try:
  200. await APIKeyService(session).delete(api_key)
  201. except Exception as e:
  202. raise InternalServerErrorException(message=f"Failed to delete api key: {e}")
  203. @router.put("/{id}", response_model=ApiKeyPublic)
  204. async def update_api_key(
  205. session: SessionDep, ctx: TenantContextDep, id: int, key_in: ApiKeyUpdate
  206. ):
  207. api_key = await ApiKey.one_by_id(session, id, options=[selectinload(ApiKey.user)])
  208. user_name = api_key.user.username if api_key and api_key.user else None
  209. if not api_key or not _api_key_in_scope(api_key, ctx):
  210. raise NotFoundException(message="Api key not found")
  211. try:
  212. await APIKeyService(session).update(
  213. api_key=api_key,
  214. source=key_in.model_dump(exclude_unset=True),
  215. )
  216. except Exception as e:
  217. raise InternalServerErrorException(message=f"Failed to update api key: {e}")
  218. return _api_key_to_public(api_key, user_name=user_name)