| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465 |
- import json
- import re
- import random
- import asyncio
- from typing import AsyncGenerator, List, Optional, Tuple, Union, Dict
- import aiohttp
- import logging
- from fastapi import APIRouter, Query, Request, Response, status
- from openai.types import Model as OAIModel
- from openai.pagination import SyncPage
- from sqlmodel import or_, select
- from sqlmodel.ext.asyncio.session import AsyncSession
- from starlette.datastructures import UploadFile
- from gpustack.api.exceptions import (
- BadRequestException,
- NotFoundException,
- InternalServerErrorException,
- OpenAIAPIError,
- OpenAIAPIErrorResponse,
- ServiceUnavailableException,
- GatewayTimeoutException,
- )
- from gpustack.api.responses import StreamingResponseWithStatusCode
- from gpustack import envs
- from gpustack.http_proxy.load_balancer import LoadBalancer
- from gpustack.routes.model_common import build_category_conditions
- from gpustack.schemas.models import Model
- from gpustack.schemas.model_routes import (
- ModelRoute,
- MyModel,
- effective_route_name,
- )
- from gpustack.schemas.principals import Principal, PLATFORM_PRINCIPAL_ID
- from gpustack.schemas.workers import Worker
- from gpustack.server.deps import SessionDep, CurrentUserDep
- from gpustack.server.services import (
- ModelInstanceService,
- ModelRouteService,
- WorkerService,
- UserService,
- )
- from gpustack.server.worker_request import request_to_worker, stream_to_worker
- from gpustack.gateway.utils import (
- model_instance_prefix,
- openai_model_prefixes,
- router_header_key,
- )
- logger = logging.getLogger(__name__)
- load_balancer = LoadBalancer()
- # Endpoints served by a dedicated server router (e.g. rerank.router), so the
- # auto-registration must skip them to avoid duplicate /v1/<endpoint> registration.
- _server_managed_elsewhere = {"/rerank"}
- def get_api_router() -> APIRouter:
- """Full OpenAI-compatible endpoint set, mounted at /v1."""
- router = APIRouter()
- router.add_api_route("/models", list_models, methods=["GET"])
- for rp in openai_model_prefixes:
- for endpoint in rp.prefixes:
- if endpoint in _server_managed_elsewhere:
- continue
- router.add_api_route(endpoint, proxy_request_by_model, methods=["POST"])
- return router
- def get_legacy_api_router() -> APIRouter:
- """Legacy subset, mounted at /v1-openai. Frozen — new endpoints go to /v1 only."""
- router = APIRouter()
- router.add_api_route("/models", list_models, methods=["GET"])
- for rp in openai_model_prefixes:
- if not rp.support_legacy:
- continue
- for endpoint in rp.prefixes:
- if endpoint in _server_managed_elsewhere:
- continue
- router.add_api_route(endpoint, proxy_request_by_model, methods=["POST"])
- return router
- async def list_models(
- user: CurrentUserDep,
- session: SessionDep,
- categories: List[str] = Query(
- [],
- description="Model categories to filter by.",
- ),
- with_meta: Optional[bool] = Query(
- None,
- description="Include model meta information.",
- ),
- ):
- target_class = ModelRoute if user.is_admin else MyModel
- statement = select(target_class).where(target_class.ready_targets > 0)
- if target_class == MyModel:
- # Non-admin users should only see their own private models when filtering by categories.
- statement = statement.where(target_class.user_id == user.id)
- if categories:
- conditions = build_category_conditions(session, target_class, categories)
- statement = statement.where(or_(*conditions))
- models = (await session.exec(statement)).all()
- # Bulk-load owner principals to resolve each route's effective
- # name (slug-prefixed for non-platform owners). Without the prefix,
- # two owners holding routes named "qwen3-0.6b" would publish the
- # same ``id`` here and Higress's AI proxy would dispatch
- # ambiguously.
- principal_ids = {
- m.owner_principal_id for m in models if m.owner_principal_id is not None
- }
- principal_by_id: Dict[int, Principal] = {}
- if principal_ids:
- rows = (
- await session.exec(select(Principal).where(Principal.id.in_(principal_ids)))
- ).all()
- principal_by_id = {p.id: p for p in rows}
- result = SyncPage[OAIModel](data=[], object="list")
- for model in models:
- owner = (
- principal_by_id.get(model.owner_principal_id)
- if model.owner_principal_id
- else None
- )
- eff_name = effective_route_name(
- model.name,
- getattr(owner, "slug", None),
- getattr(owner, "id", None) == PLATFORM_PRINCIPAL_ID,
- )
- result.data.append(
- OAIModel(
- id=eff_name,
- object="model",
- created=int(model.created_at.timestamp()),
- owned_by="gpustack",
- meta=model.meta if with_meta else None,
- )
- )
- return result
- async def proxy_request_by_model(
- request: Request,
- user: CurrentUserDep,
- session: SessionDep,
- ):
- endpoint = re.sub(r"^/(v1|v1-openai)/", "", request.url.path)
- """
- Proxy the request to the model instance that is running the model specified in the
- request body.
- """
- model_name, stream, body_json, form_data = await parse_request_body(request)
- if not await UserService(session).model_allowed_for_user(
- model_name=model_name,
- user_id=user.id,
- api_key=getattr(request.state, "api_key", None),
- ):
- raise NotFoundException(
- message="Model not found",
- is_openai_exception=True,
- )
- model_route_service = ModelRouteService(session)
- models: List[Model] = await model_route_service.get_model_ids_by_model_route_name(
- model_name
- )
- if len(models) == 0:
- raise NotFoundException(
- message="Model not found or no running instances available",
- is_openai_exception=True,
- )
- request.state.stream = stream
- model = random.choice(models)
- request.state.model = model
- # Resolve the route id so downstream middleware (usage recording) can
- # attribute the request to the route it entered through. The lookup
- # is @locked_cached so repeat hits within the same session are cheap.
- model_route = await model_route_service.get_by_name(model_name)
- request.state.model_route_id = model_route.id if model_route else None
- mutate_request(request, model_name, body_json, form_data)
- instance = await get_running_instance(session, model.id)
- worker: Worker = await WorkerService(session).get_by_id(instance.worker_id)
- if not worker:
- raise InternalServerErrorException(
- message=f"Worker with ID {instance.worker_id} not found",
- is_openai_exception=True,
- )
- extra_headers = {
- router_header_key: f"{model_instance_prefix(instance)}.static",
- }
- path = f"v1/{endpoint}"
- logger.debug(
- f"proxying to {instance.worker_ip}:{instance.port}, instance port: {instance.port}"
- )
- try:
- headers, data = _prepare_proxy_request(
- request,
- body_json,
- form_data,
- extra_headers,
- add_stream_options=stream,
- )
- if stream:
- return StreamingResponseWithStatusCode(
- _stream_response(
- worker,
- request.method,
- path,
- headers,
- data,
- request.app.state.http_client,
- request.app.state.http_client_no_proxy,
- ),
- media_type="text/event-stream",
- )
- else:
- resp, body = await request_to_worker(
- worker=worker,
- method=request.method,
- path=path,
- proxy_client=request.app.state.http_client,
- no_proxy_client=request.app.state.http_client_no_proxy,
- data=data,
- headers=headers,
- timeout=aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT),
- )
- return Response(
- status_code=resp.status,
- headers=dict(resp.headers),
- content=body,
- )
- except asyncio.TimeoutError as e:
- error_message = f"Request to worker {worker.id} timed out"
- if str(e):
- error_message += f": {e}"
- raise GatewayTimeoutException(
- message=error_message,
- is_openai_exception=True,
- )
- except Exception as e:
- error_message = "An unexpected error occurred"
- if str(e):
- error_message += f": {e}"
- raise ServiceUnavailableException(
- message=error_message,
- is_openai_exception=True,
- )
- async def parse_request_body(request: Request):
- model_name = None
- stream = False
- body_json = None
- form_data = None
- content_type = request.headers.get("content-type", "application/json").lower()
- if request.method == "GET":
- model_name = request.query_params.get("model")
- elif content_type.startswith("multipart/form-data"):
- form_data, model_name, stream = await parse_form_data(request)
- else:
- body_json, model_name, stream = await parse_json_body(request)
- if not model_name:
- raise BadRequestException(
- message="Missing 'model' field",
- is_openai_exception=True,
- )
- return model_name, stream, body_json, form_data
- async def parse_form_data(request: Request) -> Tuple[aiohttp.FormData, str, bool]:
- try:
- form = await request.form()
- model_name = form.get("model")
- stream = form.get("stream", False)
- form_data = aiohttp.FormData()
- for key, value in form.items():
- if isinstance(value, UploadFile):
- form_data.add_field(
- key,
- await value.read(),
- filename=value.filename,
- content_type=value.content_type,
- )
- else:
- form_data.add_field(key, value)
- return form_data, model_name, stream
- except Exception as e:
- raise BadRequestException(
- message=f"We could not parse the form body of your request: {e}",
- is_openai_exception=True,
- )
- async def parse_json_body(request: Request):
- try:
- body_json = await request.json()
- model_name = body_json.get("model")
- stream = body_json.get("stream", False)
- return body_json, model_name, stream
- except Exception as e:
- raise BadRequestException(
- message=f"We could not parse the JSON body of your request: {e}",
- is_openai_exception=True,
- )
- def _prepare_proxy_request(
- request: Request,
- body_json: Optional[dict],
- form_data: Optional[aiohttp.FormData],
- extra_headers: Optional[dict] = None,
- add_stream_options: bool = False,
- ) -> Tuple[Dict[str, str], Optional[Union[bytes, aiohttp.FormData]]]:
- """
- Prepare headers and body for proxy requests.
- Returns (headers, data) tuple.
- """
- headers = filter_headers(request.headers)
- if extra_headers:
- headers.update(extra_headers)
- if add_stream_options and body_json and "stream_options" not in body_json:
- # Defaults to include usage.
- # TODO Record usage without client awareness.
- body_json["stream_options"] = {"include_usage": True}
- # Convert body to data
- data = (
- form_data
- if form_data
- else (json.dumps(body_json).encode() if body_json else None)
- )
- # When using data=bytes (instead of json=), aiohttp doesn't set Content-Type
- if body_json and not form_data:
- headers["Content-Type"] = "application/json"
- return headers, data
- async def _stream_response(
- worker: Worker,
- method: str,
- path: str,
- headers: Dict[str, str],
- data: Optional[Union[bytes, aiohttp.FormData]],
- proxy_client: aiohttp.ClientSession,
- no_proxy_client: aiohttp.ClientSession,
- ) -> AsyncGenerator[Tuple[Union[bytes, str], Dict[str, str], int], None]:
- """
- Stream response from worker. Yields (chunk, headers, status) tuples.
- """
- try:
- async for chunk, resp_headers, resp_status in stream_to_worker(
- worker=worker,
- method=method,
- path=path,
- proxy_client=proxy_client,
- no_proxy_client=no_proxy_client,
- data=data,
- headers=headers,
- timeout=aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT),
- ):
- yield chunk, resp_headers, resp_status
- except aiohttp.ClientError as e:
- error_response = OpenAIAPIErrorResponse(
- error=OpenAIAPIError(
- message=f"Service unavailable. Please retry your requests after a brief wait. Original error: {e}",
- code=status.HTTP_503_SERVICE_UNAVAILABLE,
- type="ServiceUnavailable",
- ),
- )
- yield error_response.model_dump_json(), {}, status.HTTP_503_SERVICE_UNAVAILABLE
- except Exception as e:
- error_response = OpenAIAPIErrorResponse(
- error=OpenAIAPIError(
- message=f"Internal server error: {e}",
- code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- type="InternalServerError",
- ),
- )
- yield error_response.model_dump_json(), {}, status.HTTP_500_INTERNAL_SERVER_ERROR
- def filter_headers(headers):
- return {
- key: value
- for key, value in headers.items()
- if key.lower() != "content-length"
- and key.lower() != "host"
- and key.lower() != "content-type"
- and key.lower() != "transfer-encoding"
- and key.lower() != "authorization"
- }
- async def get_running_instance(session: AsyncSession, model_id: int):
- running_instances = await ModelInstanceService(session).get_running_instances(
- model_id
- )
- if not running_instances:
- raise ServiceUnavailableException(
- message="No running instances available",
- is_openai_exception=True,
- )
- return await load_balancer.get_instance(running_instances)
- def mutate_request(
- request: Request,
- model_name: str,
- body_json: Optional[dict],
- form_data: Optional[aiohttp.FormData],
- ):
- path = request.url.path
- model: Model = request.state.model
- if (
- path == "/v1/rerank"
- and body_json
- and model.env
- and model.env.get("GPUSTACK_APPLY_QWEN3_RERANKER_TEMPLATES", False)
- ):
- apply_qwen3_reranker_templates(body_json)
- if model_name != model.name:
- if body_json is not None:
- body_json["model"] = model.name
- elif form_data is not None:
- form_data.add_field("model", model.name)
- def apply_qwen3_reranker_templates(body_json: dict):
- """
- Apply Qwen3 reranker templates to the request body.
- See instructions in https://huggingface.co/Qwen/Qwen3-Reranker-0.6B.
- Note: Once vLLM supports built-in template rendering for this model, this can be removed.
- """
- prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
- suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
- query_template = "{prefix}<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n"
- document_template = "<Document>: {doc}{suffix}"
- if "query" in body_json and "documents" in body_json:
- query = body_json["query"]
- documents = body_json["documents"]
- body_json["query"] = query_template.format(prefix=prefix, query=query)
- body_json["documents"] = [
- document_template.format(doc=doc, suffix=suffix) for doc in documents
- ]
|