__init__.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. import time
  2. import asyncio
  3. import base64
  4. import os
  5. import logging
  6. import yaml
  7. import copy
  8. from functools import partial
  9. from typing import Any, Dict, Tuple, List, Optional, Literal
  10. from pydantic import BaseModel
  11. from kubernetes_asyncio import client as k8s_client
  12. from kubernetes_asyncio.client import Configuration
  13. from kubernetes_asyncio.config.kube_config import KubeConfigLoader, KubeConfigMerger
  14. from kubernetes_asyncio.config.incluster_config import (
  15. InClusterConfigLoader,
  16. SERVICE_TOKEN_FILENAME,
  17. SERVICE_CERT_FILENAME,
  18. )
  19. from kubernetes_asyncio.client.rest import ApiException
  20. from gpustack.api.auth import GATEWAY_AUTH_TOKEN_HEADER
  21. from gpustack.config.config import Config
  22. from gpustack.schemas.config import GatewayModeEnum
  23. from gpustack import envs
  24. from gpustack.gateway import client as gw_client
  25. from gpustack.gateway.client import (
  26. McpBridge,
  27. McpBridgeSpec,
  28. McpBridgeRegistry,
  29. WasmPluginSpec,
  30. WasmPluginMatchRule,
  31. )
  32. from gpustack.gateway.labels_annotations import managed_labels, match_labels
  33. from gpustack.gateway.utils import (
  34. default_mcp_bridge_name,
  35. openai_model_prefixes,
  36. anthropic_model_exact,
  37. gpustack_ai_proxy_name,
  38. gpustack_model_mapper_name,
  39. gpustack_generic_route_transformer_name,
  40. mcp_ingress_equal,
  41. get_default_mcpbridge_ref,
  42. ensure_wasm_plugin,
  43. router_header_key,
  44. gpustack_original_path_header,
  45. gpustack_fallback_path_header,
  46. )
  47. from gpustack.gateway.plugins import (
  48. get_plugin_url_with_name_and_version,
  49. )
  50. from gpustack.security import AUTH_CACHE_HEADER
  51. logger = logging.getLogger(__name__)
  52. mcp_registry_port = 80
  53. supported_openai_routes = [
  54. route for v in openai_model_prefixes for route in v.flattened_prefixes()
  55. ]
  56. supported_anthropic_routes = [
  57. route for v in anthropic_model_exact for route in v.flattened_prefixes()
  58. ]
  59. async_gateway_config: Configuration = None
  60. def init_async_k8s_config(cfg: Config):
  61. if cfg.gateway_mode == GatewayModeEnum.disabled:
  62. return
  63. global async_gateway_config
  64. if async_gateway_config is not None:
  65. return
  66. configuration = Configuration()
  67. if cfg.gateway_mode == GatewayModeEnum.incluster:
  68. cfg_loader = InClusterConfigLoader(
  69. token_filename=SERVICE_TOKEN_FILENAME,
  70. cert_filename=SERVICE_CERT_FILENAME,
  71. )
  72. cfg_loader.load_and_set(configuration)
  73. else:
  74. kubeconfig_path = cfg.gateway_kubeconfig
  75. if not kubeconfig_path or not os.path.isfile(kubeconfig_path):
  76. logger.debug(f"Kubeconfig not found at {kubeconfig_path}, skipping k8s config initialization")
  77. return
  78. config_dict = KubeConfigMerger(cfg.gateway_kubeconfig).config
  79. if not config_dict or not config_dict.get("current-context"):
  80. logger.debug(f"Kubeconfig at {kubeconfig_path} is empty or missing current-context, skipping k8s config initialization")
  81. return
  82. cfg_loader = KubeConfigLoader(config_dict=config_dict)
  83. if not cfg_loader._load_user_token():
  84. cfg_loader._load_user_pass_token()
  85. cfg_loader._load_cluster_info()
  86. cfg_loader._set_config(configuration)
  87. async_gateway_config = configuration
  88. def get_async_k8s_config(cfg: Config) -> Optional[Configuration]:
  89. if cfg.gateway_mode == GatewayModeEnum.disabled:
  90. return None
  91. global async_gateway_config
  92. if async_gateway_config is None:
  93. init_async_k8s_config(cfg=cfg)
  94. return async_gateway_config
  95. def wait_for_apiserver_ready(cfg: Config, timeout: int = 60, interval: int = 5):
  96. async def get_api_resources():
  97. config = get_async_k8s_config(cfg)
  98. start = time.time()
  99. v1 = k8s_client.CoreV1Api(k8s_client.ApiClient(configuration=config))
  100. while True:
  101. try:
  102. await v1.get_api_resources()
  103. break
  104. except Exception:
  105. if time.time() - start > timeout:
  106. raise
  107. await asyncio.sleep(interval)
  108. try:
  109. asyncio.run(get_api_resources())
  110. except asyncio.CancelledError:
  111. raise
  112. def get_gpustack_higress_registry(cfg: Config) -> McpBridgeRegistry:
  113. registry_type = "dns"
  114. domain = f"{cfg.service_discovery_name}.{cfg.get_namespace()}.svc"
  115. mcp_port = cfg.get_api_port()
  116. if cfg.gateway_mode != GatewayModeEnum.incluster:
  117. registry_type = "static"
  118. mcp_port = mcp_registry_port
  119. port = cfg.get_api_port()
  120. if cfg.gateway_mode == GatewayModeEnum.external:
  121. address = cfg.get_advertise_address()
  122. elif cfg.gateway_mode == GatewayModeEnum.embedded:
  123. address = "127.0.0.1"
  124. domain = f"{address}:{port}"
  125. mcp_registry_name = (
  126. "gpustack"
  127. if cfg.server_role() != Config.ServerRole.WORKER
  128. else "gpustack-worker"
  129. )
  130. registry = McpBridgeRegistry(
  131. type=registry_type,
  132. name=mcp_registry_name,
  133. port=mcp_port,
  134. protocol="http",
  135. domain=domain,
  136. )
  137. return registry
  138. async def ensure_mcp_resources(cfg: Config, api_client: k8s_client.ApiClient):
  139. api = gw_client.NetworkingHigressIoV1Api(api_client)
  140. # use default name for embedded mode
  141. gateway_namespace = cfg.gateway_namespace
  142. try:
  143. data: Dict[str, Any] = await api.get_mcpbridge(
  144. namespace=gateway_namespace, name=default_mcp_bridge_name
  145. )
  146. default_bridge = McpBridge.model_validate(data)
  147. except ApiException as e:
  148. if e.status == 404:
  149. default_bridge = None
  150. else:
  151. raise
  152. target_registry = get_gpustack_higress_registry(cfg=cfg)
  153. try:
  154. if not default_bridge:
  155. bridge = McpBridge(
  156. metadata={"name": "default", "namespace": gateway_namespace},
  157. spec=McpBridgeSpec(registries=[target_registry]),
  158. )
  159. await api.create_mcpbridge(namespace=gateway_namespace, body=bridge)
  160. else:
  161. should_update = False
  162. registries = (
  163. default_bridge.spec.registries
  164. if default_bridge.spec and default_bridge.spec.registries
  165. else []
  166. )
  167. if not any(r.name == target_registry.name for r in registries):
  168. if default_bridge.spec is None:
  169. default_bridge.spec = McpBridgeSpec()
  170. registries.append(target_registry)
  171. default_bridge.spec.registries = registries
  172. should_update = True
  173. else:
  174. registry = next(r for r in registries if r.name == target_registry.name)
  175. if (
  176. registry.type != target_registry.type
  177. or registry.domain != target_registry.domain
  178. or registry.port != target_registry.port
  179. or registry.protocol != target_registry.protocol
  180. ):
  181. registry.type = target_registry.type
  182. registry.domain = target_registry.domain
  183. registry.port = target_registry.port
  184. registry.protocol = target_registry.protocol
  185. should_update = True
  186. if should_update:
  187. await api.edit_mcpbridge(
  188. namespace=gateway_namespace, name='default', body=default_bridge
  189. )
  190. except ApiException as e:
  191. raise RuntimeError("Failed to ensure ingress resources") from e
  192. async def ensure_ingress_resources(cfg: Config, api_client: k8s_client.ApiClient):
  193. """
  194. Ensure the ingress resources to route traffic to mcpbridge are created.
  195. """
  196. gateway_namespace = cfg.gateway_namespace
  197. hostname = cfg.get_external_hostname()
  198. tls_secret_name = cfg.get_tls_secret_name()
  199. network_v1_client = k8s_client.NetworkingV1Api(api_client=api_client)
  200. ingress_name = envs.GATEWAY_MIRROR_INGRESS_NAME
  201. try:
  202. ingress: k8s_client.V1Ingress = await network_v1_client.read_namespaced_ingress(
  203. name=ingress_name, namespace=gateway_namespace
  204. )
  205. except ApiException as e:
  206. if e.status == 404:
  207. ingress = None
  208. else:
  209. raise
  210. registry = get_gpustack_higress_registry(cfg=cfg)
  211. expected_rule = k8s_client.V1IngressRule(
  212. http=k8s_client.V1HTTPIngressRuleValue(
  213. paths=[
  214. k8s_client.V1HTTPIngressPath(
  215. path="/",
  216. path_type="Prefix",
  217. backend=k8s_client.V1IngressBackend(
  218. resource=get_default_mcpbridge_ref()
  219. ),
  220. )
  221. ]
  222. ),
  223. )
  224. expected_ingress = k8s_client.V1Ingress(
  225. metadata=k8s_client.V1ObjectMeta(
  226. name=ingress_name,
  227. namespace=gateway_namespace,
  228. annotations={
  229. "higress.io/destination": f"{registry.get_service_name_with_port()}",
  230. "higress.io/ignore-path-case": "false",
  231. },
  232. labels=managed_labels,
  233. ),
  234. spec=k8s_client.V1IngressSpec(
  235. ingress_class_name=cfg.gateway_ingress_class,
  236. rules=[expected_rule],
  237. ),
  238. )
  239. if tls_secret_name is not None:
  240. expected_ingress.spec.tls = [
  241. k8s_client.V1IngressTLS(
  242. hosts=[hostname] if hostname is not None else None,
  243. secret_name=tls_secret_name,
  244. )
  245. ]
  246. if hostname is not None:
  247. host_rule = copy.deepcopy(expected_rule)
  248. host_rule.host = hostname
  249. expected_ingress.spec.rules.append(host_rule)
  250. if not ingress:
  251. await network_v1_client.create_namespaced_ingress(
  252. namespace=gateway_namespace, body=expected_ingress
  253. )
  254. elif match_labels(getattr(ingress.metadata, 'labels', {}), managed_labels):
  255. # only update ingress managed by gpustack
  256. if not mcp_ingress_equal(ingress, expected_ingress):
  257. await network_v1_client.replace_namespaced_ingress(
  258. name=ingress_name, namespace=gateway_namespace, body=expected_ingress
  259. )
  260. def get_match_rules(
  261. match_type: Literal["whitelist", "blacklist"],
  262. paths: List[Tuple[str, str]],
  263. ) -> Dict[str, Any]:
  264. match_list = [
  265. {
  266. "match_rule_path": pair[0],
  267. "match_rule_type": pair[1],
  268. }
  269. for pair in paths
  270. ]
  271. return {
  272. "match_list": match_list,
  273. "match_type": match_type,
  274. }
  275. def ext_auth_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  276. resource_name = "gpustack-llm-ext-auth"
  277. registry = get_gpustack_higress_registry(cfg=cfg)
  278. # this is to auth requests except for gpustack
  279. default_match_rule = get_match_rules(
  280. match_type="blacklist",
  281. paths=[("/", "prefix")],
  282. )
  283. gpustack_match_rule = get_match_rules(
  284. match_type="whitelist",
  285. paths=[("/", "prefix")],
  286. )
  287. http_service = {
  288. "authorization_request": {
  289. "allowed_headers": [
  290. {"exact": "X-GPUStack-Real-IP"},
  291. {"exact": "x-higress-llm-model"},
  292. {"exact": "x-api-key"},
  293. {"exact": "cookie"},
  294. {"exact": AUTH_CACHE_HEADER},
  295. ],
  296. "headers_to_add": {
  297. GATEWAY_AUTH_TOKEN_HEADER: cfg.get_derived_gateway_token(),
  298. },
  299. },
  300. "authorization_response": {
  301. "allowed_upstream_headers": [
  302. {"exact": "X-Mse-Consumer"},
  303. {"exact": "Authorization"},
  304. {"exact": "cookie"},
  305. {"exact": AUTH_CACHE_HEADER},
  306. ]
  307. },
  308. "endpoint": {
  309. "path": "/token-auth",
  310. "request_method": "GET",
  311. "service_name": registry.get_service_name(),
  312. "service_port": registry.port,
  313. },
  314. "endpoint_mode": "forward_auth",
  315. "timeout": envs.HIGRESS_EXT_AUTH_TIMEOUT_MS,
  316. }
  317. namespace = cfg.get_namespace()
  318. if namespace == cfg.gateway_namespace:
  319. namespace = ""
  320. # the ingress in plugin matchRules should not contains namespace prefix
  321. # if it is in the same namespace with the gateway.
  322. ingress_name = f"{namespace}/{envs.GATEWAY_MIRROR_INGRESS_NAME}".lstrip("/")
  323. expected_spec = WasmPluginSpec(
  324. defaultConfig={
  325. "http_service": http_service,
  326. **default_match_rule,
  327. },
  328. defaultConfigDisable=False,
  329. failStrategy="FAIL_OPEN",
  330. phase="AUTHN",
  331. priority=360,
  332. url=get_plugin_url_with_name_and_version(
  333. name="ext-auth", version="2.0.0", cfg=cfg
  334. ),
  335. matchRules=[
  336. WasmPluginMatchRule(
  337. config={
  338. "http_service": http_service,
  339. **gpustack_match_rule,
  340. },
  341. configDisable=False,
  342. ingress=[ingress_name],
  343. )
  344. ],
  345. )
  346. return resource_name, expected_spec
  347. def ai_statistics_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  348. resource_name = "gpustack-ai-statistics"
  349. expected_spec = WasmPluginSpec(
  350. defaultConfig={
  351. "enable_content_types": envs.GATEWAY_AI_STATISTICS_PLUGIN_CONTENT_TYPES,
  352. "attributes": [
  353. {
  354. "apply_to_log": True,
  355. "apply_to_span": False,
  356. "key": "consumer",
  357. "value": "x-mse-consumer",
  358. "value_source": "request_header",
  359. }
  360. ],
  361. },
  362. defaultConfigDisable=False,
  363. failStrategy="FAIL_OPEN",
  364. imagePullPolicy="UNSPECIFIED_POLICY",
  365. matchRules=[],
  366. phase="UNSPECIFIED_PHASE",
  367. priority=900,
  368. url=get_plugin_url_with_name_and_version(
  369. name="ai-statistics", version="2.0.0", cfg=cfg
  370. ),
  371. )
  372. return resource_name, expected_spec
  373. def model_router_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  374. resource_name = "gpustack-model-router"
  375. enabled_paths = supported_openai_routes + supported_anthropic_routes
  376. enabled_paths.append("/model/proxy")
  377. expected_spec = WasmPluginSpec(
  378. defaultConfig={
  379. 'modelToHeader': 'x-higress-llm-model',
  380. 'enableOnPathSuffix': enabled_paths,
  381. },
  382. defaultConfigDisable=False,
  383. failStrategy="FAIL_OPEN",
  384. imagePullPolicy="UNSPECIFIED_POLICY",
  385. matchRules=[],
  386. phase="AUTHN",
  387. priority=900,
  388. url=get_plugin_url_with_name_and_version(
  389. name="model-router", version="2.0.0", cfg=cfg
  390. ),
  391. )
  392. return resource_name, expected_spec
  393. def model_pre_route_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  394. resource_name = "gpustack-set-model-pre-route"
  395. enabled_path_suffixes = supported_openai_routes + supported_anthropic_routes
  396. enabled_path_prefixes = ["/model/proxy"]
  397. expected_spec = WasmPluginSpec(
  398. defaultConfig={
  399. 'clusterNameHeader': router_header_key,
  400. 'routeNameHeader': 'X-GPUStack-Route-Name',
  401. 'enableOnPathSuffix': enabled_path_suffixes,
  402. 'enableOnPathPrefix': enabled_path_prefixes,
  403. },
  404. defaultConfigDisable=False,
  405. failStrategy="FAIL_OPEN",
  406. imagePullPolicy="UNSPECIFIED_POLICY",
  407. matchRules=[],
  408. phase="AUTHN",
  409. priority=90,
  410. url=get_plugin_url_with_name_and_version(
  411. name="gpustack-set-header-pre-route", version="1.0.0", cfg=cfg
  412. ),
  413. )
  414. return resource_name, expected_spec
  415. def model_mapper_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  416. return gpustack_model_mapper_name, WasmPluginSpec(
  417. phase="AUTHN",
  418. priority=800,
  419. url=get_plugin_url_with_name_and_version(
  420. name="model-mapper", version="2.0.0", cfg=cfg
  421. ),
  422. defaultConfigDisable=False,
  423. defaultConfig={"modelMapping": {}},
  424. matchRules=[],
  425. failStrategy="FAIL_OPEN",
  426. )
  427. class HeaderRule(BaseModel):
  428. key: Optional[str] = None
  429. newKey: Optional[str] = None
  430. oldKey: Optional[str] = None
  431. fromKey: Optional[str] = None
  432. toKey: Optional[str] = None
  433. value: Optional[str] = None
  434. newValue: Optional[str] = None
  435. appendValue: Optional[str] = None
  436. value_type: Optional[Literal["object", "bool", "number", "string"]] = None
  437. strategy: Optional[Literal["RETAIN_FIRST", "RETAIN_LAST", "RETAIN_UNIQUE"]] = None
  438. host_pattern: Optional[str] = None
  439. path_pattern: Optional[str] = None
  440. def transform_header(
  441. operate: Literal["remove", "rename", "replace", "add", "append", "map", "dedupe"],
  442. *rules: HeaderRule,
  443. ) -> Dict[str, Any]:
  444. # TODO: add validation in the future
  445. return {
  446. "headers": [rule.model_dump(exclude_none=True) for rule in rules],
  447. "operate": operate,
  448. }
  449. def transformer_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  450. resource_name = "gpustack-header-transformer"
  451. expected_spec = WasmPluginSpec(
  452. defaultConfig={
  453. "reqRules": [
  454. transform_header(
  455. "remove",
  456. HeaderRule(
  457. key=GATEWAY_AUTH_TOKEN_HEADER,
  458. ),
  459. HeaderRule(
  460. key=router_header_key,
  461. ),
  462. ),
  463. transform_header(
  464. "rename",
  465. HeaderRule(
  466. oldKey="x-gpustack-model",
  467. newKey="x-higress-llm-model",
  468. ),
  469. HeaderRule(
  470. oldKey=gpustack_fallback_path_header,
  471. newKey=":path",
  472. ),
  473. ),
  474. transform_header(
  475. "dedupe",
  476. HeaderRule(
  477. key="x-gpustack-model",
  478. strategy="RETAIN_FIRST",
  479. ),
  480. HeaderRule(
  481. key="x-higress-llm-model",
  482. strategy="RETAIN_FIRST",
  483. ),
  484. HeaderRule(
  485. key=":path",
  486. strategy="RETAIN_LAST",
  487. ),
  488. ),
  489. transform_header(
  490. "map",
  491. HeaderRule(
  492. fromKey=':path',
  493. toKey=gpustack_original_path_header,
  494. ),
  495. ),
  496. transform_header(
  497. "remove",
  498. HeaderRule(
  499. key=gpustack_fallback_path_header,
  500. ),
  501. ),
  502. ],
  503. },
  504. defaultConfigDisable=False,
  505. failStrategy="FAIL_OPEN",
  506. imagePullPolicy="UNSPECIFIED_POLICY",
  507. matchRules=[],
  508. phase="AUTHN",
  509. priority=810,
  510. url=get_plugin_url_with_name_and_version(
  511. name="transformer", version="2.0.0", cfg=cfg
  512. ),
  513. )
  514. return resource_name, expected_spec
  515. def generic_route_transformer_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  516. """
  517. Pre-route transformer that injects x-higress-llm-model based on the route id
  518. captured from /model/proxy/<id>/... paths. Per-route HeaderRules are merged
  519. into defaultConfig.reqRules by the per-route reconciler.
  520. defaultConfigDisable is fixed to False for the lifetime of the plugin —
  521. toggling it rewrites Envoy's filter chain and drops every in-flight
  522. connection through the gateway.
  523. Runtime shape after two generic routes (id=1 "route-one", id=2 "route-two")
  524. have been reconciled — the reconciler only mutates the `headers` list:
  525. apiVersion: extensions.higress.io/v1alpha1
  526. kind: WasmPlugin
  527. metadata:
  528. name: gpustack-generic-route-transformer
  529. spec:
  530. phase: AUTHN
  531. priority: 905
  532. defaultConfigDisable: false
  533. defaultConfig:
  534. reqRules:
  535. - operate: add
  536. headers:
  537. - key: x-higress-llm-model
  538. value: route-one
  539. path_pattern: ^/model/proxy/1(/.*)?$
  540. - key: x-higress-llm-model
  541. value: route-two
  542. path_pattern: ^/model/proxy/2(/.*)?$
  543. On a request for ``/model/proxy/1/v1/chat/completions`` Higress rewrites the
  544. match of path_pattern inside ``:path`` with ``value`` — the whole path is
  545. consumed by the pattern (``(/.*)?$`` tail), so the header becomes exactly
  546. ``route-one`` and routing falls through to the main ingress's header
  547. matcher.
  548. """
  549. expected_spec = WasmPluginSpec(
  550. defaultConfig={"reqRules": []},
  551. defaultConfigDisable=False,
  552. failStrategy="FAIL_OPEN",
  553. imagePullPolicy="UNSPECIFIED_POLICY",
  554. matchRules=[],
  555. phase="AUTHN",
  556. priority=905, # ahead of model-router (900) so header wins
  557. url=get_plugin_url_with_name_and_version(
  558. name="transformer", version="2.0.0", cfg=cfg
  559. ),
  560. )
  561. return gpustack_generic_route_transformer_name, expected_spec
  562. def token_usage_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  563. registry = get_gpustack_higress_registry(cfg=cfg)
  564. resource_name = "gpustack-token-usage"
  565. expected_spec = WasmPluginSpec(
  566. defaultConfig={
  567. 'realIPToHeader': "X-GPUStack-Real-IP",
  568. 'endpoint': {
  569. "path": "/v2/usage/gateway-metrics",
  570. "service_name": registry.get_service_name(),
  571. "service_port": registry.port,
  572. },
  573. 'header_add': {
  574. GATEWAY_AUTH_TOKEN_HEADER: cfg.get_derived_gateway_token(),
  575. },
  576. },
  577. defaultConfigDisable=False,
  578. failStrategy="FAIL_OPEN",
  579. imagePullPolicy="UNSPECIFIED_POLICY",
  580. matchRules=[],
  581. phase="UNSPECIFIED_PHASE",
  582. priority=910,
  583. url=get_plugin_url_with_name_and_version(
  584. name="gpustack-token-usage", version="1.0.0", cfg=cfg
  585. ),
  586. )
  587. return resource_name, expected_spec
  588. def ai_proxy_plugin(cfg: Config) -> Tuple[str, WasmPluginSpec]:
  589. resource_name = gpustack_ai_proxy_name
  590. expected_spec = WasmPluginSpec(
  591. defaultConfig={},
  592. defaultConfigDisable=False,
  593. failStrategy="FAIL_OPEN",
  594. imagePullPolicy="UNSPECIFIED_POLICY",
  595. matchRules=[],
  596. priority=100,
  597. phase="UNSPECIFIED_PHASE",
  598. url=get_plugin_url_with_name_and_version(
  599. name="ai-proxy", version="2.0.0", cfg=cfg
  600. ),
  601. )
  602. return resource_name, expected_spec
  603. async def ensure_tls_secret(cfg: Config, api_client: k8s_client.ApiClient):
  604. """
  605. Ensure the TLS secret if ssl key pair is provided.
  606. """
  607. ssl_keyfile = cfg.ssl_keyfile
  608. ssl_certfile = cfg.ssl_certfile
  609. if not ssl_keyfile or not ssl_certfile:
  610. return
  611. if not (os.path.isfile(ssl_keyfile) and os.path.isfile(ssl_certfile)):
  612. raise RuntimeError(
  613. f"SSL keyfile {ssl_keyfile} or certfile {ssl_certfile} does not exist"
  614. )
  615. # read key and cert files and encode into base64
  616. with open(ssl_keyfile, 'rb') as f:
  617. ssl_key_bytes = f.read()
  618. with open(ssl_certfile, 'rb') as f:
  619. ssl_cert_bytes = f.read()
  620. ssl_key_data = base64.b64encode(ssl_key_bytes).decode()
  621. ssl_cert_data = base64.b64encode(ssl_cert_bytes).decode()
  622. gateway_namespace = cfg.gateway_namespace
  623. core_v1_client = k8s_client.CoreV1Api(api_client=api_client)
  624. secret_name = cfg.get_tls_secret_name()
  625. to_create_tls_secret = k8s_client.V1Secret(
  626. metadata=k8s_client.V1ObjectMeta(
  627. name=secret_name,
  628. namespace=gateway_namespace,
  629. labels=managed_labels,
  630. ),
  631. type="kubernetes.io/tls",
  632. data={
  633. "tls.key": ssl_key_data,
  634. "tls.crt": ssl_cert_data,
  635. },
  636. )
  637. try:
  638. existing_secret: k8s_client.V1Secret = (
  639. await core_v1_client.read_namespaced_secret(
  640. name=secret_name, namespace=gateway_namespace
  641. )
  642. )
  643. except ApiException as e:
  644. if e.status == 404:
  645. existing_secret = None
  646. else:
  647. raise
  648. if not existing_secret:
  649. await core_v1_client.create_namespaced_secret(
  650. namespace=gateway_namespace, body=to_create_tls_secret
  651. )
  652. elif match_labels(getattr(existing_secret.metadata, 'labels', {}), managed_labels):
  653. if existing_secret.data != to_create_tls_secret.data:
  654. await core_v1_client.replace_namespaced_secret(
  655. name=secret_name, namespace=gateway_namespace, body=to_create_tls_secret
  656. )
  657. async def ensure_gateway_timeout(cfg: Config, api_client: k8s_client.ApiClient):
  658. namespace = cfg.gateway_namespace
  659. higress_config_name = "higress-config"
  660. core_v1_client = k8s_client.CoreV1Api(api_client=api_client)
  661. try:
  662. higress_config: k8s_client.V1ConfigMap = (
  663. await core_v1_client.read_namespaced_config_map(
  664. name=higress_config_name, namespace=namespace
  665. )
  666. )
  667. should_update = False
  668. config_data: str = higress_config.data["higress"]
  669. config = yaml.safe_load(config_data)
  670. idle_timeout = (
  671. config.get("downstream", {}).get("idleTimeout")
  672. if isinstance(config, dict)
  673. else None
  674. )
  675. if idle_timeout is None or str(idle_timeout) != f"{envs.PROXY_TIMEOUT}":
  676. config.setdefault("downstream", {})["idleTimeout"] = envs.PROXY_TIMEOUT
  677. higress_config.data["higress"] = yaml.safe_dump(config)
  678. should_update = True
  679. upstream_idle_timeout = (
  680. config.get("upstream", {}).get("idleTimeout")
  681. if isinstance(config, dict)
  682. else None
  683. )
  684. if (
  685. upstream_idle_timeout is None
  686. or str(upstream_idle_timeout) != f"{envs.PROXY_UPSTREAM_IDLE_TIMEOUT}"
  687. ):
  688. config.setdefault("upstream", {})[
  689. "idleTimeout"
  690. ] = envs.PROXY_UPSTREAM_IDLE_TIMEOUT
  691. higress_config.data["higress"] = yaml.safe_dump(config)
  692. should_update = True
  693. if should_update:
  694. await core_v1_client.replace_namespaced_config_map(
  695. name=higress_config_name,
  696. namespace=namespace,
  697. body=higress_config,
  698. )
  699. except Exception as e:
  700. logger.error(f"Failed to read or parse Higress config map: {e}")
  701. raise
  702. def spec_replace(
  703. current_spec: Optional[WasmPluginSpec],
  704. expected_spec: WasmPluginSpec,
  705. create_only: bool = False,
  706. ) -> WasmPluginSpec:
  707. if current_spec is None:
  708. return expected_spec
  709. if create_only:
  710. if current_spec.url != expected_spec.url:
  711. current_spec.url = expected_spec.url
  712. return current_spec
  713. return expected_spec
  714. def validate_ai_statistics_plugin_content_types():
  715. for content_type in envs.GATEWAY_AI_STATISTICS_PLUGIN_CONTENT_TYPES:
  716. if content_type == "audio/pcm":
  717. raise ValueError(
  718. "audio/pcm content type is not supported in ai statistics plugin"
  719. )
  720. def initialize_gateway(cfg: Config, timeout: int = 60, interval: int = 5):
  721. if cfg.gateway_mode == GatewayModeEnum.disabled:
  722. return
  723. init_async_k8s_config(cfg=cfg)
  724. # If k8s config couldn't be initialized (e.g., no valid kubeconfig), skip gateway setup
  725. if async_gateway_config is None:
  726. logger.warning("Gateway k8s config could not be initialized, skipping gateway setup")
  727. return
  728. wait_for_apiserver_ready(cfg=cfg, timeout=timeout, interval=interval)
  729. if cfg.gateway_mode in [
  730. GatewayModeEnum.embedded,
  731. GatewayModeEnum.external,
  732. GatewayModeEnum.incluster,
  733. ]:
  734. validate_ai_statistics_plugin_content_types()
  735. plugin_list: List[Tuple[str, WasmPluginSpec]] = [
  736. ext_auth_plugin(cfg=cfg),
  737. ai_statistics_plugin(cfg=cfg),
  738. model_router_plugin(cfg=cfg),
  739. ai_proxy_plugin(cfg=cfg),
  740. model_pre_route_plugin(cfg=cfg),
  741. model_mapper_plugin(cfg=cfg),
  742. generic_route_transformer_plugin(cfg=cfg),
  743. ]
  744. if cfg.server_role() != Config.ServerRole.WORKER:
  745. plugin_list.append(transformer_plugin(cfg=cfg))
  746. plugin_list.append(token_usage_plugin(cfg=cfg))
  747. async def prepare():
  748. api_client = k8s_client.ApiClient(
  749. configuration=get_async_k8s_config(cfg=cfg)
  750. )
  751. await ensure_tls_secret(cfg=cfg, api_client=api_client)
  752. await ensure_mcp_resources(cfg=cfg, api_client=api_client)
  753. if cfg.gateway_mode != GatewayModeEnum.incluster:
  754. await ensure_gateway_timeout(cfg=cfg, api_client=api_client)
  755. await ensure_ingress_resources(cfg=cfg, api_client=api_client)
  756. for plugin_name, plugin_spec in plugin_list:
  757. create_only = plugin_name in [
  758. gpustack_ai_proxy_name,
  759. gpustack_model_mapper_name,
  760. gpustack_generic_route_transformer_name,
  761. ]
  762. spec_diff_func = partial(
  763. spec_replace, expected_spec=plugin_spec, create_only=create_only
  764. )
  765. await ensure_wasm_plugin(
  766. api=gw_client.ExtensionsHigressIoV1Api(api_client),
  767. name=plugin_name,
  768. namespace=cfg.gateway_namespace,
  769. spec_diff=spec_diff_func,
  770. )
  771. try:
  772. asyncio.run(prepare())
  773. except asyncio.CancelledError:
  774. raise
  775. except Exception as e:
  776. raise RuntimeError("Failed to initialize gateway resources") from e