cluster_proxy.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import asyncio
  2. import functools
  3. import logging
  4. import os
  5. import ssl
  6. from pathlib import Path
  7. from typing import Optional, Tuple
  8. import aiohttp
  9. from fastapi import APIRouter, Depends, Request
  10. from fastapi.responses import StreamingResponse
  11. from gpustack.api.auth import worker_auth
  12. from gpustack.api.exceptions import (
  13. InternalServerErrorException,
  14. )
  15. router = APIRouter(dependencies=[Depends(worker_auth)])
  16. logger = logging.getLogger(__name__)
  17. SERVICE_ACCOUNT_DIR = Path("/var/run/secrets/kubernetes.io/serviceaccount")
  18. TOKEN_PATH = SERVICE_ACCOUNT_DIR / "token"
  19. CA_PATH = SERVICE_ACCOUNT_DIR / "ca.crt"
  20. # Hop-by-hop and connection-control headers we never forward in either direction.
  21. _REQUEST_HEADER_SKIP = {
  22. "host",
  23. "content-length",
  24. "transfer-encoding",
  25. "connection",
  26. "keep-alive",
  27. "proxy-authenticate",
  28. "proxy-authorization",
  29. "te",
  30. "trailer",
  31. "upgrade",
  32. # Drop the inbound auth — the API server is reached with the SA token below.
  33. "authorization",
  34. "cookie",
  35. "x-forwarded-host",
  36. "x-forwarded-port",
  37. "x-forwarded-proto",
  38. }
  39. _RESPONSE_HEADER_SKIP = {
  40. "transfer-encoding",
  41. "content-length",
  42. "connection",
  43. "keep-alive",
  44. "proxy-authenticate",
  45. "proxy-authorization",
  46. "te",
  47. "trailer",
  48. "upgrade",
  49. }
  50. _session_lock = asyncio.Lock()
  51. def _read_token() -> str:
  52. return TOKEN_PATH.read_text().strip()
  53. @functools.lru_cache(maxsize=1)
  54. def _resolve_kube_target() -> Tuple[str, ssl.SSLContext]:
  55. # Cached for the lifetime of the worker process: KUBERNETES_SERVICE_HOST/PORT
  56. # and the cluster CA certificate are static for a pod, and building an
  57. # SSLContext (parses CA, sets up trust store) is non-trivial. The
  58. # ServiceAccount token is *not* cached here — kubelet rotates the projected
  59. # token file in place, so it is re-read per request via _read_token().
  60. # lru_cache does not memoize raised exceptions, so transient setup errors
  61. # (e.g. token file briefly missing during pod start) will be retried.
  62. host = os.environ.get("KUBERNETES_SERVICE_HOST")
  63. port = os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS") or os.environ.get(
  64. "KUBERNETES_SERVICE_PORT"
  65. )
  66. if not host or not port:
  67. raise InternalServerErrorException(
  68. message=(
  69. "Worker is not running inside a Kubernetes pod: "
  70. "KUBERNETES_SERVICE_HOST/PORT environment variables are not set."
  71. )
  72. )
  73. if not TOKEN_PATH.is_file() or not CA_PATH.is_file():
  74. raise InternalServerErrorException(
  75. message=(
  76. "Worker pod ServiceAccount credentials not found at "
  77. f"{SERVICE_ACCOUNT_DIR}; ensure automountServiceAccountToken is enabled."
  78. )
  79. )
  80. # IPv6 host literal needs brackets in URL.
  81. if ":" in host and not host.startswith("["):
  82. host = f"[{host}]"
  83. base_url = f"https://{host}:{port}"
  84. ssl_ctx = ssl.create_default_context(cafile=str(CA_PATH))
  85. return base_url, ssl_ctx
  86. async def _get_kube_session(request: Request) -> aiohttp.ClientSession:
  87. session: Optional[aiohttp.ClientSession] = getattr(
  88. request.app.state, "kube_api_session", None
  89. )
  90. if session is not None and not session.closed:
  91. return session
  92. async with _session_lock:
  93. session = getattr(request.app.state, "kube_api_session", None)
  94. if session is not None and not session.closed:
  95. return session
  96. _, ssl_ctx = _resolve_kube_target()
  97. connector = aiohttp.TCPConnector(ssl=ssl_ctx, limit=64)
  98. # No explicit total timeout — keep watch / log streams open until the
  99. # client disconnects. Connect timeout still applies.
  100. session = aiohttp.ClientSession(
  101. connector=connector,
  102. timeout=aiohttp.ClientTimeout(total=None, sock_connect=10),
  103. )
  104. request.app.state.kube_api_session = session
  105. return session
  106. @router.api_route(
  107. "/cluster-proxy/{path:path}",
  108. methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"],
  109. include_in_schema=False,
  110. )
  111. async def cluster_proxy(path: str, request: Request):
  112. """
  113. Forward an HTTP request to the in-cluster Kubernetes API server using the
  114. pod's ServiceAccount credentials.
  115. Designed to be invoked by the GPUStack server through the standard
  116. server→worker request channel. The server is responsible for
  117. authenticating the original caller; this endpoint trusts the worker
  118. bearer token (see worker_auth dependency).
  119. """
  120. base_url, _ = _resolve_kube_target()
  121. target_url = f"{base_url}/{path}"
  122. headers = {
  123. k: v
  124. for k, v in request.headers.items()
  125. if k.lower() not in _REQUEST_HEADER_SKIP
  126. }
  127. headers["Authorization"] = f"Bearer {_read_token()}"
  128. # Stream the request body through to avoid buffering large payloads
  129. # (e.g. apply of big manifests) in worker memory.
  130. body = (
  131. request.stream() if request.method not in ("GET", "HEAD", "OPTIONS") else None
  132. )
  133. params = list(request.query_params.multi_items()) or None
  134. session = await _get_kube_session(request)
  135. resp = await session.request(
  136. method=request.method,
  137. url=target_url,
  138. headers=headers,
  139. data=body,
  140. params=params,
  141. allow_redirects=False,
  142. )
  143. async def streamer():
  144. try:
  145. async for chunk in resp.content.iter_any():
  146. yield chunk
  147. except asyncio.CancelledError:
  148. raise
  149. except Exception as e:
  150. logger.warning(
  151. "cluster-proxy stream interrupted for %s %s: %s",
  152. request.method,
  153. target_url,
  154. e,
  155. )
  156. finally:
  157. resp.release()
  158. response_headers = {
  159. k: v for k, v in resp.headers.items() if k.lower() not in _RESPONSE_HEADER_SKIP
  160. }
  161. return StreamingResponse(
  162. streamer(),
  163. status_code=resp.status,
  164. headers=response_headers,
  165. media_type=resp.headers.get("Content-Type"),
  166. )