proxy.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import asyncio
  2. import logging
  3. import aiohttp
  4. from typing import Callable, Optional, Tuple
  5. from fastapi import APIRouter, Depends, HTTPException, Request
  6. from fastapi.responses import StreamingResponse, JSONResponse
  7. from starlette.background import BackgroundTask
  8. from gpustack.api.auth import worker_auth
  9. from gpustack.api.exceptions import (
  10. GatewayTimeoutException,
  11. ServiceUnavailableException,
  12. NotFoundException,
  13. ErrorResponse,
  14. )
  15. from gpustack import envs
  16. from gpustack.utils.network import use_proxy_env_for_url
  17. from gpustack.gateway.utils import get_instance_id_from_header, router_header_key
  18. router = APIRouter(dependencies=[Depends(worker_auth)])
  19. logger = logging.getLogger(__name__)
  20. @router.api_route(
  21. "/proxy/{path:path}",
  22. methods=["GET", "POST", "OPTIONS", "HEAD"],
  23. )
  24. async def proxy(path: str, request: Request): # noqa: C901
  25. worker_ip_getter: Callable[[], str] = request.app.state.worker_ip_getter
  26. if worker_ip_getter is None:
  27. worker_ip_getter = localhost_fallback
  28. target_service_port = getattr(request.state, "x_target_port", None)
  29. if not target_service_port:
  30. raise HTTPException(
  31. status_code=400,
  32. detail="Missing target port; ensure the request includes the routing header",
  33. )
  34. try:
  35. logger.debug(
  36. f"Proxying request to worker at port {target_service_port} for path: {path}"
  37. )
  38. url = f"http://{worker_ip_getter()}:{target_service_port}/{path}"
  39. if request.url.query:
  40. url = f"{url}?{request.url.query}"
  41. headers = dict(request.headers)
  42. headers.pop("host", None)
  43. headers.pop("transfer-encoding", None)
  44. if headers.get("transfer-encoding", "").lower() == "chunked":
  45. async def body_generator():
  46. async for chunk in request.stream():
  47. yield chunk
  48. content = body_generator()
  49. else:
  50. content = await request.body()
  51. async def stream_response(resp):
  52. async for chunk in resp.content.iter_chunked(1024):
  53. yield chunk
  54. use_proxy_env = use_proxy_env_for_url(url)
  55. http_client: aiohttp.ClientSession = (
  56. request.app.state.http_client
  57. if use_proxy_env
  58. else request.app.state.http_client_no_proxy
  59. )
  60. timeout = aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT)
  61. resp = await http_client.request(
  62. method=request.method,
  63. url=url,
  64. headers=headers,
  65. data=content,
  66. timeout=timeout,
  67. )
  68. # Heuristic: treat a non-error HTTP status as a successful inference
  69. # signal so the active health-check loop can skip this instance.
  70. # For streaming responses the status is available before body
  71. # transfer, so a mid-stream failure will still be counted — this is
  72. # acceptable as a best-effort optimisation.
  73. target_instance_id = getattr(request.state, "x_target_instance_id", None)
  74. if resp.status < 400 and target_instance_id:
  75. record_fn = getattr(request.app.state, "record_successful_inference", None)
  76. if record_fn:
  77. record_fn(int(target_instance_id))
  78. return StreamingResponse(
  79. stream_response(resp),
  80. status_code=resp.status,
  81. headers=dict(resp.headers),
  82. background=BackgroundTask(resp.close),
  83. )
  84. except asyncio.TimeoutError as e:
  85. error_message = f"Request to {url} timed out"
  86. if str(e):
  87. error_message += f": {e}"
  88. raise GatewayTimeoutException(
  89. message=error_message,
  90. is_openai_exception=True,
  91. )
  92. except Exception as e:
  93. error_message = "An unexpected error occurred"
  94. if str(e):
  95. error_message += f": {e}"
  96. raise ServiceUnavailableException(
  97. message=error_message,
  98. is_openai_exception=True,
  99. )
  100. def localhost_fallback() -> str:
  101. return "127.0.0.1"
  102. def get_model_instance_info_from_model_name(
  103. request: Request,
  104. ) -> Tuple[int, int]:
  105. """
  106. Get model instance port and instance id from model name in header
  107. "x-gpustack-model-instance".
  108. Return (port, model_instance_id).
  109. """
  110. model_instance_id = get_instance_id_from_header(request.headers)
  111. port: Optional[int] = request.app.state.get_instance_port_by_model_instance_id(
  112. model_instance_id
  113. )
  114. if not port:
  115. raise NotFoundException(
  116. message=f"No running model instance found for model name: {model_instance_id}",
  117. )
  118. logger.debug(f"Found port {port} from model instance id {model_instance_id}")
  119. return port, model_instance_id
  120. async def set_port_from_model_name(request: Request, call_next):
  121. model_name = request.headers.get(router_header_key, None)
  122. if model_name is None:
  123. return await call_next(request)
  124. try:
  125. port, model_instance_id = get_model_instance_info_from_model_name(request)
  126. request.scope["path"] = f"/proxy{request.url.path}"
  127. request.state.x_target_port = str(port)
  128. request.state.x_target_instance_id = model_instance_id
  129. return await call_next(request)
  130. except NotFoundException as e:
  131. logger.debug("failed to find model instance for proxying: %s", e.message)
  132. return JSONResponse(
  133. status_code=e.status_code,
  134. content=ErrorResponse(
  135. code=e.status_code,
  136. reason=e.reason,
  137. message=e.message,
  138. ).model_dump(),
  139. )
  140. except HTTPException as e:
  141. logger.debug("failed to find model instance for proxying: %s", e.detail)
  142. return await call_next(request)