utils.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680
  1. import re
  2. import logging
  3. import copy
  4. import math
  5. from urllib.parse import urlparse
  6. from dataclasses import dataclass
  7. from functools import partial
  8. from typing import List, Optional, Tuple, Union, Dict, Any, Literal, Callable, Set
  9. from tenacity import retry, stop_after_attempt, wait_fixed
  10. from fastapi import HTTPException
  11. from starlette.datastructures import Headers
  12. from gpustack.gateway.labels_annotations import managed_labels, match_labels
  13. from gpustack.gateway import ai_proxy_types
  14. from gpustack.gateway.client.networking_higress_io_v1_api import (
  15. NetworkingHigressIoV1Api,
  16. McpBridge,
  17. McpBridgeRegistry,
  18. McpBridgeSpec,
  19. McpBridgeProxy,
  20. )
  21. from gpustack.gateway.client.extensions_higress_io_v1_api import (
  22. WasmPlugin,
  23. WasmPluginSpec,
  24. ExtensionsHigressIoV1Api,
  25. WasmPluginMatchRule,
  26. )
  27. from gpustack.gateway.client.networking_istio_io_v1alpha3_api import (
  28. NetworkingIstioIoV1Alpha3Api,
  29. EnvoyFilter,
  30. get_ingress_fallback_envoyfilter,
  31. )
  32. from gpustack.schemas.models import (
  33. ModelInstance,
  34. ModelInstancePublic,
  35. )
  36. from gpustack.schemas.model_provider import (
  37. ModelProvider,
  38. ModelProviderTypeEnum,
  39. )
  40. from gpustack.schemas.model_routes import ModelRoute
  41. from gpustack.server.bus import EventType
  42. from gpustack.server.db import async_session
  43. from gpustack.server.services import ModelInstanceService, WorkerService
  44. from gpustack.schemas.config import ModelInstanceProxyModeEnum
  45. from gpustack.schemas.workers import Worker
  46. from gpustack.schemas.clusters import Cluster
  47. from gpustack.utils.network import is_ipaddress
  48. from kubernetes_asyncio import client as k8s_client
  49. from kubernetes_asyncio.client import ApiException, V1IngressTLS
  50. from gpustack.envs import GATEWAY_MIRROR_INGRESS_NAME
  51. from gpustack.api.exceptions import NotFoundException
  52. from gpustack.websocket_proxy.message import ServerInfo, RegisteredClientInfo
  53. logger = logging.getLogger(__name__)
  54. default_mcp_bridge_name = "default"
  55. gpustack_ai_proxy_name = "gpustack-ai-proxy"
  56. gpustack_model_mapper_name = "gpustack-model-mapper"
  57. gpustack_generic_route_transformer_name = "gpustack-generic-route-transformer"
  58. model_ingress_prefix = "ai-route-model-"
  59. model_route_ingress_prefix = "ai-route-route-"
  60. provider_id_prefix = "provider-"
  61. model_id_prefix = "model-"
  62. router_header_key = "X-GPUStack-Model-Instance"
  63. gpustack_original_path_header = "x-gpustack-original-path"
  64. gpustack_fallback_path_header = "x-gpustack-fallback-path"
  65. # Type alias for destination tuples
  66. # Each tuple contains (weight: int, model_name: str, registry: McpBridgeRegistry)
  67. DestinationTupleList = List[Tuple[int, str, McpBridgeRegistry]]
  68. @dataclass
  69. class RoutePrefix:
  70. prefixes: List[str]
  71. support_legacy: bool = False
  72. additional_versions: Optional[List[str]] = None
  73. def flattened_prefixes(self) -> List[str]:
  74. versioned_prefixes = ["/v1"]
  75. if self.support_legacy:
  76. versioned_prefixes.append("/v1-openai")
  77. if self.additional_versions:
  78. versioned_prefixes.extend(self.additional_versions)
  79. flattened = []
  80. for versioned_prefix in versioned_prefixes:
  81. for prefix in self.prefixes:
  82. flattened.append(f"{versioned_prefix}{prefix}")
  83. return flattened
  84. def regex_prefixes(self) -> List[str]:
  85. """
  86. Returns regex patterns for the prefixes, considering versioning and legacy support.
  87. It supports removing -openai suffix from the versioned prefix with rewrite-target: /$1$3
  88. """
  89. versioned_prefixes = [f"/(v1){'(-openai)?' if self.support_legacy else '()'}"]
  90. if self.additional_versions:
  91. versioned_prefixes.extend(
  92. f"/({re.escape(additional_version.lstrip('/'))})()"
  93. for additional_version in self.additional_versions
  94. )
  95. return [
  96. f"{versioned_prefix}({prefix})"
  97. for versioned_prefix in versioned_prefixes
  98. for prefix in self.prefixes
  99. ]
  100. openai_model_prefixes: List[RoutePrefix] = [
  101. RoutePrefix(
  102. [
  103. "/chat/completions",
  104. "/completions",
  105. "/responses",
  106. "/embeddings",
  107. "/audio/transcriptions",
  108. "/audio/speech",
  109. "/images/generations",
  110. "/images/edits",
  111. ],
  112. True,
  113. ),
  114. RoutePrefix(
  115. [
  116. "/audio/translations",
  117. "/images/variations",
  118. "/moderations",
  119. "/score",
  120. ]
  121. ),
  122. RoutePrefix(["/rerank"], additional_versions=["/v2"]),
  123. ]
  124. anthropic_model_exact: List[RoutePrefix] = [
  125. RoutePrefix(["/messages", "/messages/count_tokens", "/complete"]),
  126. ]
  127. def get_default_mcpbridge_ref(
  128. mcp_bridge_name: str = default_mcp_bridge_name,
  129. ) -> k8s_client.V1TypedLocalObjectReference:
  130. # the name is hardcoded in Higress MCP Bridge controller
  131. return k8s_client.V1TypedLocalObjectReference(
  132. api_group='networking.higress.io',
  133. kind='McpBridge',
  134. name=mcp_bridge_name,
  135. )
  136. def wrap_route(
  137. path: str,
  138. path_type: str,
  139. backend: Optional[k8s_client.V1IngressBackend] = None,
  140. ) -> k8s_client.V1HTTPIngressPath:
  141. if backend is None:
  142. backend = k8s_client.V1IngressBackend(
  143. resource=get_default_mcpbridge_ref(),
  144. )
  145. return k8s_client.V1HTTPIngressPath(
  146. path=path,
  147. path_type=path_type,
  148. backend=backend,
  149. )
  150. def anthropic_routes() -> List[k8s_client.V1HTTPIngressPath]:
  151. routes = []
  152. for route_exact in anthropic_model_exact:
  153. for prefix in route_exact.regex_prefixes():
  154. routes.append(wrap_route(path=prefix, path_type="ImplementationSpecific"))
  155. return routes
  156. def ingress_rule_for_model() -> k8s_client.V1IngressRule:
  157. paths: List[k8s_client.V1HTTPIngressPath] = []
  158. for route_prefix in openai_model_prefixes:
  159. for prefix in route_prefix.regex_prefixes():
  160. paths.append(wrap_route(path=prefix, path_type="ImplementationSpecific"))
  161. return k8s_client.V1IngressRule(http=k8s_client.V1HTTPIngressRuleValue(paths=paths))
  162. def cluster_mcp_bridge_name(cluster_id: int) -> str:
  163. # higress_controller has hardcoded mcp bridge name to 'default'
  164. # the name should be based on cluster_id if higress_controller supports multiple mcp bridges
  165. return default_mcp_bridge_name
  166. def model_mcp_bridge_name(cluster_id: int) -> str:
  167. return cluster_mcp_bridge_name(cluster_id)
  168. def model_route_cleanup_prefix(model_route_id: int) -> str:
  169. return f"{model_route_ingress_prefix}{model_route_id}"
  170. def model_route_ingress_name(model_route_id: int) -> str:
  171. return f"{model_route_ingress_prefix}{model_route_id}.internal"
  172. def fallback_ingress_name(name: str) -> str:
  173. split_name = name.rsplit('.', 1)
  174. if len(split_name) == 1:
  175. return f"{name}.fallback"
  176. return f"{split_name[0]}.fallback.{split_name[1]}"
  177. def model_ingress_name(model_id: int) -> str:
  178. return f"{model_ingress_prefix}{model_id}"
  179. def cluster_worker_prefix(cluster_id: int) -> str:
  180. return f"cluster-{cluster_id}-worker-"
  181. def model_prefix(model_id: int) -> str:
  182. return f"{model_id_prefix}{model_id}-"
  183. def model_instance_prefix(
  184. model_instance: Union[ModelInstance, ModelInstancePublic]
  185. ) -> str:
  186. return f"{model_prefix(model_instance.model_id)}{model_instance.id}"
  187. def model_instance_registry(
  188. model_instance: Union[ModelInstance, ModelInstancePublic],
  189. worker: Optional[Worker] = None,
  190. ) -> Optional[McpBridgeRegistry]:
  191. name = model_instance_prefix(model_instance)
  192. if worker is not None:
  193. if worker.proxy_mode == ModelInstanceProxyModeEnum.WORKER:
  194. return _worker_reserve_proxy_registry(worker, name)
  195. elif worker.proxy_mode == ModelInstanceProxyModeEnum.TUNNEL:
  196. return _worker_tunnel_proxy_registry(worker, name)
  197. address = model_instance.worker_advertise_address or model_instance.worker_ip
  198. if address is None or address == "" or model_instance.port is None:
  199. return None
  200. domain = address
  201. port = model_instance.port
  202. registry_type = "dns"
  203. if is_ipaddress(address):
  204. domain = f"{address}:{model_instance.port}"
  205. port = 80
  206. registry_type = "static"
  207. return McpBridgeRegistry(
  208. domain=domain,
  209. port=port,
  210. name=name,
  211. protocol="http",
  212. type=registry_type,
  213. )
  214. def _worker_reserve_proxy_registry(
  215. worker: Worker, name_override: Optional[str] = None
  216. ) -> McpBridgeRegistry:
  217. """Build an McpBridgeRegistry entry for a worker in DIRECT or WORKER proxy mode.
  218. Uses ``worker.advertise_address`` when available, otherwise falls back to
  219. ``worker.ip``. For raw IP addresses the registry type is set to ``static``
  220. and the host:port pair is encoded in the domain field; for hostnames the
  221. type is ``dns`` and the port is carried separately.
  222. Returns ``None`` if the worker has no resolvable address or port.
  223. """
  224. address = worker.advertise_address or worker.ip
  225. if address is None or address == "" or worker.port is None:
  226. return None
  227. domain = address
  228. port = worker.port
  229. registry_type = "dns"
  230. if is_ipaddress(address):
  231. domain = f"{address}:{worker.port}"
  232. port = 80
  233. registry_type = "static"
  234. return McpBridgeRegistry(
  235. domain=domain,
  236. port=port,
  237. name=name_override or f"{cluster_worker_prefix(worker.cluster_id)}{worker.id}",
  238. protocol="http",
  239. type=registry_type,
  240. )
  241. def _worker_tunnel_proxy_registry(
  242. worker: Worker, name_override: Optional[str] = None
  243. ) -> Optional[McpBridgeRegistry]:
  244. """Build an McpBridgeRegistry entry for a worker in TUNNEL proxy mode.
  245. Points the registry at the server-side HTTP proxy address stored in
  246. ``worker.proxy_address``, which is populated by
  247. ``worker_websocket_connect_callback`` when the worker's WebSocket tunnel
  248. connects. The gateway routes inference requests to this proxy, which then
  249. tunnels them to the worker via the persistent WebSocket connection.
  250. Returns ``None`` if the worker has no proxy address (i.e. the WebSocket
  251. tunnel has not yet connected).
  252. """
  253. if worker.get_proxy_address() is None:
  254. return None
  255. # proxy address must be a valid URL and the netloc must be a valid IP.
  256. result = urlparse(worker.get_proxy_address())
  257. protocol = "http" if result.scheme == "http" else "https"
  258. port = result.port or (80 if protocol == "http" else 443)
  259. return McpBridgeRegistry(
  260. domain=f"{result.hostname}:{port}",
  261. port=80,
  262. name=name_override or f"{cluster_worker_prefix(worker.cluster_id)}{worker.id}",
  263. protocol=protocol,
  264. type="static",
  265. )
  266. def cluster_registry(cluster: Cluster) -> Optional[McpBridgeRegistry]:
  267. if cluster.gateway_endpoint is None and cluster.reported_gateway_endpoint is None:
  268. return None
  269. return McpBridgeRegistry(
  270. domain=cluster.gateway_endpoint or cluster.reported_gateway_endpoint,
  271. port=80,
  272. name="cluster-gateway",
  273. protocol="http",
  274. type="static",
  275. )
  276. def provider_registry_name(id: int) -> str:
  277. return f"{provider_id_prefix}{id}"
  278. def provider_registry(provider: ModelProvider) -> Optional[McpBridgeRegistry]:
  279. provider_url = provider.config.get_base_url()
  280. if provider_url is None:
  281. return None
  282. result = urlparse(url=provider_url)
  283. protocol = "http" if result.scheme == "http" else "https"
  284. port = 443 if protocol == "https" else 80
  285. registry_type = (
  286. "static" if result.hostname and is_ipaddress(result.hostname) else "dns"
  287. )
  288. if registry_type == "static":
  289. domain = result.netloc
  290. if result.port is None:
  291. domain = f"{domain}:{port}"
  292. else:
  293. domain = result.hostname
  294. if result.port is not None:
  295. port = result.port
  296. registry_name = provider_registry_name(provider.id)
  297. proxyName = f"{registry_name}-proxy" if provider.proxy_url else None
  298. return McpBridgeRegistry(
  299. domain=domain,
  300. port=port,
  301. name=registry_name,
  302. protocol=protocol,
  303. type=registry_type,
  304. proxyName=proxyName,
  305. )
  306. def provider_proxy(provider: ModelProvider) -> Optional[McpBridgeProxy]:
  307. if provider.proxy_url is None:
  308. return None
  309. proxy_url = urlparse(provider.proxy_url)
  310. scheme = proxy_url.scheme
  311. port = proxy_url.port
  312. if port is None:
  313. port = 443 if scheme == "https" else 80
  314. # timeout in seconds
  315. connection_timeout = provider.proxy_timeout or 5
  316. return McpBridgeProxy(
  317. name=f"{provider_registry_name(provider.id)}-proxy",
  318. serverAddress=proxy_url.hostname,
  319. serverPort=port,
  320. type=scheme.upper(),
  321. # convert to milliseconds
  322. connectTimeout=connection_timeout * 1000,
  323. )
  324. def provider_proxy_plugin_spec(
  325. *providers: ModelProvider,
  326. ) -> Tuple[List[Dict[str, Any]], List[WasmPluginMatchRule]]:
  327. provider_list = []
  328. match_rules = []
  329. sorted_providers: List[ModelProvider] = sorted(providers, key=lambda p: p.id)
  330. for provider in sorted_providers:
  331. registry = provider_registry(provider)
  332. if registry is None:
  333. continue
  334. service_name = registry.get_service_name()
  335. default_config_data = {
  336. "id": provider_registry_name(provider.id),
  337. "apiTokens": provider.api_tokens,
  338. **provider.config.model_dump_with_default_override(),
  339. "type": provider.config.type.value,
  340. }
  341. accessible_llm_model = next(
  342. (model.name for model in provider.models or [] if model.category == "llm"),
  343. None,
  344. )
  345. # Failover has more config
  346. if accessible_llm_model and len(provider.api_tokens) > 1:
  347. default_config_data["failover"] = ai_proxy_types.FailoverConfig(
  348. enabled=True,
  349. healthCheckModel=accessible_llm_model,
  350. )
  351. default_config = ai_proxy_types.AIProxyDefaultConfig.model_validate(
  352. default_config_data
  353. )
  354. provider_list.append(
  355. default_config.model_dump(by_alias=True, exclude_none=True)
  356. )
  357. active_config = ai_proxy_types.ActiveConfig(
  358. activeProviderId=provider_registry_name(provider.id),
  359. ).model_dump(exclude_none=True)
  360. match_rules.append(
  361. WasmPluginMatchRule(
  362. config=active_config,
  363. service=[service_name],
  364. configDisable=False,
  365. )
  366. )
  367. return provider_list, match_rules
  368. def diff_registries(
  369. existing: List[McpBridgeRegistry],
  370. desired: List[McpBridgeRegistry],
  371. to_delete_prefix: Optional[str] = None,
  372. ) -> Tuple[bool, List[McpBridgeRegistry]]:
  373. desired_map = {
  374. reg.name: idx for idx, reg in enumerate(desired) if reg.name is not None
  375. }
  376. total_list = []
  377. need_update = False
  378. for registry in existing:
  379. if registry.name not in desired_map:
  380. # delete registries that are not in the current list
  381. if to_delete_prefix is not None and registry.name.startswith(
  382. to_delete_prefix
  383. ):
  384. need_update = True
  385. else:
  386. # keep unrelated registries
  387. total_list.append(registry)
  388. else:
  389. # update existing registries
  390. idx = desired_map.pop(registry.name)
  391. if registry != desired[idx]:
  392. need_update = True
  393. registry = desired[idx]
  394. total_list.append(registry)
  395. # add new registries
  396. for idx in desired_map.values():
  397. need_update = True
  398. total_list.append(desired[idx])
  399. total_list.sort(key=lambda r: r.name or "")
  400. return need_update, total_list
  401. def diff_proxies(
  402. existing: List[McpBridgeProxy],
  403. desired: List[McpBridgeProxy],
  404. to_delete_prefix: Optional[str] = None,
  405. ) -> Tuple[bool, List[McpBridgeProxy]]:
  406. desired_map = {
  407. reg.name: idx for idx, reg in enumerate(desired) if reg.name is not None
  408. }
  409. total_list = []
  410. need_update = False
  411. for proxy in existing:
  412. if proxy.name not in desired_map:
  413. # delete registries that are not in the current list
  414. if to_delete_prefix is not None and proxy.name.startswith(to_delete_prefix):
  415. need_update = True
  416. else:
  417. # keep unrelated proxies
  418. total_list.append(proxy)
  419. else:
  420. # update existing proxies
  421. idx = desired_map.pop(proxy.name)
  422. if proxy != desired[idx]:
  423. need_update = True
  424. proxy = desired[idx]
  425. total_list.append(proxy)
  426. # add new proxies
  427. for idx in desired_map.values():
  428. need_update = True
  429. total_list.append(desired[idx])
  430. total_list.sort(key=lambda r: r.name or "")
  431. return need_update, total_list
  432. @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
  433. async def ensure_mcp_bridge(
  434. client: NetworkingHigressIoV1Api,
  435. namespace: str,
  436. mcp_bridge_name: str,
  437. desired_registries: List[McpBridgeRegistry],
  438. to_delete_prefix: Optional[str] = None,
  439. desired_proxies: List[McpBridgeProxy] = None,
  440. to_delete_proxies_prefix: Optional[str] = None,
  441. ):
  442. existing_bridge = None
  443. try:
  444. mcpbridge_dict = await client.get_mcpbridge(namespace, mcp_bridge_name)
  445. existing_bridge = McpBridge.model_validate(mcpbridge_dict)
  446. except ApiException as e:
  447. if e.status != 404:
  448. raise
  449. if existing_bridge is None:
  450. mcpbridge_body = McpBridge(
  451. metadata={
  452. "name": mcp_bridge_name,
  453. "namespace": namespace,
  454. "labels": managed_labels,
  455. },
  456. spec=McpBridgeSpec(registries=desired_registries, proxies=desired_proxies),
  457. )
  458. await client.create_mcpbridge(
  459. namespace=namespace,
  460. body=mcpbridge_body,
  461. )
  462. logger.info(f"Created MCP Bridge {mcp_bridge_name} in namespace {namespace}.")
  463. else:
  464. registry_need_update, registry_list = diff_registries(
  465. existing=existing_bridge.spec.registries or [],
  466. desired=desired_registries,
  467. to_delete_prefix=to_delete_prefix,
  468. )
  469. proxy_need_update = False
  470. proxy_list = existing_bridge.spec.proxies or []
  471. if desired_proxies is not None:
  472. proxy_need_update, proxy_list = diff_proxies(
  473. existing=existing_bridge.spec.proxies or [],
  474. desired=desired_proxies,
  475. to_delete_prefix=to_delete_proxies_prefix,
  476. )
  477. if registry_need_update or proxy_need_update:
  478. registry_list.sort(key=lambda r: r.name or "")
  479. proxy_list.sort(key=lambda r: r.name or "")
  480. existing_bridge.spec.registries = registry_list
  481. existing_bridge.spec.proxies = proxy_list
  482. await client.edit_mcpbridge(
  483. name=mcp_bridge_name,
  484. namespace=namespace,
  485. body=existing_bridge,
  486. )
  487. logger.info(
  488. f"Updated MCP Bridge {mcp_bridge_name} in namespace {namespace}."
  489. )
  490. def generate_model_ingress(
  491. ingress_name: str,
  492. namespace: str,
  493. route_name: str,
  494. destinations: str,
  495. hostname: Optional[str] = None,
  496. tls: Optional[List[V1IngressTLS]] = None,
  497. included_generic_route: Optional[bool] = False,
  498. included_proxy_route: Optional[bool] = False,
  499. extra_annotations: Optional[Dict[str, str]] = None,
  500. ingress_class_name: str = "higress",
  501. ) -> k8s_client.V1Ingress:
  502. retry_policies = "error,timeout,http_503,http_502,non_idempotent"
  503. annotations = {
  504. "higress.io/rewrite-target": "/$1$3",
  505. "higress.io/destination": destinations,
  506. "higress.io/ignore-path-case": 'true',
  507. "higress.io/proxy-next-upstream-tries": '2',
  508. "higress.io/proxy-next-upstream": retry_policies,
  509. **higress_http_header_matcher("exact", "x-higress-llm-model", route_name),
  510. }
  511. if extra_annotations is not None:
  512. annotations.update(extra_annotations)
  513. metadata = k8s_client.V1ObjectMeta(
  514. name=ingress_name,
  515. namespace=namespace,
  516. annotations=annotations,
  517. labels=managed_labels,
  518. )
  519. expected_rule = ingress_rule_for_model()
  520. if included_proxy_route:
  521. # to compatible with rewrite-target /$1$3, the first capturing group is empty.
  522. # The /\d+ variant strips the route-id segment from /model/proxy/<id>/<path>
  523. # so the upstream receives /<path>. The id-less variant preserves the legacy
  524. # /model/proxy/<path> + X-GPUStack-Model header form. The more specific rule
  525. # is listed first so Higress tries id-based matching before falling back.
  526. expected_rule.http.paths.append(
  527. wrap_route(
  528. r"/()model/proxy/\d+(/|$)(.*)",
  529. "ImplementationSpecific",
  530. )
  531. )
  532. expected_rule.http.paths.append(
  533. wrap_route(
  534. "/()model/proxy(/|$)(.*)",
  535. "ImplementationSpecific",
  536. )
  537. )
  538. if included_generic_route:
  539. expected_rule.http.paths.append(wrap_route("/", "Prefix"))
  540. # support for Anthropic API
  541. expected_rule.http.paths.extend(anthropic_routes())
  542. spec = k8s_client.V1IngressSpec(
  543. ingress_class_name=ingress_class_name, rules=[expected_rule]
  544. )
  545. if hostname is not None:
  546. hostname_rule = copy.deepcopy(expected_rule)
  547. hostname_rule.host = hostname
  548. spec.rules.append(hostname_rule)
  549. spec.tls = tls
  550. ingress = k8s_client.V1Ingress(
  551. api_version="networking.k8s.io/v1",
  552. kind="Ingress",
  553. metadata=metadata,
  554. spec=spec,
  555. )
  556. return ingress
  557. def higress_metadata_equal(
  558. existing_metadata: Optional[k8s_client.V1ObjectMeta],
  559. expected_metadata: Optional[k8s_client.V1ObjectMeta],
  560. ) -> bool:
  561. existing_metadata = existing_metadata or k8s_client.V1ObjectMeta()
  562. expected_metadata = expected_metadata or k8s_client.V1ObjectMeta()
  563. if existing_metadata.annotations is None:
  564. existing_metadata.annotations = {}
  565. if expected_metadata.annotations is None:
  566. expected_metadata.annotations = {}
  567. for key in set(
  568. k for k in expected_metadata.annotations if k.startswith("higress.io")
  569. ):
  570. if existing_metadata.annotations.get(key) != expected_metadata.annotations.get(
  571. key
  572. ):
  573. return False
  574. return True
  575. def ingress_tls_equal(
  576. existing: Optional[k8s_client.V1IngressTLS],
  577. expected: Optional[k8s_client.V1IngressTLS],
  578. ) -> bool:
  579. if (existing is None) != (expected is None):
  580. return False
  581. if existing and expected:
  582. if len(existing) != len(expected):
  583. return False
  584. for etls, xtls in zip(existing, expected):
  585. # only compares hosts and secret_name for tls equal
  586. if getattr(etls, 'hosts', None) != getattr(xtls, 'hosts', None):
  587. return False
  588. if getattr(etls, 'secret_name', None) != getattr(xtls, 'secret_name', None):
  589. return False
  590. return True
  591. def mcp_ingress_equal(
  592. existing: k8s_client.V1Ingress, expected: k8s_client.V1Ingress
  593. ) -> bool:
  594. if not higress_metadata_equal(
  595. existing_metadata=existing.metadata, expected_metadata=expected.metadata
  596. ):
  597. return False
  598. if existing.spec is None or expected.spec is None:
  599. return False
  600. if not ingress_tls_equal(
  601. existing=getattr(existing.spec, 'tls', None),
  602. expected=getattr(expected.spec, 'tls', None),
  603. ):
  604. return False
  605. if len(existing.spec.rules or []) != len(expected.spec.rules or []):
  606. return False
  607. for existing_rule, expected_rule in zip(
  608. existing.spec.rules or [], expected.spec.rules or []
  609. ):
  610. if getattr(existing_rule, 'host', None) != getattr(expected_rule, 'host', None):
  611. return False
  612. if existing_rule.http is None or expected_rule.http is None:
  613. return False
  614. if len(existing_rule.http.paths or []) != len(expected_rule.http.paths or []):
  615. return False
  616. for existing_path, expected_path in zip(
  617. existing_rule.http.paths or [], expected_rule.http.paths or []
  618. ):
  619. if existing_path.path != expected_path.path:
  620. return False
  621. if existing_path.path_type != expected_path.path_type:
  622. return False
  623. if existing_path.backend.resource != expected_path.backend.resource:
  624. return False
  625. return True
  626. def scale_weight(weight_instance_pairs: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
  627. """
  628. Scale weights based on the least common multiple of counts to maintain proportionality.
  629. """
  630. counts = [count for _, count in weight_instance_pairs if count > 0]
  631. if not counts:
  632. return weight_instance_pairs
  633. lcm_count = math.lcm(*counts)
  634. scaled = [
  635. (weight * lcm_count // count if count > 0 else 0, count)
  636. for weight, count in weight_instance_pairs
  637. ]
  638. return scaled
  639. def hamilton_calculate_weight(
  640. weight_instance_pairs: List[Tuple[int, int]],
  641. max_weight: Optional[int] = 0,
  642. ) -> List[int]:
  643. """
  644. hamilton_calculate_weight to allocate percentage based on weight and instance count.
  645. The total should be 100.
  646. :param weight_instance_pairs: weight and instance count pairs
  647. :type weight_instance_pairs: List[Tuple[int, int]]
  648. :return: list of percentage for instance
  649. :rtype: List[int]
  650. """
  651. weight_instance_pairs = scale_weight(weight_instance_pairs)
  652. instances_info = []
  653. for weight, instance_count in weight_instance_pairs:
  654. for _ in range(instance_count):
  655. instances_info.append({'weight': weight, 'group_weight': weight})
  656. total_weight = sum(max(info['weight'], max_weight) for info in instances_info)
  657. if total_weight == 0:
  658. return []
  659. for info in instances_info:
  660. weight = max(info['weight'], max_weight)
  661. info['exact_quota'] = weight * 100 / total_weight
  662. info['floor_quota'] = int(info['exact_quota'])
  663. info['remainder'] = info['exact_quota'] - info['floor_quota']
  664. total_floor = sum(info['floor_quota'] for info in instances_info)
  665. remaining_seats = 100 - total_floor
  666. sorted_instances = sorted(instances_info, key=lambda x: -x['remainder'])
  667. for i in range(remaining_seats):
  668. sorted_instances[i]['floor_quota'] += 1
  669. return [info['floor_quota'] for info in instances_info]
  670. def model_instances_registry_list(
  671. model_instances: List[Union[ModelInstance, ModelInstancePublic]],
  672. workers: Optional[Dict[int, Worker]] = None,
  673. ) -> DestinationTupleList:
  674. registries: DestinationTupleList = []
  675. for model_instance in model_instances:
  676. worker = (
  677. (workers or {}).get(model_instance.worker_id)
  678. if model_instance.worker_id
  679. else None
  680. )
  681. registry = model_instance_registry(model_instance, worker=worker)
  682. if registry is not None:
  683. registries.append((1, model_instance.model_name, registry))
  684. return registries
  685. @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
  686. async def ensure_model_ingress(
  687. ingress_name: str,
  688. ingress_class_name: str,
  689. route_name: str,
  690. namespace: str,
  691. destinations: DestinationTupleList,
  692. event_type: EventType,
  693. networking_api: k8s_client.NetworkingV1Api,
  694. included_generic_route: Optional[bool] = False,
  695. included_proxy_route: Optional[bool] = False,
  696. extra_annotations: Optional[Dict[str, str]] = None,
  697. ):
  698. """
  699. Ensure the model ingress resource in Kubernetes matches the desired state.
  700. Parameters:
  701. ingress_name (str): The name of the ingress resource.
  702. namespace (str): The Kubernetes namespace for the ingress resource.
  703. destinations (DestinationTupleList): Weighted list of MCP Bridge registries for traffic routing.
  704. route_name (str): The name of the model route for which ingress is managed.
  705. event_type (EventType): The event type (CREATED, UPDATED, DELETED) triggering reconciliation.
  706. networking_api (k8s_client.NetworkingV1Api): The Kubernetes networking API client.
  707. hostname (Optional[str]): The external hostname for ingress routing.
  708. tls_secret_name (Optional[str]): TLS secret name for HTTPS ingress.
  709. included_generic_route (bool): Whether to include a generic '/' route for fallback traffic. Used in worker gateway.
  710. included_proxy_route (bool): Whether to include a proxy route for model traffic (e.g., /model/proxy/{model_name}). Used in server gateway.
  711. """
  712. if event_type == EventType.DELETED:
  713. try:
  714. await networking_api.delete_namespaced_ingress(
  715. name=ingress_name, namespace=namespace
  716. )
  717. logger.info(
  718. f"Deleted model ingress {ingress_name} for model route {route_name}"
  719. )
  720. except ApiException as e:
  721. if e.status != 404:
  722. logger.error(f"Failed to delete ingress {ingress_name}: {e}")
  723. return
  724. expected_destinations = '\n'.join(
  725. [
  726. f"{persentage}% {registry.get_service_name_with_port()}"
  727. for persentage, _, registry in destinations
  728. ]
  729. )
  730. try:
  731. existing_ingress: Optional[k8s_client.V1Ingress] = (
  732. await networking_api.read_namespaced_ingress(
  733. name=ingress_name, namespace=namespace
  734. )
  735. )
  736. except ApiException as e:
  737. if e.status != 404:
  738. logger.error(f"Failed to get ingress {ingress_name}: {e}")
  739. return
  740. existing_ingress = None
  741. hostname, tls = await mirror_hostname_tls_from_ingress(
  742. network_v1_client=networking_api,
  743. gateway_namespace=namespace,
  744. target_ingress_name=GATEWAY_MIRROR_INGRESS_NAME,
  745. )
  746. expected_ingress = generate_model_ingress(
  747. ingress_name=ingress_name,
  748. route_name=route_name,
  749. namespace=namespace,
  750. destinations=expected_destinations,
  751. hostname=hostname,
  752. tls=tls,
  753. included_generic_route=included_generic_route,
  754. included_proxy_route=included_proxy_route,
  755. extra_annotations=extra_annotations,
  756. ingress_class_name=ingress_class_name,
  757. )
  758. if existing_ingress is None:
  759. await networking_api.create_namespaced_ingress(
  760. namespace=namespace,
  761. body=expected_ingress,
  762. )
  763. logger.info(
  764. f"Created model ingress {ingress_name} for model route {route_name}"
  765. )
  766. else:
  767. is_equal = mcp_ingress_equal(
  768. existing=existing_ingress, expected=expected_ingress
  769. )
  770. if not is_equal:
  771. existing_ingress.spec = expected_ingress.spec
  772. metadata = existing_ingress.metadata or k8s_client.V1ObjectMeta()
  773. metadata.annotations = metadata.annotations or {}
  774. expected_higress_keys = set()
  775. for key, value in (expected_ingress.metadata.annotations or {}).items():
  776. if key.startswith("higress.io"):
  777. metadata.annotations[key] = value
  778. expected_higress_keys.add(key)
  779. to_delete = [
  780. key
  781. for key in metadata.annotations.keys()
  782. if key.startswith("higress.io") and key not in expected_higress_keys
  783. ]
  784. for key in to_delete:
  785. del metadata.annotations[key]
  786. await networking_api.replace_namespaced_ingress(
  787. name=ingress_name,
  788. namespace=namespace,
  789. body=existing_ingress,
  790. )
  791. logger.info(
  792. f"Updated model ingress {ingress_name} for model route {route_name}"
  793. )
  794. @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
  795. async def ensure_wasm_plugin(
  796. api: ExtensionsHigressIoV1Api,
  797. name: str,
  798. namespace: str,
  799. spec_diff: Callable[[Optional[WasmPluginSpec]], WasmPluginSpec],
  800. extra_labels: Optional[Dict[str, str]] = None,
  801. ):
  802. labels = copy.deepcopy(managed_labels)
  803. if extra_labels:
  804. labels.update(extra_labels)
  805. current_plugin = None
  806. try:
  807. data: Dict[str, Any] = await api.get_wasmplugin(namespace=namespace, name=name)
  808. current_plugin = WasmPlugin.model_validate(data)
  809. except ApiException as e:
  810. if e.status == 404:
  811. current_plugin = None
  812. else:
  813. raise
  814. current_spec = getattr(current_plugin, 'spec', None)
  815. expected = spec_diff(copy.deepcopy(current_spec))
  816. if current_plugin is None:
  817. wasm_plugin_body = WasmPlugin(
  818. metadata={
  819. "name": name,
  820. "namespace": namespace,
  821. "labels": labels,
  822. },
  823. spec=expected,
  824. )
  825. await api.create_wasmplugin(
  826. namespace=namespace,
  827. body=wasm_plugin_body,
  828. )
  829. logger.info(f"Created WasmPlugin {name} in namespace {namespace}.")
  830. elif match_labels(current_plugin.metadata.get("labels", {}), labels):
  831. current_spec = (
  832. current_plugin.spec.model_dump(exclude_none=True)
  833. if current_plugin.spec
  834. else {}
  835. )
  836. expected_spec = expected.model_dump(exclude_none=True) if expected else {}
  837. if current_spec != expected_spec:
  838. current_plugin.spec = expected
  839. await api.edit_wasmplugin(
  840. namespace=namespace,
  841. name=name,
  842. body=current_plugin,
  843. )
  844. logger.info(f"Updated WasmPlugin {name} in namespace {namespace}.")
  845. async def cleanup_model_mapper(
  846. namespace: str,
  847. expected_ingresses: List[str],
  848. config: k8s_client.Configuration,
  849. extra_labels: Optional[Dict[str, str]] = None,
  850. ):
  851. if config is None:
  852. return
  853. api = ExtensionsHigressIoV1Api(k8s_client.ApiClient(config))
  854. labels = copy.deepcopy(managed_labels)
  855. if extra_labels:
  856. labels.update(extra_labels)
  857. def spec_diff(current_spec: Optional[WasmPluginSpec]) -> WasmPluginSpec:
  858. if current_spec is None:
  859. return current_spec
  860. to_keep_rules: List[WasmPluginMatchRule] = []
  861. for rule in current_spec.matchRules or []:
  862. if any(ingress in expected_ingresses for ingress in rule.ingress):
  863. to_keep_rules.append(rule)
  864. else:
  865. logger.info(
  866. f"Removing rule with ingress {rule.ingress} from model mapper plugin as it is not in expected ingresses."
  867. )
  868. to_keep_rules.sort(key=lambda r: r.ingress[0] if r.ingress else "")
  869. current_spec.matchRules = to_keep_rules
  870. return current_spec
  871. await ensure_wasm_plugin(
  872. api=api,
  873. name=gpustack_model_mapper_name,
  874. namespace=namespace,
  875. spec_diff=spec_diff,
  876. extra_labels=extra_labels,
  877. )
  878. async def cleanup_ingresses(
  879. namespace: str,
  880. expected_names: List[str],
  881. config: k8s_client.Configuration,
  882. cleanup_prefix: str,
  883. reason: str = "orphaned",
  884. ):
  885. if config is None:
  886. return
  887. networking_api = k8s_client.NetworkingV1Api(k8s_client.ApiClient(config))
  888. try:
  889. # Use label selector to filter only managed ingresses
  890. label_selector = ','.join([f"{k}={v}" for k, v in managed_labels.items()])
  891. ingresses = await networking_api.list_namespaced_ingress(
  892. namespace=namespace,
  893. label_selector=label_selector,
  894. )
  895. for ingress in ingresses.items:
  896. # name must be not None due to label selector
  897. name: str = ingress.metadata.name
  898. if name in expected_names or not name.startswith(cleanup_prefix):
  899. continue
  900. await networking_api.delete_namespaced_ingress(
  901. name=name, namespace=namespace
  902. )
  903. logger.info(
  904. f"Deleted {reason} model ingress {name} in namespace {namespace}."
  905. )
  906. except Exception as e:
  907. logger.error(f"Error cleaning up {reason} model ingresses: {e}")
  908. async def ensure_model_mcp_bridge(
  909. event_type: EventType,
  910. model_id: int,
  911. model_instances: List[Union[ModelInstance, ModelInstancePublic]],
  912. networking_higress_api: NetworkingHigressIoV1Api,
  913. namespace: str,
  914. cluster_id: int,
  915. workers: Optional[Dict[int, Worker]] = None,
  916. ) -> List[McpBridgeRegistry]:
  917. desired_registry: List[McpBridgeRegistry] = []
  918. to_delete_prefix: Optional[str] = model_prefix(model_id)
  919. if event_type != EventType.DELETED:
  920. for model_instance in model_instances:
  921. worker = (
  922. (workers or {}).get(model_instance.worker_id)
  923. if model_instance.worker_id
  924. else None
  925. )
  926. registry = model_instance_registry(model_instance, worker=worker)
  927. if registry is not None:
  928. desired_registry.append(registry)
  929. await ensure_mcp_bridge(
  930. client=networking_higress_api,
  931. namespace=namespace,
  932. mcp_bridge_name=model_mcp_bridge_name(cluster_id),
  933. desired_registries=desired_registry,
  934. to_delete_prefix=to_delete_prefix,
  935. )
  936. return desired_registry
  937. async def mirror_hostname_tls_from_ingress(
  938. network_v1_client: k8s_client.NetworkingV1Api,
  939. gateway_namespace: str,
  940. target_ingress_name: str,
  941. ) -> Tuple[Optional[str], Optional[List[V1IngressTLS]]]:
  942. """
  943. Mirror TLS settings from an existing ingress to be used in the gateway.
  944. Parameters:
  945. api_client (k8s_client.ApiClient): The Kubernetes API client.
  946. gateway_namespace (str): The namespace where the gateway ingress resides.
  947. target_ingress_name (str): The name of the ingress to mirror TLS settings from.
  948. Returns:
  949. Tuple[Optional[str], Optional[List[V1IngressTLS]]]: A tuple containing the hostname and ingress TLS settings.
  950. """
  951. try:
  952. ingress: k8s_client.V1Ingress = await network_v1_client.read_namespaced_ingress(
  953. name=target_ingress_name, namespace=gateway_namespace
  954. )
  955. except ApiException as e:
  956. if e.status == 404:
  957. logger.warning(
  958. f"Target ingress {target_ingress_name} not found in namespace {gateway_namespace} for TLS mirroring."
  959. )
  960. return None, None
  961. else:
  962. raise
  963. tls = getattr(ingress.spec, 'tls', None)
  964. hostname = None
  965. for rule in ingress.spec.rules or []:
  966. if rule.host:
  967. hostname = rule.host
  968. break
  969. return hostname, tls
  970. def get_expected_match_list(
  971. route_name: str,
  972. ingress_prefix: str,
  973. ingress_name: str,
  974. model_name_to_registries: Dict[str, List[str]],
  975. fallback_model_name_to_registries: Dict[str, List[str]],
  976. ) -> List[WasmPluginMatchRule]:
  977. match_list: List[WasmPluginMatchRule] = []
  978. ingress_name = f"{ingress_prefix}{ingress_name}"
  979. for model_name, service_names in model_name_to_registries.items():
  980. config = {"modelMapping": {route_name: model_name}}
  981. match_list.append(
  982. WasmPluginMatchRule(
  983. config=config,
  984. ingress=[ingress_name],
  985. configDisable=False,
  986. service=service_names,
  987. )
  988. )
  989. for model_name, service_names in fallback_model_name_to_registries.items():
  990. # the fallback mapping should include both normal ingress and fallback ingress
  991. # as the normal ingress may not exist when only fallback model is set
  992. fallback_name = fallback_ingress_name(ingress_name)
  993. config = {"modelMapping": {route_name: model_name}}
  994. match_list.append(
  995. WasmPluginMatchRule(
  996. config=config,
  997. ingress=[ingress_name, fallback_name],
  998. configDisable=False,
  999. service=service_names,
  1000. )
  1001. )
  1002. return match_list
  1003. def higress_http_header_matcher(
  1004. operator: Literal["exact", "regex", "prefix"],
  1005. header_key: str,
  1006. header_value: str,
  1007. ) -> Dict[str, str]:
  1008. header_matcher = "match-header"
  1009. return {
  1010. f"higress.io/{operator}-{header_matcher}-{header_key}": header_value,
  1011. }
  1012. async def cleanup_fallback_filters(
  1013. namespace: str,
  1014. expected_names: List[str],
  1015. cleanup_prefix: str,
  1016. reason: str = "orphaned",
  1017. networking_istio_api: Optional[NetworkingIstioIoV1Alpha3Api] = None,
  1018. k8s_config: Optional[k8s_client.Configuration] = None,
  1019. ):
  1020. if networking_istio_api is None:
  1021. if k8s_config is None:
  1022. return
  1023. networking_istio_api = NetworkingIstioIoV1Alpha3Api(
  1024. k8s_client.ApiClient(k8s_config)
  1025. )
  1026. try:
  1027. label_selector = ','.join([f"{k}={v}" for k, v in managed_labels.items()])
  1028. filters = await networking_istio_api.list_envoyfilters(
  1029. namespace=namespace,
  1030. label_selector=label_selector,
  1031. )
  1032. items: List[Dict[str, Any]] = filters.get('items', [])
  1033. for filter_item in items:
  1034. # name must be not None due to label selector
  1035. name = filter_item.get("metadata", {}).get("name", None)
  1036. if (
  1037. name is None
  1038. or name in expected_names
  1039. or not name.startswith(cleanup_prefix)
  1040. ):
  1041. continue
  1042. await networking_istio_api.delete_envoyfilter(
  1043. name=name, namespace=namespace
  1044. )
  1045. logger.info(
  1046. f"Deleted {reason} fallback filter {name} in namespace {namespace}."
  1047. )
  1048. except Exception as e:
  1049. logger.error(f"Error cleaning up {reason} fallback filters: {e}")
  1050. @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
  1051. async def ensure_fallback_filter(
  1052. event_type: EventType,
  1053. ingress_name: str,
  1054. namespace: str,
  1055. networking_istio_api: NetworkingIstioIoV1Alpha3Api,
  1056. ):
  1057. if event_type == EventType.DELETED:
  1058. await cleanup_fallback_filters(
  1059. namespace=namespace,
  1060. expected_names=[],
  1061. networking_istio_api=networking_istio_api,
  1062. cleanup_prefix=ingress_name,
  1063. reason="event deleted",
  1064. )
  1065. return
  1066. existing_filter = None
  1067. try:
  1068. filter_dict = await networking_istio_api.get_envoyfilter(
  1069. namespace=namespace, name=ingress_name
  1070. )
  1071. existing_filter = EnvoyFilter.model_validate(filter_dict)
  1072. except ApiException as e:
  1073. if e.status != 404:
  1074. raise
  1075. except Exception as e:
  1076. raise e
  1077. expected_filter = get_ingress_fallback_envoyfilter(
  1078. ingress_name=ingress_name,
  1079. namespace=namespace,
  1080. labels={**managed_labels},
  1081. extra_req_headers={
  1082. gpustack_fallback_path_header: f'%REQ({gpustack_original_path_header.upper()})%'
  1083. },
  1084. )
  1085. if existing_filter is None:
  1086. await networking_istio_api.create_envoyfilter(
  1087. namespace=namespace,
  1088. body=expected_filter,
  1089. )
  1090. logger.info(
  1091. f"Created fallback EnvoyFilter {ingress_name} in namespace {namespace}."
  1092. )
  1093. else:
  1094. existing_spec_dict = existing_filter.spec.model_dump(exclude_none=True)
  1095. expected_spec_dict = expected_filter.spec.model_dump(exclude_none=True)
  1096. if existing_spec_dict != expected_spec_dict:
  1097. existing_filter.spec = expected_filter.spec
  1098. await networking_istio_api.edit_envoyfilter(
  1099. name=ingress_name,
  1100. namespace=namespace,
  1101. body=existing_filter,
  1102. )
  1103. logger.info(
  1104. f"Updated fallback EnvoyFilter {ingress_name} in namespace {namespace}."
  1105. )
  1106. def ai_proxy_openai_provider_config(id: str) -> Dict[str, Any]:
  1107. return ai_proxy_types.AIProxyDefaultConfig(
  1108. type=ModelProviderTypeEnum.OPENAI,
  1109. id=id,
  1110. failover=ai_proxy_types.FailoverConfig(enabled=False),
  1111. retryOnFailure=ai_proxy_types.EnableState(enabled=False),
  1112. ).model_dump(exclude_none=True, exclude_unset=True)
  1113. def compare_and_append_default_proxy_config(
  1114. existing_providers: List[Dict[str, Any]],
  1115. expected_providers: List[Dict[str, Any]],
  1116. operating_id_prefix: Optional[str] = None,
  1117. ) -> List[Dict[str, Any]]:
  1118. to_keep_config = []
  1119. for provider in existing_providers:
  1120. provider_id: Optional[str] = provider.get('id', None)
  1121. if (
  1122. provider_id is None
  1123. or operating_id_prefix is None
  1124. or not provider_id.startswith(operating_id_prefix)
  1125. ):
  1126. to_keep_config.append(provider)
  1127. continue
  1128. return_providers = expected_providers.copy()
  1129. return_providers.extend(to_keep_config)
  1130. return_providers.sort(key=lambda p: p.get("id", ""))
  1131. return return_providers
  1132. def compare_and_append_proxy_match_rules(
  1133. existing_rules: List[WasmPluginMatchRule],
  1134. expected_rules: List[WasmPluginMatchRule],
  1135. operating_id_prefix: Optional[str] = None,
  1136. ) -> List[WasmPluginMatchRule]:
  1137. to_keep_config = []
  1138. for rule in existing_rules:
  1139. provider_id: Optional[str] = rule.config.get('activeProviderId', None)
  1140. if (
  1141. provider_id is None
  1142. or operating_id_prefix is None
  1143. or not provider_id.startswith(operating_id_prefix)
  1144. ):
  1145. to_keep_config.append(rule)
  1146. continue
  1147. return_rules = expected_rules.copy()
  1148. return_rules.extend(to_keep_config)
  1149. return_rules.sort(key=lambda r: (r.config.get("activeProviderId", None) or ""))
  1150. return return_rules
  1151. async def cleanup_ai_proxy_config(
  1152. providers: List[ModelProvider],
  1153. routes: List[ModelRoute],
  1154. k8s_config: k8s_client.Configuration,
  1155. namespace: str,
  1156. ):
  1157. if k8s_config is None:
  1158. return
  1159. prefixes_to_keep = {model_route_cleanup_prefix(route.id) for route in routes}
  1160. prefixes_to_keep.update(
  1161. {provider_registry_name(provider.id) for provider in providers}
  1162. )
  1163. def should_keep(provider_id: str) -> bool:
  1164. for prefix in prefixes_to_keep:
  1165. if provider_id.startswith(prefix):
  1166. return True
  1167. return False
  1168. try:
  1169. extensions_api = ExtensionsHigressIoV1Api(k8s_client.ApiClient(k8s_config))
  1170. ai_proxy_data = await extensions_api.get_wasmplugin(
  1171. namespace=namespace,
  1172. name=gpustack_ai_proxy_name,
  1173. )
  1174. existing_plugin = WasmPlugin.model_validate(ai_proxy_data)
  1175. current_providers = existing_plugin.spec.defaultConfig.get("providers", [])
  1176. filtered_providers = [
  1177. p for p in current_providers if p.get("id") and should_keep(p.get("id"))
  1178. ]
  1179. existing_plugin.spec.defaultConfig["providers"] = filtered_providers
  1180. filtered_provider_ids = {
  1181. p.get("id") for p in filtered_providers if p.get("id") is not None
  1182. }
  1183. filtered_rules = [
  1184. r
  1185. for r in existing_plugin.spec.matchRules or []
  1186. if r.config.get("activeProviderId") in filtered_provider_ids
  1187. ]
  1188. existing_plugin.spec.matchRules = filtered_rules
  1189. await extensions_api.edit_wasmplugin(
  1190. namespace=namespace,
  1191. name=gpustack_ai_proxy_name,
  1192. body=existing_plugin,
  1193. )
  1194. except k8s_client.ApiException as e:
  1195. logger.error(
  1196. f"Failed to cleanup gpustack AI proxy wasmplugin {gpustack_ai_proxy_name}: {e}"
  1197. )
  1198. raise
  1199. def build_generic_route_path_pattern(route_id: int) -> str:
  1200. """Path pattern matching /model/proxy/<id> and /model/proxy/<id>/<anything>.
  1201. The Higress transformer plugin treats `add` + `path_pattern` + `value` as a
  1202. regex substitution: the portion of :path that matches path_pattern is
  1203. replaced by value, and the resulting string is written to the target header.
  1204. So the pattern MUST consume the entire path — otherwise the unmatched tail
  1205. gets concatenated onto the header value (e.g. `<route_name>v1/models`).
  1206. The `(/.*)?$` tail keeps the `/` boundary after <id> so `/model/proxy/10`
  1207. doesn't spuriously match id=1.
  1208. """
  1209. return f"^/model/proxy/{route_id}(/.*)?$"
  1210. def build_generic_route_header_rule(route_id: int, route_name: str) -> Dict[str, Any]:
  1211. """HeaderRule dict injecting x-higress-llm-model when /model/proxy/<id>/ is hit.
  1212. Example for ``route_id=1, route_name="qwen3-0.6b"``::
  1213. {
  1214. "key": "x-higress-llm-model",
  1215. "value": "qwen3-0.6b",
  1216. "path_pattern": "^/model/proxy/1(/.*)?$",
  1217. }
  1218. At runtime Higress substitutes the portion of ``:path`` that matches
  1219. ``path_pattern`` with ``value`` and writes the result to ``key``.
  1220. ``path_pattern`` anchors both ends so every matched path reduces to just
  1221. ``value`` — see build_generic_route_path_pattern for why.
  1222. """
  1223. return {
  1224. "key": "x-higress-llm-model",
  1225. "value": route_name,
  1226. "path_pattern": build_generic_route_path_pattern(route_id),
  1227. }
  1228. # Generic-route rules are identified by the shape of their path_pattern — this
  1229. # lets the diff/cleanup code coexist with other reqRules blocks a future
  1230. # contributor might add for unrelated purposes, instead of assuming the
  1231. # generic-route block is the only `add` block in the plugin.
  1232. _GENERIC_ROUTE_PATH_PATTERN_RE = re.compile(r"^\^/model/proxy/\d+")
  1233. def _is_generic_route_header(header: Dict[str, Any]) -> bool:
  1234. return bool(_GENERIC_ROUTE_PATH_PATTERN_RE.match(header.get("path_pattern", "")))
  1235. def _split_generic_route_req_rules(
  1236. req_rules: List[Dict[str, Any]],
  1237. ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
  1238. """
  1239. Partition existing reqRules into (blocks_to_preserve, generic_route_headers).
  1240. - Any ``add`` block contributes generic-route headers (identified by
  1241. path_pattern shape) to the second list. Non-generic headers from the same
  1242. block are retained verbatim in the first list, so mixed ownership is safe.
  1243. - Blocks with any other ``operate`` (rename, remove, map, ...) are preserved
  1244. untouched.
  1245. """
  1246. preserve: List[Dict[str, Any]] = []
  1247. generic_headers: List[Dict[str, Any]] = []
  1248. for rule in req_rules:
  1249. if rule.get("operate") != "add":
  1250. preserve.append(rule)
  1251. continue
  1252. foreign_headers: List[Dict[str, Any]] = []
  1253. for header in rule.get("headers", []):
  1254. if _is_generic_route_header(header):
  1255. generic_headers.append(header)
  1256. else:
  1257. foreign_headers.append(header)
  1258. if foreign_headers:
  1259. preserve.append({**rule, "headers": foreign_headers})
  1260. return preserve, generic_headers
  1261. def _set_generic_route_headers(
  1262. current_spec: WasmPluginSpec, headers: List[Dict[str, Any]]
  1263. ) -> WasmPluginSpec:
  1264. # Do NOT touch defaultConfigDisable here — flipping it rewrites Envoy's
  1265. # filter chain and tears down every live connection. Only the reqRules
  1266. # list changes between reconciliations; the enable flag is locked at
  1267. # plugin creation (see generic_route_transformer_plugin).
  1268. default_config = current_spec.defaultConfig or {}
  1269. preserve, _ = _split_generic_route_req_rules(default_config.get("reqRules", []))
  1270. new_req_rules: List[Dict[str, Any]] = list(preserve)
  1271. if headers:
  1272. sorted_headers = sorted(headers, key=lambda h: h.get("path_pattern", ""))
  1273. new_req_rules.append({"operate": "add", "headers": sorted_headers})
  1274. current_spec.defaultConfig = {"reqRules": new_req_rules}
  1275. return current_spec
  1276. def generic_route_transformer_diff_spec(
  1277. current_spec: Optional[WasmPluginSpec],
  1278. expected_header_rules: List[Dict[str, Any]],
  1279. operating_path_pattern: str,
  1280. ) -> Optional[WasmPluginSpec]:
  1281. """
  1282. Merge expected_header_rules into the current spec, replacing any existing rule
  1283. whose path_pattern equals operating_path_pattern. Other routes' rules stay
  1284. untouched, as do unrelated reqRules blocks added by other subsystems.
  1285. Returns None if the plugin doesn't exist yet — init handles that.
  1286. Example — reconciling route id=2 renamed to "route-two-renamed" while route
  1287. id=1 is already registered:
  1288. # current_spec.defaultConfig.reqRules (before)
  1289. [{"operate": "add", "headers": [
  1290. {"key": "x-higress-llm-model", "value": "route-one",
  1291. "path_pattern": "^/model/proxy/1(/.*)?$"},
  1292. {"key": "x-higress-llm-model", "value": "route-two",
  1293. "path_pattern": "^/model/proxy/2(/.*)?$"},
  1294. ]}]
  1295. # call
  1296. generic_route_transformer_diff_spec(
  1297. current_spec,
  1298. expected_header_rules=[build_generic_route_header_rule(2, "route-two-renamed")],
  1299. operating_path_pattern="^/model/proxy/2(/.*)?$",
  1300. )
  1301. # current_spec.defaultConfig.reqRules (after)
  1302. [{"operate": "add", "headers": [
  1303. {"key": "x-higress-llm-model", "value": "route-one",
  1304. "path_pattern": "^/model/proxy/1(/.*)?$"},
  1305. {"key": "x-higress-llm-model", "value": "route-two-renamed",
  1306. "path_pattern": "^/model/proxy/2(/.*)?$"},
  1307. ]}]
  1308. Pass expected_header_rules=[] to remove the rule for this route (e.g. when
  1309. generic_proxy is toggled off or the route is deleted).
  1310. """
  1311. if current_spec is None:
  1312. return current_spec
  1313. req_rules = (current_spec.defaultConfig or {}).get("reqRules", [])
  1314. _, generic_headers = _split_generic_route_req_rules(req_rules)
  1315. retained = [
  1316. h for h in generic_headers if h.get("path_pattern") != operating_path_pattern
  1317. ]
  1318. return _set_generic_route_headers(current_spec, retained + expected_header_rules)
  1319. def cleanup_generic_route_transformer_spec_diff(
  1320. current_spec: Optional[WasmPluginSpec],
  1321. expected_path_patterns: Set[str],
  1322. ) -> Optional[WasmPluginSpec]:
  1323. """
  1324. Drop generic-route HeaderRules whose path_pattern is not in
  1325. expected_path_patterns. Non-generic rules (any shape that doesn't look like
  1326. ``^/model/proxy/<id>``) are left untouched. Used on startup to prune rules
  1327. for routes that were deleted or had generic_proxy toggled off while the
  1328. server was down.
  1329. """
  1330. if current_spec is None:
  1331. return current_spec
  1332. req_rules = (current_spec.defaultConfig or {}).get("reqRules", [])
  1333. _, generic_headers = _split_generic_route_req_rules(req_rules)
  1334. retained = [
  1335. h for h in generic_headers if h.get("path_pattern") in expected_path_patterns
  1336. ]
  1337. return _set_generic_route_headers(current_spec, retained)
  1338. async def cleanup_generic_route_transformer(
  1339. routes: List[ModelRoute],
  1340. k8s_config: k8s_client.Configuration,
  1341. namespace: str,
  1342. ):
  1343. """Prune generic-route transformer rules to those for existing generic_proxy routes."""
  1344. if k8s_config is None:
  1345. return
  1346. expected_patterns = {
  1347. build_generic_route_path_pattern(route.id)
  1348. for route in routes
  1349. if getattr(route, "generic_proxy", False)
  1350. }
  1351. api = ExtensionsHigressIoV1Api(k8s_client.ApiClient(k8s_config))
  1352. await ensure_wasm_plugin(
  1353. api=api,
  1354. name=gpustack_generic_route_transformer_name,
  1355. namespace=namespace,
  1356. spec_diff=partial(
  1357. cleanup_generic_route_transformer_spec_diff,
  1358. expected_path_patterns=expected_patterns,
  1359. ),
  1360. )
  1361. async def cleanup_mcpbridge_registry(
  1362. providers: List[ModelProvider],
  1363. model_instances: List[ModelInstance],
  1364. workers: List[Worker],
  1365. namespace: str,
  1366. k8s_config: k8s_client.Configuration,
  1367. ):
  1368. if k8s_config is None:
  1369. return
  1370. worker_by_id = {worker.id: worker for worker in workers}
  1371. networking_higress_api = NetworkingHigressIoV1Api(k8s_client.ApiClient(k8s_config))
  1372. # cleanup providers
  1373. desired_registries = []
  1374. desired_proxies = []
  1375. for provider in providers:
  1376. registry = provider_registry(provider=provider)
  1377. if registry is not None:
  1378. desired_registries.append(registry)
  1379. proxy = provider_proxy(provider=provider)
  1380. if proxy is not None:
  1381. desired_proxies.append(proxy)
  1382. to_delete_prefix = provider_id_prefix
  1383. await ensure_mcp_bridge(
  1384. client=networking_higress_api,
  1385. namespace=namespace,
  1386. mcp_bridge_name=default_mcp_bridge_name,
  1387. desired_registries=desired_registries,
  1388. to_delete_prefix=to_delete_prefix,
  1389. desired_proxies=desired_proxies,
  1390. to_delete_proxies_prefix=provider_id_prefix,
  1391. )
  1392. # cleanup model instances
  1393. desired_registries = []
  1394. to_delete_prefix = model_id_prefix
  1395. for instance in model_instances:
  1396. worker = worker_by_id.get(instance.worker_id)
  1397. registry = model_instance_registry(instance, worker=worker)
  1398. if registry is not None:
  1399. desired_registries.append(registry)
  1400. await ensure_mcp_bridge(
  1401. client=networking_higress_api,
  1402. namespace=namespace,
  1403. mcp_bridge_name=default_mcp_bridge_name,
  1404. desired_registries=desired_registries,
  1405. to_delete_prefix=to_delete_prefix,
  1406. )
  1407. def ai_proxy_diff_spec(
  1408. current_spec: Optional[WasmPluginSpec],
  1409. expected_providers: List[Dict[str, Any]],
  1410. expected_match_rules: List[WasmPluginMatchRule],
  1411. operating_id_prefix: Optional[str] = None,
  1412. ) -> WasmPluginSpec:
  1413. if current_spec is None:
  1414. return current_spec
  1415. current_spec.defaultConfig["providers"] = compare_and_append_default_proxy_config(
  1416. existing_providers=current_spec.defaultConfig.get("providers", []),
  1417. expected_providers=expected_providers,
  1418. operating_id_prefix=operating_id_prefix,
  1419. )
  1420. current_spec.matchRules = compare_and_append_proxy_match_rules(
  1421. existing_rules=current_spec.matchRules or [],
  1422. expected_rules=expected_match_rules,
  1423. operating_id_prefix=operating_id_prefix,
  1424. )
  1425. return current_spec
  1426. def get_instance_id_from_header(headers: Headers) -> int:
  1427. """Parse the model instance ID from the ``x-gpustack-model-instance`` routing header.
  1428. The header value follows the pattern ``model-<model_id>-<instance_id>.<suffix>``
  1429. injected by the API gateway. The instance ID is the last numeric segment
  1430. before the first dot.
  1431. Raises:
  1432. HTTPException (400): if the header is absent.
  1433. NotFoundException: if the header value does not match the expected pattern.
  1434. """
  1435. model_destination = headers.get(router_header_key, None)
  1436. if model_destination is None:
  1437. raise HTTPException(
  1438. status_code=400, detail=f"Missing {router_header_key} header"
  1439. )
  1440. # Match pattern: model-<model_id>-<instance_id>.suffix
  1441. # instance_id is the last numeric segment before the first dot
  1442. match = re.match(r'^model-.*-(\d+)\..+', model_destination)
  1443. if not match:
  1444. raise NotFoundException(
  1445. message=f"Invalid model destination format: {model_destination}"
  1446. )
  1447. return int(match.group(1))
  1448. async def resolve_instance_address_from_model_header(
  1449. headers: Dict[str, str],
  1450. ) -> Tuple[Optional[str], int]:
  1451. """Resolve the target worker (IP, port) for an inference request.
  1452. Parses the ``x-gpustack-model-instance`` routing header injected by the API gateway
  1453. to extract the model instance ID, then queries the database for that
  1454. instance's worker IP and inference port.
  1455. Used as the ``header_router`` callback of ``HTTPSProxyServer`` in tunnel
  1456. proxy mode so the proxy knows which instance address to forward each request to.
  1457. Returns ``(None, 0)`` when the header is absent or the instance cannot be
  1458. resolved, causing the proxy to fall back to URI-based routing.
  1459. """
  1460. try:
  1461. instance_id = get_instance_id_from_header(headers)
  1462. except HTTPException as e:
  1463. logger.trace(f"direct proxying request as: {e}")
  1464. return None, 0
  1465. except Exception as e:
  1466. logger.debug(f"Error parsing model destination header: {e}")
  1467. return None, 0
  1468. async with async_session() as session:
  1469. model_instance_service = ModelInstanceService(session)
  1470. model_instance: ModelInstance = await model_instance_service.get_by_id(
  1471. instance_id
  1472. )
  1473. if model_instance is None:
  1474. logger.error(f"Model instance with ID {instance_id} not found.")
  1475. return None, 0
  1476. if model_instance.worker_ip is None or len(model_instance.ports) == 0:
  1477. logger.error(
  1478. f"Model instance with ID {instance_id} do not get scheduled yet."
  1479. )
  1480. return None, 0
  1481. return model_instance.worker_ip, model_instance.ports[0]
  1482. async def worker_websocket_connect_callback(
  1483. _server: Optional[ServerInfo],
  1484. client: Optional[RegisteredClientInfo],
  1485. proxy_address: Optional[str] = None,
  1486. ) -> None:
  1487. """Update ``worker.proxy_address`` in the database when a tunnel connects or disconnects.
  1488. Called by ``MessageServerHandler`` as the ``callback_on_connect`` /
  1489. ``callback_on_disconnect`` hook. On connect, ``proxy_address`` is the
  1490. server-side HTTP proxy URL the gateway should route to; on disconnect it is
  1491. ``None``, clearing the field so the worker is no longer reachable via tunnel.
  1492. The worker is looked up by matching ``client.client_id`` against
  1493. ``Worker.worker_uuid``. If no matching worker is found the callback logs an
  1494. error and returns without modifying the database.
  1495. """
  1496. if client is None:
  1497. return
  1498. async with async_session() as session:
  1499. worker = await Worker.one_by_field(
  1500. session=session, field="worker_uuid", value=str(client.client_id)
  1501. )
  1502. if worker is None:
  1503. logger.error(f"Worker with UUID {client.client_id} not found.")
  1504. return
  1505. if worker.proxy_address == proxy_address:
  1506. return
  1507. worker.proxy_address = proxy_address
  1508. await WorkerService(session).update(worker)