| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import asyncio
- import logging
- import aiohttp
- from typing import Callable, Optional, Tuple
- from fastapi import APIRouter, Depends, HTTPException, Request
- from fastapi.responses import StreamingResponse, JSONResponse
- from starlette.background import BackgroundTask
- from gpustack.api.auth import worker_auth
- from gpustack.api.exceptions import (
- GatewayTimeoutException,
- ServiceUnavailableException,
- NotFoundException,
- ErrorResponse,
- )
- from gpustack import envs
- from gpustack.utils.network import use_proxy_env_for_url
- from gpustack.gateway.utils import get_instance_id_from_header, router_header_key
- router = APIRouter(dependencies=[Depends(worker_auth)])
- logger = logging.getLogger(__name__)
- @router.api_route(
- "/proxy/{path:path}",
- methods=["GET", "POST", "OPTIONS", "HEAD"],
- )
- async def proxy(path: str, request: Request): # noqa: C901
- worker_ip_getter: Callable[[], str] = request.app.state.worker_ip_getter
- if worker_ip_getter is None:
- worker_ip_getter = localhost_fallback
- target_service_port = getattr(request.state, "x_target_port", None)
- if not target_service_port:
- raise HTTPException(
- status_code=400,
- detail="Missing target port; ensure the request includes the routing header",
- )
- try:
- logger.debug(
- f"Proxying request to worker at port {target_service_port} for path: {path}"
- )
- url = f"http://{worker_ip_getter()}:{target_service_port}/{path}"
- if request.url.query:
- url = f"{url}?{request.url.query}"
- headers = dict(request.headers)
- headers.pop("host", None)
- headers.pop("transfer-encoding", None)
- if headers.get("transfer-encoding", "").lower() == "chunked":
- async def body_generator():
- async for chunk in request.stream():
- yield chunk
- content = body_generator()
- else:
- content = await request.body()
- async def stream_response(resp):
- async for chunk in resp.content.iter_chunked(1024):
- yield chunk
- use_proxy_env = use_proxy_env_for_url(url)
- http_client: aiohttp.ClientSession = (
- request.app.state.http_client
- if use_proxy_env
- else request.app.state.http_client_no_proxy
- )
- timeout = aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT)
- resp = await http_client.request(
- method=request.method,
- url=url,
- headers=headers,
- data=content,
- timeout=timeout,
- )
- # Heuristic: treat a non-error HTTP status as a successful inference
- # signal so the active health-check loop can skip this instance.
- # For streaming responses the status is available before body
- # transfer, so a mid-stream failure will still be counted — this is
- # acceptable as a best-effort optimisation.
- target_instance_id = getattr(request.state, "x_target_instance_id", None)
- if resp.status < 400 and target_instance_id:
- record_fn = getattr(request.app.state, "record_successful_inference", None)
- if record_fn:
- record_fn(int(target_instance_id))
- return StreamingResponse(
- stream_response(resp),
- status_code=resp.status,
- headers=dict(resp.headers),
- background=BackgroundTask(resp.close),
- )
- except asyncio.TimeoutError as e:
- error_message = f"Request to {url} timed out"
- if str(e):
- error_message += f": {e}"
- raise GatewayTimeoutException(
- message=error_message,
- is_openai_exception=True,
- )
- except Exception as e:
- error_message = "An unexpected error occurred"
- if str(e):
- error_message += f": {e}"
- raise ServiceUnavailableException(
- message=error_message,
- is_openai_exception=True,
- )
- def localhost_fallback() -> str:
- return "127.0.0.1"
- def get_model_instance_info_from_model_name(
- request: Request,
- ) -> Tuple[int, int]:
- """
- Get model instance port and instance id from model name in header
- "x-gpustack-model-instance".
- Return (port, model_instance_id).
- """
- model_instance_id = get_instance_id_from_header(request.headers)
- port: Optional[int] = request.app.state.get_instance_port_by_model_instance_id(
- model_instance_id
- )
- if not port:
- raise NotFoundException(
- message=f"No running model instance found for model name: {model_instance_id}",
- )
- logger.debug(f"Found port {port} from model instance id {model_instance_id}")
- return port, model_instance_id
- async def set_port_from_model_name(request: Request, call_next):
- model_name = request.headers.get(router_header_key, None)
- if model_name is None:
- return await call_next(request)
- try:
- port, model_instance_id = get_model_instance_info_from_model_name(request)
- request.scope["path"] = f"/proxy{request.url.path}"
- request.state.x_target_port = str(port)
- request.state.x_target_instance_id = model_instance_id
- return await call_next(request)
- except NotFoundException as e:
- logger.debug("failed to find model instance for proxying: %s", e.message)
- return JSONResponse(
- status_code=e.status_code,
- content=ErrorResponse(
- code=e.status_code,
- reason=e.reason,
- message=e.message,
- ).model_dump(),
- )
- except HTTPException as e:
- logger.debug("failed to find model instance for proxying: %s", e.detail)
- return await call_next(request)
|