cloud_credentials.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from urllib.parse import urljoin
  2. from functools import partial
  3. from fastapi import APIRouter, Depends, Request
  4. from fastapi.responses import StreamingResponse
  5. from gpustack.api.exceptions import (
  6. AlreadyExistsException,
  7. InternalServerErrorException,
  8. NotFoundException,
  9. )
  10. from gpustack.api.tenant import (
  11. assert_org_owned_writable,
  12. assert_resource_visible,
  13. tenant_list_conditions,
  14. validate_owner_principal,
  15. )
  16. from gpustack.server.db import async_session
  17. from gpustack.server.deps import SessionDep, TenantContextDep
  18. from gpustack.schemas.clusters import (
  19. CloudCredentialCreate,
  20. CloudCredentialListParams,
  21. CloudCredentialPublic,
  22. CloudCredentialsPublic,
  23. CloudCredentialUpdate,
  24. CloudCredential,
  25. ClusterProvider,
  26. )
  27. from gpustack.cloud_providers.common import factory
  28. from gpustack.routes.proxy import proxy_to
  29. from gpustack.schemas.organizations import PLATFORM_PRINCIPAL_ID
  30. router = APIRouter()
  31. @router.get("", response_model=CloudCredentialsPublic)
  32. async def list(
  33. ctx: TenantContextDep,
  34. params: CloudCredentialListParams = Depends(),
  35. name: str = None,
  36. search: str = None,
  37. ):
  38. fuzzy_fields = {}
  39. if search:
  40. fuzzy_fields = {"name": search}
  41. fields = {"deleted_at": None}
  42. if name:
  43. fields = {"name": name}
  44. if params.watch:
  45. return StreamingResponse(
  46. CloudCredential.streaming(fields=fields, fuzzy_fields=fuzzy_fields),
  47. media_type="text/event-stream",
  48. )
  49. async with async_session() as session:
  50. extra_conditions = tenant_list_conditions(ctx, CloudCredential)
  51. return await CloudCredential.paginated_by_query(
  52. session=session,
  53. fields=fields,
  54. fuzzy_fields=fuzzy_fields,
  55. extra_conditions=extra_conditions,
  56. page=params.page,
  57. per_page=params.perPage,
  58. order_by=params.order_by,
  59. )
  60. @router.get("/{id}", response_model=CloudCredentialPublic)
  61. async def get(session: SessionDep, ctx: TenantContextDep, id: int):
  62. existing = await CloudCredential.one_by_id(session, id)
  63. if not existing or existing.deleted_at is not None:
  64. raise NotFoundException(message=f"cloud credential {id} not found")
  65. assert_resource_visible(
  66. ctx,
  67. existing,
  68. not_found_message=f"cloud credential {id} not found",
  69. )
  70. return existing
  71. @router.post("", response_model=CloudCredentialPublic)
  72. async def create(
  73. session: SessionDep, ctx: TenantContextDep, input: CloudCredentialCreate
  74. ):
  75. # Mirror cluster-create: every credential has an owner Org. Fill in
  76. # ctx.current_principal_id, or PLATFORM_ORG for admin in "All" mode.
  77. if input.owner_principal_id is None:
  78. input.owner_principal_id = ctx.current_principal_id or PLATFORM_PRINCIPAL_ID
  79. validate_owner_principal(
  80. input.owner_principal_id, ctx, resource_label="cloud credential"
  81. )
  82. # Names are unique within their owning Org.
  83. existing = await CloudCredential.one_by_fields(
  84. session,
  85. {
  86. "deleted_at": None,
  87. "name": input.name,
  88. "owner_principal_id": input.owner_principal_id,
  89. },
  90. )
  91. if existing:
  92. raise AlreadyExistsException(
  93. message=f"cloud credential {input.name} already exists"
  94. )
  95. try:
  96. return await CloudCredential.create(session, input)
  97. except Exception as e:
  98. raise InternalServerErrorException(
  99. message=f"Failed to create cloud credential: {e}"
  100. )
  101. @router.put("/{id}", response_model=CloudCredentialPublic)
  102. async def update(
  103. session: SessionDep,
  104. ctx: TenantContextDep,
  105. id: int,
  106. input: CloudCredentialUpdate,
  107. ):
  108. existing = await CloudCredential.one_by_id(session, id)
  109. if not existing or existing.deleted_at is not None:
  110. raise NotFoundException(message=f"cloud credential {id} not found")
  111. assert_org_owned_writable(ctx, existing, resource_label="cloud credential")
  112. try:
  113. await CloudCredential.update(existing, session=session, source=input)
  114. except Exception as e:
  115. raise InternalServerErrorException(
  116. message=f"Failed to update cloud credential: {e}"
  117. )
  118. return await CloudCredential.one_by_id(session, id)
  119. @router.delete("/{id}")
  120. async def delete(session: SessionDep, ctx: TenantContextDep, id: int):
  121. existing = await CloudCredential.one_by_id(session, id)
  122. if not existing or existing.deleted_at is not None:
  123. raise NotFoundException(message=f"cloud credential {id} not found")
  124. assert_org_owned_writable(ctx, existing, resource_label="cloud credential")
  125. try:
  126. await existing.delete(session=session)
  127. except Exception as e:
  128. raise InternalServerErrorException(
  129. message=f"Failed to delete cloud credential: {e}"
  130. )
  131. @router.api_route("/{id}/provider-proxy/{path:path}", methods=["GET"])
  132. async def proxy_cluster_provider_api(
  133. request: Request, session: SessionDep, ctx: TenantContextDep, id: int, path: str
  134. ):
  135. """
  136. To support other provider in the future, use api_route instead of get.
  137. """
  138. credential = await CloudCredential.one_by_id(session=session, id=id)
  139. if not credential:
  140. raise NotFoundException(message=f"Credential {id} not found")
  141. # Proxying via the credential's secret bridges into the cloud
  142. # provider's API; treat as a "use" / read-class permission, gated
  143. # the same way as a visibility check.
  144. assert_resource_visible(
  145. ctx,
  146. credential,
  147. not_found_message=f"Credential {id} not found",
  148. )
  149. if credential.provider in [ClusterProvider.Docker, ClusterProvider.Kubernetes]:
  150. raise NotFoundException(message=f"Provider {credential.provider} not supported")
  151. provider = factory.get(credential.provider, None)
  152. if provider is None:
  153. raise NotFoundException(message=f"Provider {credential.provider} not found")
  154. url = urljoin(provider[0].get_api_endpoint(), path)
  155. if request.query_params:
  156. url = f"{url}?{str(request.query_params)}"
  157. options = {
  158. **(credential.options or {}),
  159. }
  160. header_modifier = partial(
  161. provider[0].process_header, credential.key, credential.secret, options
  162. )
  163. response = await proxy_to(request, url, header_modifier)
  164. if response.status_code in [401, 403, 404]:
  165. original_status = response.status_code
  166. response.status_code = 400
  167. response.headers.append("X-GPUStack-Original-Status", str(original_status))
  168. return response