| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680 |
- import re
- import logging
- import copy
- import math
- from urllib.parse import urlparse
- from dataclasses import dataclass
- from functools import partial
- from typing import List, Optional, Tuple, Union, Dict, Any, Literal, Callable, Set
- from tenacity import retry, stop_after_attempt, wait_fixed
- from fastapi import HTTPException
- from starlette.datastructures import Headers
- from gpustack.gateway.labels_annotations import managed_labels, match_labels
- from gpustack.gateway import ai_proxy_types
- from gpustack.gateway.client.networking_higress_io_v1_api import (
- NetworkingHigressIoV1Api,
- McpBridge,
- McpBridgeRegistry,
- McpBridgeSpec,
- McpBridgeProxy,
- )
- from gpustack.gateway.client.extensions_higress_io_v1_api import (
- WasmPlugin,
- WasmPluginSpec,
- ExtensionsHigressIoV1Api,
- WasmPluginMatchRule,
- )
- from gpustack.gateway.client.networking_istio_io_v1alpha3_api import (
- NetworkingIstioIoV1Alpha3Api,
- EnvoyFilter,
- get_ingress_fallback_envoyfilter,
- )
- from gpustack.schemas.models import (
- ModelInstance,
- ModelInstancePublic,
- )
- from gpustack.schemas.model_provider import (
- ModelProvider,
- ModelProviderTypeEnum,
- )
- from gpustack.schemas.model_routes import ModelRoute
- from gpustack.server.bus import EventType
- from gpustack.server.db import async_session
- from gpustack.server.services import ModelInstanceService, WorkerService
- from gpustack.schemas.config import ModelInstanceProxyModeEnum
- from gpustack.schemas.workers import Worker
- from gpustack.schemas.clusters import Cluster
- from gpustack.utils.network import is_ipaddress
- from kubernetes_asyncio import client as k8s_client
- from kubernetes_asyncio.client import ApiException, V1IngressTLS
- from gpustack.envs import GATEWAY_MIRROR_INGRESS_NAME
- from gpustack.api.exceptions import NotFoundException
- from gpustack.websocket_proxy.message import ServerInfo, RegisteredClientInfo
- logger = logging.getLogger(__name__)
- default_mcp_bridge_name = "default"
- gpustack_ai_proxy_name = "gpustack-ai-proxy"
- gpustack_model_mapper_name = "gpustack-model-mapper"
- gpustack_generic_route_transformer_name = "gpustack-generic-route-transformer"
- model_ingress_prefix = "ai-route-model-"
- model_route_ingress_prefix = "ai-route-route-"
- provider_id_prefix = "provider-"
- model_id_prefix = "model-"
- router_header_key = "X-GPUStack-Model-Instance"
- gpustack_original_path_header = "x-gpustack-original-path"
- gpustack_fallback_path_header = "x-gpustack-fallback-path"
- # Type alias for destination tuples
- # Each tuple contains (weight: int, model_name: str, registry: McpBridgeRegistry)
- DestinationTupleList = List[Tuple[int, str, McpBridgeRegistry]]
- @dataclass
- class RoutePrefix:
- prefixes: List[str]
- support_legacy: bool = False
- additional_versions: Optional[List[str]] = None
- def flattened_prefixes(self) -> List[str]:
- versioned_prefixes = ["/v1"]
- if self.support_legacy:
- versioned_prefixes.append("/v1-openai")
- if self.additional_versions:
- versioned_prefixes.extend(self.additional_versions)
- flattened = []
- for versioned_prefix in versioned_prefixes:
- for prefix in self.prefixes:
- flattened.append(f"{versioned_prefix}{prefix}")
- return flattened
- def regex_prefixes(self) -> List[str]:
- """
- Returns regex patterns for the prefixes, considering versioning and legacy support.
- It supports removing -openai suffix from the versioned prefix with rewrite-target: /$1$3
- """
- versioned_prefixes = [f"/(v1){'(-openai)?' if self.support_legacy else '()'}"]
- if self.additional_versions:
- versioned_prefixes.extend(
- f"/({re.escape(additional_version.lstrip('/'))})()"
- for additional_version in self.additional_versions
- )
- return [
- f"{versioned_prefix}({prefix})"
- for versioned_prefix in versioned_prefixes
- for prefix in self.prefixes
- ]
- openai_model_prefixes: List[RoutePrefix] = [
- RoutePrefix(
- [
- "/chat/completions",
- "/completions",
- "/responses",
- "/embeddings",
- "/audio/transcriptions",
- "/audio/speech",
- "/images/generations",
- "/images/edits",
- ],
- True,
- ),
- RoutePrefix(
- [
- "/audio/translations",
- "/images/variations",
- "/moderations",
- "/score",
- ]
- ),
- RoutePrefix(["/rerank"], additional_versions=["/v2"]),
- ]
- anthropic_model_exact: List[RoutePrefix] = [
- RoutePrefix(["/messages", "/messages/count_tokens", "/complete"]),
- ]
- def get_default_mcpbridge_ref(
- mcp_bridge_name: str = default_mcp_bridge_name,
- ) -> k8s_client.V1TypedLocalObjectReference:
- # the name is hardcoded in Higress MCP Bridge controller
- return k8s_client.V1TypedLocalObjectReference(
- api_group='networking.higress.io',
- kind='McpBridge',
- name=mcp_bridge_name,
- )
- def wrap_route(
- path: str,
- path_type: str,
- backend: Optional[k8s_client.V1IngressBackend] = None,
- ) -> k8s_client.V1HTTPIngressPath:
- if backend is None:
- backend = k8s_client.V1IngressBackend(
- resource=get_default_mcpbridge_ref(),
- )
- return k8s_client.V1HTTPIngressPath(
- path=path,
- path_type=path_type,
- backend=backend,
- )
- def anthropic_routes() -> List[k8s_client.V1HTTPIngressPath]:
- routes = []
- for route_exact in anthropic_model_exact:
- for prefix in route_exact.regex_prefixes():
- routes.append(wrap_route(path=prefix, path_type="ImplementationSpecific"))
- return routes
- def ingress_rule_for_model() -> k8s_client.V1IngressRule:
- paths: List[k8s_client.V1HTTPIngressPath] = []
- for route_prefix in openai_model_prefixes:
- for prefix in route_prefix.regex_prefixes():
- paths.append(wrap_route(path=prefix, path_type="ImplementationSpecific"))
- return k8s_client.V1IngressRule(http=k8s_client.V1HTTPIngressRuleValue(paths=paths))
- def cluster_mcp_bridge_name(cluster_id: int) -> str:
- # higress_controller has hardcoded mcp bridge name to 'default'
- # the name should be based on cluster_id if higress_controller supports multiple mcp bridges
- return default_mcp_bridge_name
- def model_mcp_bridge_name(cluster_id: int) -> str:
- return cluster_mcp_bridge_name(cluster_id)
- def model_route_cleanup_prefix(model_route_id: int) -> str:
- return f"{model_route_ingress_prefix}{model_route_id}"
- def model_route_ingress_name(model_route_id: int) -> str:
- return f"{model_route_ingress_prefix}{model_route_id}.internal"
- def fallback_ingress_name(name: str) -> str:
- split_name = name.rsplit('.', 1)
- if len(split_name) == 1:
- return f"{name}.fallback"
- return f"{split_name[0]}.fallback.{split_name[1]}"
- def model_ingress_name(model_id: int) -> str:
- return f"{model_ingress_prefix}{model_id}"
- def cluster_worker_prefix(cluster_id: int) -> str:
- return f"cluster-{cluster_id}-worker-"
- def model_prefix(model_id: int) -> str:
- return f"{model_id_prefix}{model_id}-"
- def model_instance_prefix(
- model_instance: Union[ModelInstance, ModelInstancePublic]
- ) -> str:
- return f"{model_prefix(model_instance.model_id)}{model_instance.id}"
- def model_instance_registry(
- model_instance: Union[ModelInstance, ModelInstancePublic],
- worker: Optional[Worker] = None,
- ) -> Optional[McpBridgeRegistry]:
- name = model_instance_prefix(model_instance)
- if worker is not None:
- if worker.proxy_mode == ModelInstanceProxyModeEnum.WORKER:
- return _worker_reserve_proxy_registry(worker, name)
- elif worker.proxy_mode == ModelInstanceProxyModeEnum.TUNNEL:
- return _worker_tunnel_proxy_registry(worker, name)
- address = model_instance.worker_advertise_address or model_instance.worker_ip
- if address is None or address == "" or model_instance.port is None:
- return None
- domain = address
- port = model_instance.port
- registry_type = "dns"
- if is_ipaddress(address):
- domain = f"{address}:{model_instance.port}"
- port = 80
- registry_type = "static"
- return McpBridgeRegistry(
- domain=domain,
- port=port,
- name=name,
- protocol="http",
- type=registry_type,
- )
- def _worker_reserve_proxy_registry(
- worker: Worker, name_override: Optional[str] = None
- ) -> McpBridgeRegistry:
- """Build an McpBridgeRegistry entry for a worker in DIRECT or WORKER proxy mode.
- Uses ``worker.advertise_address`` when available, otherwise falls back to
- ``worker.ip``. For raw IP addresses the registry type is set to ``static``
- and the host:port pair is encoded in the domain field; for hostnames the
- type is ``dns`` and the port is carried separately.
- Returns ``None`` if the worker has no resolvable address or port.
- """
- address = worker.advertise_address or worker.ip
- if address is None or address == "" or worker.port is None:
- return None
- domain = address
- port = worker.port
- registry_type = "dns"
- if is_ipaddress(address):
- domain = f"{address}:{worker.port}"
- port = 80
- registry_type = "static"
- return McpBridgeRegistry(
- domain=domain,
- port=port,
- name=name_override or f"{cluster_worker_prefix(worker.cluster_id)}{worker.id}",
- protocol="http",
- type=registry_type,
- )
- def _worker_tunnel_proxy_registry(
- worker: Worker, name_override: Optional[str] = None
- ) -> Optional[McpBridgeRegistry]:
- """Build an McpBridgeRegistry entry for a worker in TUNNEL proxy mode.
- Points the registry at the server-side HTTP proxy address stored in
- ``worker.proxy_address``, which is populated by
- ``worker_websocket_connect_callback`` when the worker's WebSocket tunnel
- connects. The gateway routes inference requests to this proxy, which then
- tunnels them to the worker via the persistent WebSocket connection.
- Returns ``None`` if the worker has no proxy address (i.e. the WebSocket
- tunnel has not yet connected).
- """
- if worker.get_proxy_address() is None:
- return None
- # proxy address must be a valid URL and the netloc must be a valid IP.
- result = urlparse(worker.get_proxy_address())
- protocol = "http" if result.scheme == "http" else "https"
- port = result.port or (80 if protocol == "http" else 443)
- return McpBridgeRegistry(
- domain=f"{result.hostname}:{port}",
- port=80,
- name=name_override or f"{cluster_worker_prefix(worker.cluster_id)}{worker.id}",
- protocol=protocol,
- type="static",
- )
- def cluster_registry(cluster: Cluster) -> Optional[McpBridgeRegistry]:
- if cluster.gateway_endpoint is None and cluster.reported_gateway_endpoint is None:
- return None
- return McpBridgeRegistry(
- domain=cluster.gateway_endpoint or cluster.reported_gateway_endpoint,
- port=80,
- name="cluster-gateway",
- protocol="http",
- type="static",
- )
- def provider_registry_name(id: int) -> str:
- return f"{provider_id_prefix}{id}"
- def provider_registry(provider: ModelProvider) -> Optional[McpBridgeRegistry]:
- provider_url = provider.config.get_base_url()
- if provider_url is None:
- return None
- result = urlparse(url=provider_url)
- protocol = "http" if result.scheme == "http" else "https"
- port = 443 if protocol == "https" else 80
- registry_type = (
- "static" if result.hostname and is_ipaddress(result.hostname) else "dns"
- )
- if registry_type == "static":
- domain = result.netloc
- if result.port is None:
- domain = f"{domain}:{port}"
- else:
- domain = result.hostname
- if result.port is not None:
- port = result.port
- registry_name = provider_registry_name(provider.id)
- proxyName = f"{registry_name}-proxy" if provider.proxy_url else None
- return McpBridgeRegistry(
- domain=domain,
- port=port,
- name=registry_name,
- protocol=protocol,
- type=registry_type,
- proxyName=proxyName,
- )
- def provider_proxy(provider: ModelProvider) -> Optional[McpBridgeProxy]:
- if provider.proxy_url is None:
- return None
- proxy_url = urlparse(provider.proxy_url)
- scheme = proxy_url.scheme
- port = proxy_url.port
- if port is None:
- port = 443 if scheme == "https" else 80
- # timeout in seconds
- connection_timeout = provider.proxy_timeout or 5
- return McpBridgeProxy(
- name=f"{provider_registry_name(provider.id)}-proxy",
- serverAddress=proxy_url.hostname,
- serverPort=port,
- type=scheme.upper(),
- # convert to milliseconds
- connectTimeout=connection_timeout * 1000,
- )
- def provider_proxy_plugin_spec(
- *providers: ModelProvider,
- ) -> Tuple[List[Dict[str, Any]], List[WasmPluginMatchRule]]:
- provider_list = []
- match_rules = []
- sorted_providers: List[ModelProvider] = sorted(providers, key=lambda p: p.id)
- for provider in sorted_providers:
- registry = provider_registry(provider)
- if registry is None:
- continue
- service_name = registry.get_service_name()
- default_config_data = {
- "id": provider_registry_name(provider.id),
- "apiTokens": provider.api_tokens,
- **provider.config.model_dump_with_default_override(),
- "type": provider.config.type.value,
- }
- accessible_llm_model = next(
- (model.name for model in provider.models or [] if model.category == "llm"),
- None,
- )
- # Failover has more config
- if accessible_llm_model and len(provider.api_tokens) > 1:
- default_config_data["failover"] = ai_proxy_types.FailoverConfig(
- enabled=True,
- healthCheckModel=accessible_llm_model,
- )
- default_config = ai_proxy_types.AIProxyDefaultConfig.model_validate(
- default_config_data
- )
- provider_list.append(
- default_config.model_dump(by_alias=True, exclude_none=True)
- )
- active_config = ai_proxy_types.ActiveConfig(
- activeProviderId=provider_registry_name(provider.id),
- ).model_dump(exclude_none=True)
- match_rules.append(
- WasmPluginMatchRule(
- config=active_config,
- service=[service_name],
- configDisable=False,
- )
- )
- return provider_list, match_rules
- def diff_registries(
- existing: List[McpBridgeRegistry],
- desired: List[McpBridgeRegistry],
- to_delete_prefix: Optional[str] = None,
- ) -> Tuple[bool, List[McpBridgeRegistry]]:
- desired_map = {
- reg.name: idx for idx, reg in enumerate(desired) if reg.name is not None
- }
- total_list = []
- need_update = False
- for registry in existing:
- if registry.name not in desired_map:
- # delete registries that are not in the current list
- if to_delete_prefix is not None and registry.name.startswith(
- to_delete_prefix
- ):
- need_update = True
- else:
- # keep unrelated registries
- total_list.append(registry)
- else:
- # update existing registries
- idx = desired_map.pop(registry.name)
- if registry != desired[idx]:
- need_update = True
- registry = desired[idx]
- total_list.append(registry)
- # add new registries
- for idx in desired_map.values():
- need_update = True
- total_list.append(desired[idx])
- total_list.sort(key=lambda r: r.name or "")
- return need_update, total_list
- def diff_proxies(
- existing: List[McpBridgeProxy],
- desired: List[McpBridgeProxy],
- to_delete_prefix: Optional[str] = None,
- ) -> Tuple[bool, List[McpBridgeProxy]]:
- desired_map = {
- reg.name: idx for idx, reg in enumerate(desired) if reg.name is not None
- }
- total_list = []
- need_update = False
- for proxy in existing:
- if proxy.name not in desired_map:
- # delete registries that are not in the current list
- if to_delete_prefix is not None and proxy.name.startswith(to_delete_prefix):
- need_update = True
- else:
- # keep unrelated proxies
- total_list.append(proxy)
- else:
- # update existing proxies
- idx = desired_map.pop(proxy.name)
- if proxy != desired[idx]:
- need_update = True
- proxy = desired[idx]
- total_list.append(proxy)
- # add new proxies
- for idx in desired_map.values():
- need_update = True
- total_list.append(desired[idx])
- total_list.sort(key=lambda r: r.name or "")
- return need_update, total_list
- @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
- async def ensure_mcp_bridge(
- client: NetworkingHigressIoV1Api,
- namespace: str,
- mcp_bridge_name: str,
- desired_registries: List[McpBridgeRegistry],
- to_delete_prefix: Optional[str] = None,
- desired_proxies: List[McpBridgeProxy] = None,
- to_delete_proxies_prefix: Optional[str] = None,
- ):
- existing_bridge = None
- try:
- mcpbridge_dict = await client.get_mcpbridge(namespace, mcp_bridge_name)
- existing_bridge = McpBridge.model_validate(mcpbridge_dict)
- except ApiException as e:
- if e.status != 404:
- raise
- if existing_bridge is None:
- mcpbridge_body = McpBridge(
- metadata={
- "name": mcp_bridge_name,
- "namespace": namespace,
- "labels": managed_labels,
- },
- spec=McpBridgeSpec(registries=desired_registries, proxies=desired_proxies),
- )
- await client.create_mcpbridge(
- namespace=namespace,
- body=mcpbridge_body,
- )
- logger.info(f"Created MCP Bridge {mcp_bridge_name} in namespace {namespace}.")
- else:
- registry_need_update, registry_list = diff_registries(
- existing=existing_bridge.spec.registries or [],
- desired=desired_registries,
- to_delete_prefix=to_delete_prefix,
- )
- proxy_need_update = False
- proxy_list = existing_bridge.spec.proxies or []
- if desired_proxies is not None:
- proxy_need_update, proxy_list = diff_proxies(
- existing=existing_bridge.spec.proxies or [],
- desired=desired_proxies,
- to_delete_prefix=to_delete_proxies_prefix,
- )
- if registry_need_update or proxy_need_update:
- registry_list.sort(key=lambda r: r.name or "")
- proxy_list.sort(key=lambda r: r.name or "")
- existing_bridge.spec.registries = registry_list
- existing_bridge.spec.proxies = proxy_list
- await client.edit_mcpbridge(
- name=mcp_bridge_name,
- namespace=namespace,
- body=existing_bridge,
- )
- logger.info(
- f"Updated MCP Bridge {mcp_bridge_name} in namespace {namespace}."
- )
- def generate_model_ingress(
- ingress_name: str,
- namespace: str,
- route_name: str,
- destinations: str,
- hostname: Optional[str] = None,
- tls: Optional[List[V1IngressTLS]] = None,
- included_generic_route: Optional[bool] = False,
- included_proxy_route: Optional[bool] = False,
- extra_annotations: Optional[Dict[str, str]] = None,
- ingress_class_name: str = "higress",
- ) -> k8s_client.V1Ingress:
- retry_policies = "error,timeout,http_503,http_502,non_idempotent"
- annotations = {
- "higress.io/rewrite-target": "/$1$3",
- "higress.io/destination": destinations,
- "higress.io/ignore-path-case": 'true',
- "higress.io/proxy-next-upstream-tries": '2',
- "higress.io/proxy-next-upstream": retry_policies,
- **higress_http_header_matcher("exact", "x-higress-llm-model", route_name),
- }
- if extra_annotations is not None:
- annotations.update(extra_annotations)
- metadata = k8s_client.V1ObjectMeta(
- name=ingress_name,
- namespace=namespace,
- annotations=annotations,
- labels=managed_labels,
- )
- expected_rule = ingress_rule_for_model()
- if included_proxy_route:
- # to compatible with rewrite-target /$1$3, the first capturing group is empty.
- # The /\d+ variant strips the route-id segment from /model/proxy/<id>/<path>
- # so the upstream receives /<path>. The id-less variant preserves the legacy
- # /model/proxy/<path> + X-GPUStack-Model header form. The more specific rule
- # is listed first so Higress tries id-based matching before falling back.
- expected_rule.http.paths.append(
- wrap_route(
- r"/()model/proxy/\d+(/|$)(.*)",
- "ImplementationSpecific",
- )
- )
- expected_rule.http.paths.append(
- wrap_route(
- "/()model/proxy(/|$)(.*)",
- "ImplementationSpecific",
- )
- )
- if included_generic_route:
- expected_rule.http.paths.append(wrap_route("/", "Prefix"))
- # support for Anthropic API
- expected_rule.http.paths.extend(anthropic_routes())
- spec = k8s_client.V1IngressSpec(
- ingress_class_name=ingress_class_name, rules=[expected_rule]
- )
- if hostname is not None:
- hostname_rule = copy.deepcopy(expected_rule)
- hostname_rule.host = hostname
- spec.rules.append(hostname_rule)
- spec.tls = tls
- ingress = k8s_client.V1Ingress(
- api_version="networking.k8s.io/v1",
- kind="Ingress",
- metadata=metadata,
- spec=spec,
- )
- return ingress
- def higress_metadata_equal(
- existing_metadata: Optional[k8s_client.V1ObjectMeta],
- expected_metadata: Optional[k8s_client.V1ObjectMeta],
- ) -> bool:
- existing_metadata = existing_metadata or k8s_client.V1ObjectMeta()
- expected_metadata = expected_metadata or k8s_client.V1ObjectMeta()
- if existing_metadata.annotations is None:
- existing_metadata.annotations = {}
- if expected_metadata.annotations is None:
- expected_metadata.annotations = {}
- for key in set(
- k for k in expected_metadata.annotations if k.startswith("higress.io")
- ):
- if existing_metadata.annotations.get(key) != expected_metadata.annotations.get(
- key
- ):
- return False
- return True
- def ingress_tls_equal(
- existing: Optional[k8s_client.V1IngressTLS],
- expected: Optional[k8s_client.V1IngressTLS],
- ) -> bool:
- if (existing is None) != (expected is None):
- return False
- if existing and expected:
- if len(existing) != len(expected):
- return False
- for etls, xtls in zip(existing, expected):
- # only compares hosts and secret_name for tls equal
- if getattr(etls, 'hosts', None) != getattr(xtls, 'hosts', None):
- return False
- if getattr(etls, 'secret_name', None) != getattr(xtls, 'secret_name', None):
- return False
- return True
- def mcp_ingress_equal(
- existing: k8s_client.V1Ingress, expected: k8s_client.V1Ingress
- ) -> bool:
- if not higress_metadata_equal(
- existing_metadata=existing.metadata, expected_metadata=expected.metadata
- ):
- return False
- if existing.spec is None or expected.spec is None:
- return False
- if not ingress_tls_equal(
- existing=getattr(existing.spec, 'tls', None),
- expected=getattr(expected.spec, 'tls', None),
- ):
- return False
- if len(existing.spec.rules or []) != len(expected.spec.rules or []):
- return False
- for existing_rule, expected_rule in zip(
- existing.spec.rules or [], expected.spec.rules or []
- ):
- if getattr(existing_rule, 'host', None) != getattr(expected_rule, 'host', None):
- return False
- if existing_rule.http is None or expected_rule.http is None:
- return False
- if len(existing_rule.http.paths or []) != len(expected_rule.http.paths or []):
- return False
- for existing_path, expected_path in zip(
- existing_rule.http.paths or [], expected_rule.http.paths or []
- ):
- if existing_path.path != expected_path.path:
- return False
- if existing_path.path_type != expected_path.path_type:
- return False
- if existing_path.backend.resource != expected_path.backend.resource:
- return False
- return True
- def scale_weight(weight_instance_pairs: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
- """
- Scale weights based on the least common multiple of counts to maintain proportionality.
- """
- counts = [count for _, count in weight_instance_pairs if count > 0]
- if not counts:
- return weight_instance_pairs
- lcm_count = math.lcm(*counts)
- scaled = [
- (weight * lcm_count // count if count > 0 else 0, count)
- for weight, count in weight_instance_pairs
- ]
- return scaled
- def hamilton_calculate_weight(
- weight_instance_pairs: List[Tuple[int, int]],
- max_weight: Optional[int] = 0,
- ) -> List[int]:
- """
- hamilton_calculate_weight to allocate percentage based on weight and instance count.
- The total should be 100.
- :param weight_instance_pairs: weight and instance count pairs
- :type weight_instance_pairs: List[Tuple[int, int]]
- :return: list of percentage for instance
- :rtype: List[int]
- """
- weight_instance_pairs = scale_weight(weight_instance_pairs)
- instances_info = []
- for weight, instance_count in weight_instance_pairs:
- for _ in range(instance_count):
- instances_info.append({'weight': weight, 'group_weight': weight})
- total_weight = sum(max(info['weight'], max_weight) for info in instances_info)
- if total_weight == 0:
- return []
- for info in instances_info:
- weight = max(info['weight'], max_weight)
- info['exact_quota'] = weight * 100 / total_weight
- info['floor_quota'] = int(info['exact_quota'])
- info['remainder'] = info['exact_quota'] - info['floor_quota']
- total_floor = sum(info['floor_quota'] for info in instances_info)
- remaining_seats = 100 - total_floor
- sorted_instances = sorted(instances_info, key=lambda x: -x['remainder'])
- for i in range(remaining_seats):
- sorted_instances[i]['floor_quota'] += 1
- return [info['floor_quota'] for info in instances_info]
- def model_instances_registry_list(
- model_instances: List[Union[ModelInstance, ModelInstancePublic]],
- workers: Optional[Dict[int, Worker]] = None,
- ) -> DestinationTupleList:
- registries: DestinationTupleList = []
- for model_instance in model_instances:
- worker = (
- (workers or {}).get(model_instance.worker_id)
- if model_instance.worker_id
- else None
- )
- registry = model_instance_registry(model_instance, worker=worker)
- if registry is not None:
- registries.append((1, model_instance.model_name, registry))
- return registries
- @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
- async def ensure_model_ingress(
- ingress_name: str,
- ingress_class_name: str,
- route_name: str,
- namespace: str,
- destinations: DestinationTupleList,
- event_type: EventType,
- networking_api: k8s_client.NetworkingV1Api,
- included_generic_route: Optional[bool] = False,
- included_proxy_route: Optional[bool] = False,
- extra_annotations: Optional[Dict[str, str]] = None,
- ):
- """
- Ensure the model ingress resource in Kubernetes matches the desired state.
- Parameters:
- ingress_name (str): The name of the ingress resource.
- namespace (str): The Kubernetes namespace for the ingress resource.
- destinations (DestinationTupleList): Weighted list of MCP Bridge registries for traffic routing.
- route_name (str): The name of the model route for which ingress is managed.
- event_type (EventType): The event type (CREATED, UPDATED, DELETED) triggering reconciliation.
- networking_api (k8s_client.NetworkingV1Api): The Kubernetes networking API client.
- hostname (Optional[str]): The external hostname for ingress routing.
- tls_secret_name (Optional[str]): TLS secret name for HTTPS ingress.
- included_generic_route (bool): Whether to include a generic '/' route for fallback traffic. Used in worker gateway.
- included_proxy_route (bool): Whether to include a proxy route for model traffic (e.g., /model/proxy/{model_name}). Used in server gateway.
- """
- if event_type == EventType.DELETED:
- try:
- await networking_api.delete_namespaced_ingress(
- name=ingress_name, namespace=namespace
- )
- logger.info(
- f"Deleted model ingress {ingress_name} for model route {route_name}"
- )
- except ApiException as e:
- if e.status != 404:
- logger.error(f"Failed to delete ingress {ingress_name}: {e}")
- return
- expected_destinations = '\n'.join(
- [
- f"{persentage}% {registry.get_service_name_with_port()}"
- for persentage, _, registry in destinations
- ]
- )
- try:
- existing_ingress: Optional[k8s_client.V1Ingress] = (
- await networking_api.read_namespaced_ingress(
- name=ingress_name, namespace=namespace
- )
- )
- except ApiException as e:
- if e.status != 404:
- logger.error(f"Failed to get ingress {ingress_name}: {e}")
- return
- existing_ingress = None
- hostname, tls = await mirror_hostname_tls_from_ingress(
- network_v1_client=networking_api,
- gateway_namespace=namespace,
- target_ingress_name=GATEWAY_MIRROR_INGRESS_NAME,
- )
- expected_ingress = generate_model_ingress(
- ingress_name=ingress_name,
- route_name=route_name,
- namespace=namespace,
- destinations=expected_destinations,
- hostname=hostname,
- tls=tls,
- included_generic_route=included_generic_route,
- included_proxy_route=included_proxy_route,
- extra_annotations=extra_annotations,
- ingress_class_name=ingress_class_name,
- )
- if existing_ingress is None:
- await networking_api.create_namespaced_ingress(
- namespace=namespace,
- body=expected_ingress,
- )
- logger.info(
- f"Created model ingress {ingress_name} for model route {route_name}"
- )
- else:
- is_equal = mcp_ingress_equal(
- existing=existing_ingress, expected=expected_ingress
- )
- if not is_equal:
- existing_ingress.spec = expected_ingress.spec
- metadata = existing_ingress.metadata or k8s_client.V1ObjectMeta()
- metadata.annotations = metadata.annotations or {}
- expected_higress_keys = set()
- for key, value in (expected_ingress.metadata.annotations or {}).items():
- if key.startswith("higress.io"):
- metadata.annotations[key] = value
- expected_higress_keys.add(key)
- to_delete = [
- key
- for key in metadata.annotations.keys()
- if key.startswith("higress.io") and key not in expected_higress_keys
- ]
- for key in to_delete:
- del metadata.annotations[key]
- await networking_api.replace_namespaced_ingress(
- name=ingress_name,
- namespace=namespace,
- body=existing_ingress,
- )
- logger.info(
- f"Updated model ingress {ingress_name} for model route {route_name}"
- )
- @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
- async def ensure_wasm_plugin(
- api: ExtensionsHigressIoV1Api,
- name: str,
- namespace: str,
- spec_diff: Callable[[Optional[WasmPluginSpec]], WasmPluginSpec],
- extra_labels: Optional[Dict[str, str]] = None,
- ):
- labels = copy.deepcopy(managed_labels)
- if extra_labels:
- labels.update(extra_labels)
- current_plugin = None
- try:
- data: Dict[str, Any] = await api.get_wasmplugin(namespace=namespace, name=name)
- current_plugin = WasmPlugin.model_validate(data)
- except ApiException as e:
- if e.status == 404:
- current_plugin = None
- else:
- raise
- current_spec = getattr(current_plugin, 'spec', None)
- expected = spec_diff(copy.deepcopy(current_spec))
- if current_plugin is None:
- wasm_plugin_body = WasmPlugin(
- metadata={
- "name": name,
- "namespace": namespace,
- "labels": labels,
- },
- spec=expected,
- )
- await api.create_wasmplugin(
- namespace=namespace,
- body=wasm_plugin_body,
- )
- logger.info(f"Created WasmPlugin {name} in namespace {namespace}.")
- elif match_labels(current_plugin.metadata.get("labels", {}), labels):
- current_spec = (
- current_plugin.spec.model_dump(exclude_none=True)
- if current_plugin.spec
- else {}
- )
- expected_spec = expected.model_dump(exclude_none=True) if expected else {}
- if current_spec != expected_spec:
- current_plugin.spec = expected
- await api.edit_wasmplugin(
- namespace=namespace,
- name=name,
- body=current_plugin,
- )
- logger.info(f"Updated WasmPlugin {name} in namespace {namespace}.")
- async def cleanup_model_mapper(
- namespace: str,
- expected_ingresses: List[str],
- config: k8s_client.Configuration,
- extra_labels: Optional[Dict[str, str]] = None,
- ):
- if config is None:
- return
- api = ExtensionsHigressIoV1Api(k8s_client.ApiClient(config))
- labels = copy.deepcopy(managed_labels)
- if extra_labels:
- labels.update(extra_labels)
- def spec_diff(current_spec: Optional[WasmPluginSpec]) -> WasmPluginSpec:
- if current_spec is None:
- return current_spec
- to_keep_rules: List[WasmPluginMatchRule] = []
- for rule in current_spec.matchRules or []:
- if any(ingress in expected_ingresses for ingress in rule.ingress):
- to_keep_rules.append(rule)
- else:
- logger.info(
- f"Removing rule with ingress {rule.ingress} from model mapper plugin as it is not in expected ingresses."
- )
- to_keep_rules.sort(key=lambda r: r.ingress[0] if r.ingress else "")
- current_spec.matchRules = to_keep_rules
- return current_spec
- await ensure_wasm_plugin(
- api=api,
- name=gpustack_model_mapper_name,
- namespace=namespace,
- spec_diff=spec_diff,
- extra_labels=extra_labels,
- )
- async def cleanup_ingresses(
- namespace: str,
- expected_names: List[str],
- config: k8s_client.Configuration,
- cleanup_prefix: str,
- reason: str = "orphaned",
- ):
- if config is None:
- return
- networking_api = k8s_client.NetworkingV1Api(k8s_client.ApiClient(config))
- try:
- # Use label selector to filter only managed ingresses
- label_selector = ','.join([f"{k}={v}" for k, v in managed_labels.items()])
- ingresses = await networking_api.list_namespaced_ingress(
- namespace=namespace,
- label_selector=label_selector,
- )
- for ingress in ingresses.items:
- # name must be not None due to label selector
- name: str = ingress.metadata.name
- if name in expected_names or not name.startswith(cleanup_prefix):
- continue
- await networking_api.delete_namespaced_ingress(
- name=name, namespace=namespace
- )
- logger.info(
- f"Deleted {reason} model ingress {name} in namespace {namespace}."
- )
- except Exception as e:
- logger.error(f"Error cleaning up {reason} model ingresses: {e}")
- async def ensure_model_mcp_bridge(
- event_type: EventType,
- model_id: int,
- model_instances: List[Union[ModelInstance, ModelInstancePublic]],
- networking_higress_api: NetworkingHigressIoV1Api,
- namespace: str,
- cluster_id: int,
- workers: Optional[Dict[int, Worker]] = None,
- ) -> List[McpBridgeRegistry]:
- desired_registry: List[McpBridgeRegistry] = []
- to_delete_prefix: Optional[str] = model_prefix(model_id)
- if event_type != EventType.DELETED:
- for model_instance in model_instances:
- worker = (
- (workers or {}).get(model_instance.worker_id)
- if model_instance.worker_id
- else None
- )
- registry = model_instance_registry(model_instance, worker=worker)
- if registry is not None:
- desired_registry.append(registry)
- await ensure_mcp_bridge(
- client=networking_higress_api,
- namespace=namespace,
- mcp_bridge_name=model_mcp_bridge_name(cluster_id),
- desired_registries=desired_registry,
- to_delete_prefix=to_delete_prefix,
- )
- return desired_registry
- async def mirror_hostname_tls_from_ingress(
- network_v1_client: k8s_client.NetworkingV1Api,
- gateway_namespace: str,
- target_ingress_name: str,
- ) -> Tuple[Optional[str], Optional[List[V1IngressTLS]]]:
- """
- Mirror TLS settings from an existing ingress to be used in the gateway.
- Parameters:
- api_client (k8s_client.ApiClient): The Kubernetes API client.
- gateway_namespace (str): The namespace where the gateway ingress resides.
- target_ingress_name (str): The name of the ingress to mirror TLS settings from.
- Returns:
- Tuple[Optional[str], Optional[List[V1IngressTLS]]]: A tuple containing the hostname and ingress TLS settings.
- """
- try:
- ingress: k8s_client.V1Ingress = await network_v1_client.read_namespaced_ingress(
- name=target_ingress_name, namespace=gateway_namespace
- )
- except ApiException as e:
- if e.status == 404:
- logger.warning(
- f"Target ingress {target_ingress_name} not found in namespace {gateway_namespace} for TLS mirroring."
- )
- return None, None
- else:
- raise
- tls = getattr(ingress.spec, 'tls', None)
- hostname = None
- for rule in ingress.spec.rules or []:
- if rule.host:
- hostname = rule.host
- break
- return hostname, tls
- def get_expected_match_list(
- route_name: str,
- ingress_prefix: str,
- ingress_name: str,
- model_name_to_registries: Dict[str, List[str]],
- fallback_model_name_to_registries: Dict[str, List[str]],
- ) -> List[WasmPluginMatchRule]:
- match_list: List[WasmPluginMatchRule] = []
- ingress_name = f"{ingress_prefix}{ingress_name}"
- for model_name, service_names in model_name_to_registries.items():
- config = {"modelMapping": {route_name: model_name}}
- match_list.append(
- WasmPluginMatchRule(
- config=config,
- ingress=[ingress_name],
- configDisable=False,
- service=service_names,
- )
- )
- for model_name, service_names in fallback_model_name_to_registries.items():
- # the fallback mapping should include both normal ingress and fallback ingress
- # as the normal ingress may not exist when only fallback model is set
- fallback_name = fallback_ingress_name(ingress_name)
- config = {"modelMapping": {route_name: model_name}}
- match_list.append(
- WasmPluginMatchRule(
- config=config,
- ingress=[ingress_name, fallback_name],
- configDisable=False,
- service=service_names,
- )
- )
- return match_list
- def higress_http_header_matcher(
- operator: Literal["exact", "regex", "prefix"],
- header_key: str,
- header_value: str,
- ) -> Dict[str, str]:
- header_matcher = "match-header"
- return {
- f"higress.io/{operator}-{header_matcher}-{header_key}": header_value,
- }
- async def cleanup_fallback_filters(
- namespace: str,
- expected_names: List[str],
- cleanup_prefix: str,
- reason: str = "orphaned",
- networking_istio_api: Optional[NetworkingIstioIoV1Alpha3Api] = None,
- k8s_config: Optional[k8s_client.Configuration] = None,
- ):
- if networking_istio_api is None:
- if k8s_config is None:
- return
- networking_istio_api = NetworkingIstioIoV1Alpha3Api(
- k8s_client.ApiClient(k8s_config)
- )
- try:
- label_selector = ','.join([f"{k}={v}" for k, v in managed_labels.items()])
- filters = await networking_istio_api.list_envoyfilters(
- namespace=namespace,
- label_selector=label_selector,
- )
- items: List[Dict[str, Any]] = filters.get('items', [])
- for filter_item in items:
- # name must be not None due to label selector
- name = filter_item.get("metadata", {}).get("name", None)
- if (
- name is None
- or name in expected_names
- or not name.startswith(cleanup_prefix)
- ):
- continue
- await networking_istio_api.delete_envoyfilter(
- name=name, namespace=namespace
- )
- logger.info(
- f"Deleted {reason} fallback filter {name} in namespace {namespace}."
- )
- except Exception as e:
- logger.error(f"Error cleaning up {reason} fallback filters: {e}")
- @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
- async def ensure_fallback_filter(
- event_type: EventType,
- ingress_name: str,
- namespace: str,
- networking_istio_api: NetworkingIstioIoV1Alpha3Api,
- ):
- if event_type == EventType.DELETED:
- await cleanup_fallback_filters(
- namespace=namespace,
- expected_names=[],
- networking_istio_api=networking_istio_api,
- cleanup_prefix=ingress_name,
- reason="event deleted",
- )
- return
- existing_filter = None
- try:
- filter_dict = await networking_istio_api.get_envoyfilter(
- namespace=namespace, name=ingress_name
- )
- existing_filter = EnvoyFilter.model_validate(filter_dict)
- except ApiException as e:
- if e.status != 404:
- raise
- except Exception as e:
- raise e
- expected_filter = get_ingress_fallback_envoyfilter(
- ingress_name=ingress_name,
- namespace=namespace,
- labels={**managed_labels},
- extra_req_headers={
- gpustack_fallback_path_header: f'%REQ({gpustack_original_path_header.upper()})%'
- },
- )
- if existing_filter is None:
- await networking_istio_api.create_envoyfilter(
- namespace=namespace,
- body=expected_filter,
- )
- logger.info(
- f"Created fallback EnvoyFilter {ingress_name} in namespace {namespace}."
- )
- else:
- existing_spec_dict = existing_filter.spec.model_dump(exclude_none=True)
- expected_spec_dict = expected_filter.spec.model_dump(exclude_none=True)
- if existing_spec_dict != expected_spec_dict:
- existing_filter.spec = expected_filter.spec
- await networking_istio_api.edit_envoyfilter(
- name=ingress_name,
- namespace=namespace,
- body=existing_filter,
- )
- logger.info(
- f"Updated fallback EnvoyFilter {ingress_name} in namespace {namespace}."
- )
- def ai_proxy_openai_provider_config(id: str) -> Dict[str, Any]:
- return ai_proxy_types.AIProxyDefaultConfig(
- type=ModelProviderTypeEnum.OPENAI,
- id=id,
- failover=ai_proxy_types.FailoverConfig(enabled=False),
- retryOnFailure=ai_proxy_types.EnableState(enabled=False),
- ).model_dump(exclude_none=True, exclude_unset=True)
- def compare_and_append_default_proxy_config(
- existing_providers: List[Dict[str, Any]],
- expected_providers: List[Dict[str, Any]],
- operating_id_prefix: Optional[str] = None,
- ) -> List[Dict[str, Any]]:
- to_keep_config = []
- for provider in existing_providers:
- provider_id: Optional[str] = provider.get('id', None)
- if (
- provider_id is None
- or operating_id_prefix is None
- or not provider_id.startswith(operating_id_prefix)
- ):
- to_keep_config.append(provider)
- continue
- return_providers = expected_providers.copy()
- return_providers.extend(to_keep_config)
- return_providers.sort(key=lambda p: p.get("id", ""))
- return return_providers
- def compare_and_append_proxy_match_rules(
- existing_rules: List[WasmPluginMatchRule],
- expected_rules: List[WasmPluginMatchRule],
- operating_id_prefix: Optional[str] = None,
- ) -> List[WasmPluginMatchRule]:
- to_keep_config = []
- for rule in existing_rules:
- provider_id: Optional[str] = rule.config.get('activeProviderId', None)
- if (
- provider_id is None
- or operating_id_prefix is None
- or not provider_id.startswith(operating_id_prefix)
- ):
- to_keep_config.append(rule)
- continue
- return_rules = expected_rules.copy()
- return_rules.extend(to_keep_config)
- return_rules.sort(key=lambda r: (r.config.get("activeProviderId", None) or ""))
- return return_rules
- async def cleanup_ai_proxy_config(
- providers: List[ModelProvider],
- routes: List[ModelRoute],
- k8s_config: k8s_client.Configuration,
- namespace: str,
- ):
- if k8s_config is None:
- return
- prefixes_to_keep = {model_route_cleanup_prefix(route.id) for route in routes}
- prefixes_to_keep.update(
- {provider_registry_name(provider.id) for provider in providers}
- )
- def should_keep(provider_id: str) -> bool:
- for prefix in prefixes_to_keep:
- if provider_id.startswith(prefix):
- return True
- return False
- try:
- extensions_api = ExtensionsHigressIoV1Api(k8s_client.ApiClient(k8s_config))
- ai_proxy_data = await extensions_api.get_wasmplugin(
- namespace=namespace,
- name=gpustack_ai_proxy_name,
- )
- existing_plugin = WasmPlugin.model_validate(ai_proxy_data)
- current_providers = existing_plugin.spec.defaultConfig.get("providers", [])
- filtered_providers = [
- p for p in current_providers if p.get("id") and should_keep(p.get("id"))
- ]
- existing_plugin.spec.defaultConfig["providers"] = filtered_providers
- filtered_provider_ids = {
- p.get("id") for p in filtered_providers if p.get("id") is not None
- }
- filtered_rules = [
- r
- for r in existing_plugin.spec.matchRules or []
- if r.config.get("activeProviderId") in filtered_provider_ids
- ]
- existing_plugin.spec.matchRules = filtered_rules
- await extensions_api.edit_wasmplugin(
- namespace=namespace,
- name=gpustack_ai_proxy_name,
- body=existing_plugin,
- )
- except k8s_client.ApiException as e:
- logger.error(
- f"Failed to cleanup gpustack AI proxy wasmplugin {gpustack_ai_proxy_name}: {e}"
- )
- raise
- def build_generic_route_path_pattern(route_id: int) -> str:
- """Path pattern matching /model/proxy/<id> and /model/proxy/<id>/<anything>.
- The Higress transformer plugin treats `add` + `path_pattern` + `value` as a
- regex substitution: the portion of :path that matches path_pattern is
- replaced by value, and the resulting string is written to the target header.
- So the pattern MUST consume the entire path — otherwise the unmatched tail
- gets concatenated onto the header value (e.g. `<route_name>v1/models`).
- The `(/.*)?$` tail keeps the `/` boundary after <id> so `/model/proxy/10`
- doesn't spuriously match id=1.
- """
- return f"^/model/proxy/{route_id}(/.*)?$"
- def build_generic_route_header_rule(route_id: int, route_name: str) -> Dict[str, Any]:
- """HeaderRule dict injecting x-higress-llm-model when /model/proxy/<id>/ is hit.
- Example for ``route_id=1, route_name="qwen3-0.6b"``::
- {
- "key": "x-higress-llm-model",
- "value": "qwen3-0.6b",
- "path_pattern": "^/model/proxy/1(/.*)?$",
- }
- At runtime Higress substitutes the portion of ``:path`` that matches
- ``path_pattern`` with ``value`` and writes the result to ``key``.
- ``path_pattern`` anchors both ends so every matched path reduces to just
- ``value`` — see build_generic_route_path_pattern for why.
- """
- return {
- "key": "x-higress-llm-model",
- "value": route_name,
- "path_pattern": build_generic_route_path_pattern(route_id),
- }
- # Generic-route rules are identified by the shape of their path_pattern — this
- # lets the diff/cleanup code coexist with other reqRules blocks a future
- # contributor might add for unrelated purposes, instead of assuming the
- # generic-route block is the only `add` block in the plugin.
- _GENERIC_ROUTE_PATH_PATTERN_RE = re.compile(r"^\^/model/proxy/\d+")
- def _is_generic_route_header(header: Dict[str, Any]) -> bool:
- return bool(_GENERIC_ROUTE_PATH_PATTERN_RE.match(header.get("path_pattern", "")))
- def _split_generic_route_req_rules(
- req_rules: List[Dict[str, Any]],
- ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
- """
- Partition existing reqRules into (blocks_to_preserve, generic_route_headers).
- - Any ``add`` block contributes generic-route headers (identified by
- path_pattern shape) to the second list. Non-generic headers from the same
- block are retained verbatim in the first list, so mixed ownership is safe.
- - Blocks with any other ``operate`` (rename, remove, map, ...) are preserved
- untouched.
- """
- preserve: List[Dict[str, Any]] = []
- generic_headers: List[Dict[str, Any]] = []
- for rule in req_rules:
- if rule.get("operate") != "add":
- preserve.append(rule)
- continue
- foreign_headers: List[Dict[str, Any]] = []
- for header in rule.get("headers", []):
- if _is_generic_route_header(header):
- generic_headers.append(header)
- else:
- foreign_headers.append(header)
- if foreign_headers:
- preserve.append({**rule, "headers": foreign_headers})
- return preserve, generic_headers
- def _set_generic_route_headers(
- current_spec: WasmPluginSpec, headers: List[Dict[str, Any]]
- ) -> WasmPluginSpec:
- # Do NOT touch defaultConfigDisable here — flipping it rewrites Envoy's
- # filter chain and tears down every live connection. Only the reqRules
- # list changes between reconciliations; the enable flag is locked at
- # plugin creation (see generic_route_transformer_plugin).
- default_config = current_spec.defaultConfig or {}
- preserve, _ = _split_generic_route_req_rules(default_config.get("reqRules", []))
- new_req_rules: List[Dict[str, Any]] = list(preserve)
- if headers:
- sorted_headers = sorted(headers, key=lambda h: h.get("path_pattern", ""))
- new_req_rules.append({"operate": "add", "headers": sorted_headers})
- current_spec.defaultConfig = {"reqRules": new_req_rules}
- return current_spec
- def generic_route_transformer_diff_spec(
- current_spec: Optional[WasmPluginSpec],
- expected_header_rules: List[Dict[str, Any]],
- operating_path_pattern: str,
- ) -> Optional[WasmPluginSpec]:
- """
- Merge expected_header_rules into the current spec, replacing any existing rule
- whose path_pattern equals operating_path_pattern. Other routes' rules stay
- untouched, as do unrelated reqRules blocks added by other subsystems.
- Returns None if the plugin doesn't exist yet — init handles that.
- Example — reconciling route id=2 renamed to "route-two-renamed" while route
- id=1 is already registered:
- # current_spec.defaultConfig.reqRules (before)
- [{"operate": "add", "headers": [
- {"key": "x-higress-llm-model", "value": "route-one",
- "path_pattern": "^/model/proxy/1(/.*)?$"},
- {"key": "x-higress-llm-model", "value": "route-two",
- "path_pattern": "^/model/proxy/2(/.*)?$"},
- ]}]
- # call
- generic_route_transformer_diff_spec(
- current_spec,
- expected_header_rules=[build_generic_route_header_rule(2, "route-two-renamed")],
- operating_path_pattern="^/model/proxy/2(/.*)?$",
- )
- # current_spec.defaultConfig.reqRules (after)
- [{"operate": "add", "headers": [
- {"key": "x-higress-llm-model", "value": "route-one",
- "path_pattern": "^/model/proxy/1(/.*)?$"},
- {"key": "x-higress-llm-model", "value": "route-two-renamed",
- "path_pattern": "^/model/proxy/2(/.*)?$"},
- ]}]
- Pass expected_header_rules=[] to remove the rule for this route (e.g. when
- generic_proxy is toggled off or the route is deleted).
- """
- if current_spec is None:
- return current_spec
- req_rules = (current_spec.defaultConfig or {}).get("reqRules", [])
- _, generic_headers = _split_generic_route_req_rules(req_rules)
- retained = [
- h for h in generic_headers if h.get("path_pattern") != operating_path_pattern
- ]
- return _set_generic_route_headers(current_spec, retained + expected_header_rules)
- def cleanup_generic_route_transformer_spec_diff(
- current_spec: Optional[WasmPluginSpec],
- expected_path_patterns: Set[str],
- ) -> Optional[WasmPluginSpec]:
- """
- Drop generic-route HeaderRules whose path_pattern is not in
- expected_path_patterns. Non-generic rules (any shape that doesn't look like
- ``^/model/proxy/<id>``) are left untouched. Used on startup to prune rules
- for routes that were deleted or had generic_proxy toggled off while the
- server was down.
- """
- if current_spec is None:
- return current_spec
- req_rules = (current_spec.defaultConfig or {}).get("reqRules", [])
- _, generic_headers = _split_generic_route_req_rules(req_rules)
- retained = [
- h for h in generic_headers if h.get("path_pattern") in expected_path_patterns
- ]
- return _set_generic_route_headers(current_spec, retained)
- async def cleanup_generic_route_transformer(
- routes: List[ModelRoute],
- k8s_config: k8s_client.Configuration,
- namespace: str,
- ):
- """Prune generic-route transformer rules to those for existing generic_proxy routes."""
- if k8s_config is None:
- return
- expected_patterns = {
- build_generic_route_path_pattern(route.id)
- for route in routes
- if getattr(route, "generic_proxy", False)
- }
- api = ExtensionsHigressIoV1Api(k8s_client.ApiClient(k8s_config))
- await ensure_wasm_plugin(
- api=api,
- name=gpustack_generic_route_transformer_name,
- namespace=namespace,
- spec_diff=partial(
- cleanup_generic_route_transformer_spec_diff,
- expected_path_patterns=expected_patterns,
- ),
- )
- async def cleanup_mcpbridge_registry(
- providers: List[ModelProvider],
- model_instances: List[ModelInstance],
- workers: List[Worker],
- namespace: str,
- k8s_config: k8s_client.Configuration,
- ):
- if k8s_config is None:
- return
- worker_by_id = {worker.id: worker for worker in workers}
- networking_higress_api = NetworkingHigressIoV1Api(k8s_client.ApiClient(k8s_config))
- # cleanup providers
- desired_registries = []
- desired_proxies = []
- for provider in providers:
- registry = provider_registry(provider=provider)
- if registry is not None:
- desired_registries.append(registry)
- proxy = provider_proxy(provider=provider)
- if proxy is not None:
- desired_proxies.append(proxy)
- to_delete_prefix = provider_id_prefix
- await ensure_mcp_bridge(
- client=networking_higress_api,
- namespace=namespace,
- mcp_bridge_name=default_mcp_bridge_name,
- desired_registries=desired_registries,
- to_delete_prefix=to_delete_prefix,
- desired_proxies=desired_proxies,
- to_delete_proxies_prefix=provider_id_prefix,
- )
- # cleanup model instances
- desired_registries = []
- to_delete_prefix = model_id_prefix
- for instance in model_instances:
- worker = worker_by_id.get(instance.worker_id)
- registry = model_instance_registry(instance, worker=worker)
- if registry is not None:
- desired_registries.append(registry)
- await ensure_mcp_bridge(
- client=networking_higress_api,
- namespace=namespace,
- mcp_bridge_name=default_mcp_bridge_name,
- desired_registries=desired_registries,
- to_delete_prefix=to_delete_prefix,
- )
- def ai_proxy_diff_spec(
- current_spec: Optional[WasmPluginSpec],
- expected_providers: List[Dict[str, Any]],
- expected_match_rules: List[WasmPluginMatchRule],
- operating_id_prefix: Optional[str] = None,
- ) -> WasmPluginSpec:
- if current_spec is None:
- return current_spec
- current_spec.defaultConfig["providers"] = compare_and_append_default_proxy_config(
- existing_providers=current_spec.defaultConfig.get("providers", []),
- expected_providers=expected_providers,
- operating_id_prefix=operating_id_prefix,
- )
- current_spec.matchRules = compare_and_append_proxy_match_rules(
- existing_rules=current_spec.matchRules or [],
- expected_rules=expected_match_rules,
- operating_id_prefix=operating_id_prefix,
- )
- return current_spec
- def get_instance_id_from_header(headers: Headers) -> int:
- """Parse the model instance ID from the ``x-gpustack-model-instance`` routing header.
- The header value follows the pattern ``model-<model_id>-<instance_id>.<suffix>``
- injected by the API gateway. The instance ID is the last numeric segment
- before the first dot.
- Raises:
- HTTPException (400): if the header is absent.
- NotFoundException: if the header value does not match the expected pattern.
- """
- model_destination = headers.get(router_header_key, None)
- if model_destination is None:
- raise HTTPException(
- status_code=400, detail=f"Missing {router_header_key} header"
- )
- # Match pattern: model-<model_id>-<instance_id>.suffix
- # instance_id is the last numeric segment before the first dot
- match = re.match(r'^model-.*-(\d+)\..+', model_destination)
- if not match:
- raise NotFoundException(
- message=f"Invalid model destination format: {model_destination}"
- )
- return int(match.group(1))
- async def resolve_instance_address_from_model_header(
- headers: Dict[str, str],
- ) -> Tuple[Optional[str], int]:
- """Resolve the target worker (IP, port) for an inference request.
- Parses the ``x-gpustack-model-instance`` routing header injected by the API gateway
- to extract the model instance ID, then queries the database for that
- instance's worker IP and inference port.
- Used as the ``header_router`` callback of ``HTTPSProxyServer`` in tunnel
- proxy mode so the proxy knows which instance address to forward each request to.
- Returns ``(None, 0)`` when the header is absent or the instance cannot be
- resolved, causing the proxy to fall back to URI-based routing.
- """
- try:
- instance_id = get_instance_id_from_header(headers)
- except HTTPException as e:
- logger.trace(f"direct proxying request as: {e}")
- return None, 0
- except Exception as e:
- logger.debug(f"Error parsing model destination header: {e}")
- return None, 0
- async with async_session() as session:
- model_instance_service = ModelInstanceService(session)
- model_instance: ModelInstance = await model_instance_service.get_by_id(
- instance_id
- )
- if model_instance is None:
- logger.error(f"Model instance with ID {instance_id} not found.")
- return None, 0
- if model_instance.worker_ip is None or len(model_instance.ports) == 0:
- logger.error(
- f"Model instance with ID {instance_id} do not get scheduled yet."
- )
- return None, 0
- return model_instance.worker_ip, model_instance.ports[0]
- async def worker_websocket_connect_callback(
- _server: Optional[ServerInfo],
- client: Optional[RegisteredClientInfo],
- proxy_address: Optional[str] = None,
- ) -> None:
- """Update ``worker.proxy_address`` in the database when a tunnel connects or disconnects.
- Called by ``MessageServerHandler`` as the ``callback_on_connect`` /
- ``callback_on_disconnect`` hook. On connect, ``proxy_address`` is the
- server-side HTTP proxy URL the gateway should route to; on disconnect it is
- ``None``, clearing the field so the worker is no longer reachable via tunnel.
- The worker is looked up by matching ``client.client_id`` against
- ``Worker.worker_uuid``. If no matching worker is found the callback logs an
- error and returns without modifying the database.
- """
- if client is None:
- return
- async with async_session() as session:
- worker = await Worker.one_by_field(
- session=session, field="worker_uuid", value=str(client.client_id)
- )
- if worker is None:
- logger.error(f"Worker with UUID {client.client_id} not found.")
- return
- if worker.proxy_address == proxy_address:
- return
- worker.proxy_address = proxy_address
- await WorkerService(session).update(worker)
|