import asyncio import logging import aiohttp from typing import Callable, Optional, Tuple from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse, JSONResponse from starlette.background import BackgroundTask from gpustack.api.auth import worker_auth from gpustack.api.exceptions import ( GatewayTimeoutException, ServiceUnavailableException, NotFoundException, ErrorResponse, ) from gpustack import envs from gpustack.utils.network import use_proxy_env_for_url from gpustack.gateway.utils import get_instance_id_from_header, router_header_key router = APIRouter(dependencies=[Depends(worker_auth)]) logger = logging.getLogger(__name__) @router.api_route( "/proxy/{path:path}", methods=["GET", "POST", "OPTIONS", "HEAD"], ) async def proxy(path: str, request: Request): # noqa: C901 worker_ip_getter: Callable[[], str] = request.app.state.worker_ip_getter if worker_ip_getter is None: worker_ip_getter = localhost_fallback target_service_port = getattr(request.state, "x_target_port", None) if not target_service_port: raise HTTPException( status_code=400, detail="Missing target port; ensure the request includes the routing header", ) try: logger.debug( f"Proxying request to worker at port {target_service_port} for path: {path}" ) url = f"http://{worker_ip_getter()}:{target_service_port}/{path}" if request.url.query: url = f"{url}?{request.url.query}" headers = dict(request.headers) headers.pop("host", None) headers.pop("transfer-encoding", None) if headers.get("transfer-encoding", "").lower() == "chunked": async def body_generator(): async for chunk in request.stream(): yield chunk content = body_generator() else: content = await request.body() async def stream_response(resp): async for chunk in resp.content.iter_chunked(1024): yield chunk use_proxy_env = use_proxy_env_for_url(url) http_client: aiohttp.ClientSession = ( request.app.state.http_client if use_proxy_env else request.app.state.http_client_no_proxy ) timeout = aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT) resp = await http_client.request( method=request.method, url=url, headers=headers, data=content, timeout=timeout, ) # Heuristic: treat a non-error HTTP status as a successful inference # signal so the active health-check loop can skip this instance. # For streaming responses the status is available before body # transfer, so a mid-stream failure will still be counted — this is # acceptable as a best-effort optimisation. target_instance_id = getattr(request.state, "x_target_instance_id", None) if resp.status < 400 and target_instance_id: record_fn = getattr(request.app.state, "record_successful_inference", None) if record_fn: record_fn(int(target_instance_id)) return StreamingResponse( stream_response(resp), status_code=resp.status, headers=dict(resp.headers), background=BackgroundTask(resp.close), ) except asyncio.TimeoutError as e: error_message = f"Request to {url} timed out" if str(e): error_message += f": {e}" raise GatewayTimeoutException( message=error_message, is_openai_exception=True, ) except Exception as e: error_message = "An unexpected error occurred" if str(e): error_message += f": {e}" raise ServiceUnavailableException( message=error_message, is_openai_exception=True, ) def localhost_fallback() -> str: return "127.0.0.1" def get_model_instance_info_from_model_name( request: Request, ) -> Tuple[int, int]: """ Get model instance port and instance id from model name in header "x-gpustack-model-instance". Return (port, model_instance_id). """ model_instance_id = get_instance_id_from_header(request.headers) port: Optional[int] = request.app.state.get_instance_port_by_model_instance_id( model_instance_id ) if not port: raise NotFoundException( message=f"No running model instance found for model name: {model_instance_id}", ) logger.debug(f"Found port {port} from model instance id {model_instance_id}") return port, model_instance_id async def set_port_from_model_name(request: Request, call_next): model_name = request.headers.get(router_header_key, None) if model_name is None: return await call_next(request) try: port, model_instance_id = get_model_instance_info_from_model_name(request) request.scope["path"] = f"/proxy{request.url.path}" request.state.x_target_port = str(port) request.state.x_target_instance_id = model_instance_id return await call_next(request) except NotFoundException as e: logger.debug("failed to find model instance for proxying: %s", e.message) return JSONResponse( status_code=e.status_code, content=ErrorResponse( code=e.status_code, reason=e.reason, message=e.message, ).model_dump(), ) except HTTPException as e: logger.debug("failed to find model instance for proxying: %s", e.detail) return await call_next(request)