| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655 |
- import logging
- import random
- import string
- import asyncio
- import yaml
- from importlib.resources import files
- from functools import partial
- from typing import Any, Dict, List, Tuple, Optional, Set
- from pydantic import BaseModel
- from sqlmodel import select
- from sqlmodel.ext.asyncio.session import AsyncSession
- from sqlalchemy.orm import selectinload
- from sqlalchemy.orm.attributes import flag_modified
- from gpustack.config.config import (
- Config,
- get_cluster_image_name,
- )
- from gpustack.policies.scorers.offload_layer_scorer import OffloadLayerScorer
- from gpustack.policies.scorers.placement_scorer import PlacementScorer, ScaleTypeEnum
- from gpustack.policies.scorers.score_chain import (
- ModelInstanceScoreChain,
- )
- from gpustack.policies.base import ModelInstanceScore
- from gpustack.policies.scorers.status_scorer import StatusScorer
- from gpustack.schemas.inference_backend import (
- InferenceBackend,
- get_built_in_backend,
- VersionConfig,
- VersionConfigDict,
- )
- from gpustack.schemas.links import ModelRoutePrincipalLink
- from gpustack.schemas.model_files import ModelFile, ModelFileStateEnum
- from gpustack.schemas.model_routes import (
- ModelRoute,
- ModelRouteTarget,
- MyModel,
- TargetStateEnum,
- effective_route_name,
- )
- from gpustack.schemas.principals import (
- Principal,
- PLATFORM_PRINCIPAL_ID,
- )
- from gpustack.schemas.models import (
- BackendEnum,
- BackendSourceEnum,
- ModelSource,
- Model,
- ModelInstance,
- ModelInstanceCreate,
- ModelInstanceStateEnum,
- SourceEnum,
- get_backend,
- )
- from gpustack.schemas.config import (
- GatewayModeEnum,
- SensitivePredefinedConfig,
- )
- from gpustack.schemas.workers import (
- Worker,
- WorkerStateEnum,
- WorkerStatus,
- )
- from gpustack.schemas.clusters import (
- Cluster,
- WorkerPool,
- CloudCredential,
- Credential,
- CredentialType,
- ClusterStateEnum,
- SSHKeyOptions,
- ClusterProvider,
- )
- from gpustack.schemas.users import (
- User,
- is_default_cluster_user,
- )
- from gpustack.server.bus import Event, EventType, event_bus
- from gpustack.utils.model_source import get_draft_model_source
- from gpustack import envs
- from gpustack.server.db import async_session
- from gpustack.server.services import (
- ModelFileService,
- ModelInstanceService,
- ModelService,
- WorkerService,
- ModelRouteService,
- )
- from gpustack.utils.model_instance_workers import get_model_instance_worker_match
- from gpustack.cloud_providers.common import (
- get_client_from_provider,
- construct_cloud_instance,
- generate_ssh_key_pair,
- )
- from gpustack.cloud_providers.abstract import (
- ProviderClientBase,
- CloudInstance,
- InstanceState,
- )
- from kubernetes_asyncio import client as k8s_client
- from gpustack.gateway.client.networking_higress_io_v1_api import (
- NetworkingHigressIoV1Api,
- McpBridgeRegistry,
- )
- from gpustack.gateway.client.extensions_higress_io_v1_api import (
- ExtensionsHigressIoV1Api,
- WasmPluginMatchRule,
- WasmPluginSpec,
- )
- from gpustack.gateway.client.networking_istio_io_v1alpha3_api import (
- NetworkingIstioIoV1Alpha3Api,
- )
- from gpustack.gateway import utils as mcp_handler
- from gpustack.gateway import get_async_k8s_config
- from gpustack.schemas.model_provider import (
- ModelProvider,
- )
- logger = logging.getLogger(__name__)
- class ModelController:
- def __init__(self, cfg: Config):
- self._config = cfg
- self._k8s_config = get_async_k8s_config(cfg=cfg)
- self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
- pass
- async def start(self):
- """
- Start the controller.
- """
- if not self._disable_gateway:
- base_client = k8s_client.ApiClient(configuration=self._k8s_config)
- self._higress_network_api = NetworkingHigressIoV1Api(base_client)
- async for event in Model.subscribe(source="model_controller"):
- if event.type == EventType.HEARTBEAT:
- continue
- await self._reconcile(event)
- async def _ensure_model_mcp_bridge(
- self, session: AsyncSession, event_type: EventType, model: Model
- ):
- if self._disable_gateway:
- return
- model_instances = await ModelInstance.all_by_fields(
- session,
- fields={"model_id": model.id, "deleted_at": None},
- )
- worker_by_id = None
- worker_ids = {
- instance.worker_id for instance in model_instances if instance.worker_id
- }
- if worker_ids:
- workers = await Worker.all_by_fields(
- session,
- extra_conditions=[
- Worker.id.in_(worker_ids),
- ],
- )
- worker_by_id = {worker.id: worker for worker in workers}
- await mcp_handler.ensure_model_mcp_bridge(
- event_type=event_type,
- model_id=model.id,
- model_instances=model_instances,
- networking_higress_api=self._higress_network_api,
- namespace=self._config.gateway_namespace,
- cluster_id=model.cluster_id,
- workers=worker_by_id,
- )
- async def _reconcile(self, event: Event):
- """
- Reconcile the model.
- """
- model: Model = event.data
- try:
- async with async_session() as session:
- await sync_replicas(session, model)
- await notify_model_route_target(
- session=session, model=model, event=event
- )
- await sync_categories_and_meta(session, model, event)
- await self._ensure_model_mcp_bridge(session, event.type, model)
- except Exception as e:
- logger.error(f"Failed to reconcile model {model.name}: {e}")
- class ModelInstanceController:
- def __init__(self, cfg: Config):
- self._config = cfg
- pass
- async def start(self):
- """
- Start the controller.
- """
- async for event in ModelInstance.subscribe(source="model_instance_controller"):
- if event.type == EventType.HEARTBEAT:
- continue
- await self._reconcile(event)
- async def _reconcile(self, event: Event):
- """
- Reconcile the model.
- """
- model_instance: ModelInstance = event.data
- try:
- async with async_session() as session:
- model = await Model.one_by_id(session, model_instance.model_id)
- if not model:
- return
- model_deleting = model.deleted_at is not None
- if event.type == EventType.DELETED:
- # trigger model replica sync, but only if model is not deleted
- if not model_deleting:
- copied_model = Model.model_validate(model.model_dump())
- asyncio.create_task(
- event_bus.publish(
- Model.__name__.lower(),
- Event(type=EventType.UPDATED, data=copied_model),
- )
- )
- elif model_instance.state == ModelInstanceStateEnum.INITIALIZING:
- await ensure_instance_model_file(session, model_instance)
- return
- if model_deleting:
- return
- await model.refresh(session)
- await sync_ready_replicas(session, model)
- except Exception as e:
- logger.error(
- f"Failed to reconcile model instance {model_instance.name}: {e}"
- )
- async def sync_replicas(session: AsyncSession, model: Model):
- """
- Synchronize the replicas.
- """
- # Re-fetch model from database to ensure we have latest state
- # (event data may be from a different session or stale)
- fresh_model = await Model.one_by_id(session, model.id)
- if not fresh_model or fresh_model.deleted_at is not None:
- return
- model = fresh_model
- instances = await ModelInstance.all_by_field(session, "model_id", model.id)
- if len(instances) < model.replicas:
- for _ in range(model.replicas - len(instances)):
- name_prefix = ''.join(
- random.choices(string.ascii_letters + string.digits, k=5)
- )
- instance = ModelInstanceCreate(
- name=f"{model.name}-{name_prefix}",
- model_id=model.id,
- model_name=model.name,
- source=model.source,
- huggingface_repo_id=model.huggingface_repo_id,
- huggingface_filename=model.huggingface_filename,
- model_scope_model_id=model.model_scope_model_id,
- model_scope_file_path=model.model_scope_file_path,
- local_path=model.local_path,
- state=ModelInstanceStateEnum.PENDING,
- cluster_id=model.cluster_id,
- # Inherit the parent Model's tenant binding — the schema
- # default of PLATFORM_PRINCIPAL_ID would otherwise
- # land instances of a non-Default-Org Model in Default.
- owner_principal_id=model.owner_principal_id,
- draft_model_source=get_draft_model_source(model),
- backend=get_backend(model),
- backend_version=model.backend_version,
- )
- await ModelInstanceService(session).create(instance)
- logger.debug(f"Created model instance for model {model.name}")
- elif len(instances) > model.replicas:
- # Get instances for update lock, to avoid race condition with scheduler
- instances = await ModelInstance.all_by_field(
- session, "model_id", model.id, for_update=True
- )
- candidates = await find_scale_down_candidates(instances, model)
- scale_down_count = len(candidates) - model.replicas
- if scale_down_count > 0:
- scale_down_instances = []
- for candidate in candidates[:scale_down_count]:
- scale_down_instances.append(candidate.model_instance)
- scale_down_instance_names = await ModelInstanceService(
- session
- ).batch_delete(scale_down_instances)
- if scale_down_instance_names:
- logger.debug(f"Deleted model instances: {scale_down_instance_names}")
- async def distribute_models_to_user(
- session: AsyncSession, model: ModelRoute, event: Event
- ):
- if len(event.changed_fields) == 0 and event.type == EventType.CREATED:
- return
- model_dict = model.model_dump(exclude={"instances", "users", "cluster"})
- model_id = model.id
- to_delete_model_user_ids: Set[int] = set()
- to_update_model_user_ids: Set[int] = set()
- to_create_model_user_ids: Set[int] = set()
- if event.type == EventType.DELETED:
- users = await User.all_by_fields(
- session, fields={"deleted_at": None, "is_admin": False}
- )
- for user in users:
- to_delete_model_user_ids.add(user.id)
- if event.type == EventType.UPDATED:
- changed_fields = event.changed_fields.copy()
- changed_users = changed_fields.pop("users", None)
- if changed_users is not None:
- old_users, new_users = changed_users
- old_user_ids = {user.id for user in old_users}
- new_user_ids = {user.id for user in new_users}
- to_create_model_user_ids = new_user_ids - old_user_ids
- to_delete_model_user_ids = old_user_ids - new_user_ids
- if len(changed_fields) > 0:
- users = await User.all_by_fields(
- session,
- fields={"deleted_at": None, "is_admin": False},
- extra_conditions=[
- User.principal_id.in_(
- select(ModelRoutePrincipalLink.principal_id).where(
- ModelRoutePrincipalLink.route_id == model.id
- )
- )
- ],
- )
- current_user_ids = {user.id for user in users}
- to_update_model_user_ids = current_user_ids - to_create_model_user_ids
- if event.type == EventType.CREATED:
- users = await User.all_by_fields(
- session,
- fields={"deleted_at": None, "is_admin": False},
- extra_conditions=[
- User.principal_id.in_(
- select(ModelRoutePrincipalLink.principal_id).where(
- ModelRoutePrincipalLink.route_id == model.id
- )
- )
- ],
- )
- for user in users:
- to_create_model_user_ids.add(user.id)
- tasks = []
- for event_type, ids in [
- (EventType.CREATED, to_create_model_user_ids),
- (EventType.DELETED, to_delete_model_user_ids),
- (EventType.UPDATED, to_update_model_user_ids),
- ]:
- for user_id in ids:
- my_model = MyModel(
- pid=f"{model_id}:{user_id}",
- user_id=user_id,
- **model_dict,
- )
- tasks.append(
- event_bus.publish(
- MyModel.__name__.lower(), Event(type=event_type, data=my_model)
- )
- )
- if tasks:
- await asyncio.gather(*tasks)
- async def ensure_instance_model_file(session: AsyncSession, instance: ModelInstance):
- """
- Synchronize the model file of the model instance.
- """
- if instance.worker_id is None:
- # Not scheduled yet
- return
- instance = await ModelInstance.one_by_id(
- session,
- instance.id,
- options=[
- selectinload(ModelInstance.model_files),
- ],
- )
- if not instance:
- return
- if len(instance.model_files) > 0:
- await sync_instance_files_state(session, instance, instance.model_files)
- return
- retry_model_files = []
- model_files = await get_or_create_model_files_for_instance(session, instance)
- draft_model_files = []
- if instance.draft_model_source:
- draft_model_files = await get_or_create_model_files_for_instance(
- session, instance, is_draft_model=True
- )
- for model_file in model_files + draft_model_files:
- if model_file.state == ModelFileStateEnum.ERROR:
- # Retry the download
- retry_model_files.append(model_file.readable_source)
- model_file.state = ModelFileStateEnum.DOWNLOADING
- model_file.download_progress = 0
- model_file.state_message = ""
- await model_file.update(session, auto_commit=False)
- if retry_model_files:
- await session.commit()
- logger.info(
- f"Retrying download for model files {retry_model_files} for model instance {instance.name}"
- )
- instance = await ModelInstance.one_by_id(session, instance.id)
- instance.model_files = model_files
- instance.draft_model_files = draft_model_files
- await sync_instance_files_state(session, instance, model_files + draft_model_files)
- async def get_or_create_model_files_for_instance(
- session: AsyncSession, instance: ModelInstance, is_draft_model: bool = False
- ) -> List[ModelFile]:
- """
- Get or create model files for the given model instance.
- If is_draft_model is True, get or create model files for the draft model.
- """
- model_files = await get_model_files_for_instance(session, instance, is_draft_model)
- worker_ids = _get_worker_ids_for_file_download(instance)
- # Return early if all model files are already created for the workers
- if len(model_files) == len(worker_ids):
- return model_files
- # Get the worker IDs that are missing model files.
- missing_worker_ids = set(worker_ids) - {
- model_file.worker_id for model_file in model_files
- }
- if not missing_worker_ids:
- return model_files
- model_source = instance
- if is_draft_model:
- model_source = instance.draft_model_source
- # Create model files for the missing worker IDs.
- for worker_id in missing_worker_ids:
- model_file = ModelFile(
- source=model_source.source,
- huggingface_repo_id=model_source.huggingface_repo_id,
- huggingface_filename=model_source.huggingface_filename,
- model_scope_model_id=model_source.model_scope_model_id,
- model_scope_file_path=model_source.model_scope_file_path,
- local_path=model_source.local_path,
- state=ModelFileStateEnum.DOWNLOADING,
- worker_id=worker_id,
- source_index=model_source.model_source_index,
- )
- await ModelFile.create(session, model_file)
- logger.info(
- f"Created model file for model instance {instance.name} and worker {worker_id}"
- )
- # After creating the model files, fetch them again to return the complete list.
- return await get_model_files_for_instance(session, instance, is_draft_model)
- async def get_model_files_for_instance(
- session: AsyncSession, instance: ModelInstance, is_draft_model: bool = False
- ) -> List[ModelFile]:
- """
- Get the model files for the given model instance.
- If draft_model is provided, get the model files for the draft model.
- """
- worker_ids = _get_worker_ids_for_file_download(instance)
- model_source: ModelSource = instance
- if is_draft_model:
- model_source = instance.draft_model_source
- model_files = await ModelFileService(session).get_by_source_index(
- model_source.model_source_index
- )
- model_files = [
- model_file for model_file in model_files if model_file.worker_id in worker_ids
- ]
- if model_source.source == SourceEnum.LOCAL_PATH and model_source.local_path:
- # If the source is local path, get the model files with the same local path.
- local_path_model_files = await ModelFileService(session).get_by_resolved_path(
- model_source.local_path
- )
- local_path_model_files = [
- model_file
- for model_file in local_path_model_files
- if model_file.worker_id in worker_ids
- ]
- existing_worker_ids = {mf.worker_id for mf in model_files}
- additional_files = [
- model_file
- for model_file in local_path_model_files
- if model_file.worker_id not in existing_worker_ids
- ]
- model_files.extend(additional_files)
- return model_files
- async def find_scale_down_candidates(
- instances: List[ModelInstance],
- model: Model,
- *,
- status_max_score: Optional[float] = None,
- offload_max_score: Optional[float] = None,
- placement_max_score: Optional[float] = None,
- total_max_score: Optional[float] = None,
- ) -> List[ModelInstanceScore]:
- try:
- if status_max_score is None:
- status_max_score = envs.SCHEDULER_SCALE_DOWN_STATUS_MAX_SCORE
- if offload_max_score is None:
- offload_max_score = envs.SCHEDULER_SCALE_DOWN_OFFLOAD_MAX_SCORE
- if placement_max_score is None:
- placement_max_score = envs.SCHEDULER_SCALE_DOWN_PLACEMENT_MAX_SCORE
- chain = ModelInstanceScoreChain(
- scorers=[
- StatusScorer(model, max_score=status_max_score),
- OffloadLayerScorer(model, max_score=offload_max_score),
- PlacementScorer(
- model,
- instances,
- scale_type=ScaleTypeEnum.SCALE_DOWN,
- max_score=placement_max_score,
- ),
- ],
- total_max_score=total_max_score,
- )
- final_candidates = await chain.score(instances)
- final_candidates = sorted(
- final_candidates, key=lambda x: x.score, reverse=False
- )
- return final_candidates
- except Exception as e:
- state_message = (
- f"Failed to find scale down candidates for model {model.name}: {e}"
- )
- logger.error(state_message)
- return []
- async def sync_ready_replicas(session: AsyncSession, model: Model):
- """
- Synchronize the ready replicas.
- """
- if model.deleted_at is not None:
- return
- instances = await ModelInstance.all_by_field(session, "model_id", model.id)
- ready_replicas: int = 0
- for _, instance in enumerate(instances):
- if instance.state == ModelInstanceStateEnum.RUNNING:
- ready_replicas += 1
- if model.ready_replicas != ready_replicas:
- model.ready_replicas = ready_replicas
- await ModelService(session).update(model)
- async def get_cluster_registry(
- session: AsyncSession, cluster_id: int
- ) -> Optional[McpBridgeRegistry]:
- cluster_user = await User.one_by_field(
- session=session,
- field="cluster_id",
- value=cluster_id,
- options=[selectinload(User.cluster)],
- )
- if is_default_cluster_user(cluster_user):
- return None
- cluster_registry = mcp_handler.cluster_registry(cluster_user.cluster)
- if cluster_registry is None:
- return None
- return cluster_registry
- async def sync_model_route_mapper(
- cfg: Config,
- extensions_api: ExtensionsHigressIoV1Api,
- ingress_name: str,
- route_name: str,
- destinations: mcp_handler.DestinationTupleList,
- fallback_destinations: mcp_handler.DestinationTupleList,
- ):
- """
- Synchronize the model route mapper.
- """
- ingress_prefix = f"{cfg.get_namespace()}/"
- if cfg.get_namespace() == cfg.gateway_namespace:
- ingress_prefix = ""
- model_name_to_registries: Dict[str, List[str]] = {}
- for _, model_name, registry in destinations:
- if route_name == model_name:
- # Skip self mapping
- continue
- registries = model_name_to_registries.setdefault(model_name, [])
- registries.append(registry.get_service_name())
- fallback_model_name_to_registries: Dict[str, List[str]] = {}
- for _, model_name, registry in fallback_destinations:
- registries = fallback_model_name_to_registries.setdefault(model_name, [])
- registries.append(registry.get_service_name())
- expected_rules = mcp_handler.get_expected_match_list(
- route_name=route_name,
- ingress_prefix=ingress_prefix,
- ingress_name=ingress_name,
- model_name_to_registries=model_name_to_registries,
- fallback_model_name_to_registries=fallback_model_name_to_registries,
- )
- def spec_diff(current_spec: Optional[WasmPluginSpec]) -> WasmPluginSpec:
- # the current spec must exist. If not, it means the plugin has been deleted manually,
- # we should not recreate it until next update event to avoid potential misconfiguration.
- if current_spec is None:
- return current_spec
- to_keep_rules: List[WasmPluginMatchRule] = []
- full_ingress_name = f"{ingress_prefix}{ingress_name}"
- for rule in current_spec.matchRules or []:
- if full_ingress_name not in rule.ingress:
- to_keep_rules.append(rule)
- to_keep_rules.extend(expected_rules)
- to_keep_rules.sort(key=lambda r: r.ingress[0] if r.ingress else "")
- current_spec.matchRules = to_keep_rules
- return current_spec
- await mcp_handler.ensure_wasm_plugin(
- api=extensions_api,
- name=mcp_handler.gpustack_model_mapper_name,
- namespace=cfg.gateway_namespace,
- spec_diff=spec_diff,
- )
- async def ensure_route_generic_transformer_config(
- cfg: Config,
- model_route: ModelRoute,
- effective_name: str,
- extensions_api: ExtensionsHigressIoV1Api,
- generic_proxy_enabled: bool,
- ):
- """
- Reconcile the single HeaderRule that maps /model/proxy/<route_id>/... to this
- route's x-higress-llm-model. When generic_proxy_enabled is False (generic proxy
- disabled or route deleted), the rule is removed and other routes are untouched.
- ``effective_name`` is the fully-qualified model name including the
- Org slug prefix (e.g. ``org1/qwen3-0.6b``) for non-platform Orgs;
- platform Org keeps the unprefixed ``model_route.name``.
- """
- operating_path_pattern = mcp_handler.build_generic_route_path_pattern(
- model_route.id
- )
- expected_header_rules: List[Dict[str, Any]] = []
- if generic_proxy_enabled:
- expected_header_rules.append(
- mcp_handler.build_generic_route_header_rule(model_route.id, effective_name)
- )
- await mcp_handler.ensure_wasm_plugin(
- api=extensions_api,
- name=mcp_handler.gpustack_generic_route_transformer_name,
- namespace=cfg.gateway_namespace,
- spec_diff=partial(
- mcp_handler.generic_route_transformer_diff_spec,
- expected_header_rules=expected_header_rules,
- operating_path_pattern=operating_path_pattern,
- ),
- )
- async def ensure_route_ai_proxy_config(
- cfg: Config,
- model_route_id: int,
- extensions_api: ExtensionsHigressIoV1Api,
- route_destinations: mcp_handler.DestinationTupleList,
- fallback_destinations: mcp_handler.DestinationTupleList,
- ):
- service_namespace_prefix = cfg.get_namespace() + "/"
- if cfg.get_namespace() == cfg.gateway_namespace:
- service_namespace_prefix = ""
- operating_id = mcp_handler.model_route_cleanup_prefix(model_route_id)
- ingress_name = mcp_handler.model_route_ingress_name(model_route_id)
- fallback_ingress_name = mcp_handler.fallback_ingress_name(ingress_name)
- expected_providers = []
- expected_match_rules = []
- # cross provider needs to configure ai_proxy
- unique_registry_services: Set[str] = set(
- registry.get_service_name()
- for _, _, registry in route_destinations
- if (not registry.name.startswith(mcp_handler.provider_id_prefix))
- )
- unique_fallback_registry_services: Set[str] = set(
- registry.get_service_name()
- for _, _, registry in fallback_destinations
- if (not registry.name.startswith(mcp_handler.provider_id_prefix))
- )
- if len(unique_registry_services) + len(unique_fallback_registry_services) > 0:
- expected_providers.append(
- mcp_handler.ai_proxy_openai_provider_config(operating_id)
- )
- if len(unique_registry_services) > 0:
- expected_match_rules.append(
- WasmPluginMatchRule(
- config={
- "activeProviderId": operating_id,
- },
- configDisable=False,
- service=list(unique_registry_services),
- ingress=[f"{service_namespace_prefix}{ingress_name}"],
- )
- )
- # same logic for fallback
- if len(unique_fallback_registry_services) > 0:
- expected_match_rules.append(
- WasmPluginMatchRule(
- config={
- "activeProviderId": operating_id,
- },
- configDisable=False,
- service=list(unique_fallback_registry_services),
- ingress=[f"{service_namespace_prefix}{fallback_ingress_name}"],
- )
- )
- await mcp_handler.ensure_wasm_plugin(
- api=extensions_api,
- name=mcp_handler.gpustack_ai_proxy_name,
- namespace=cfg.gateway_namespace,
- spec_diff=partial(
- mcp_handler.ai_proxy_diff_spec,
- expected_providers=expected_providers,
- expected_match_rules=expected_match_rules,
- operating_id_prefix=operating_id,
- ),
- )
- async def sync_gateway(
- session: AsyncSession,
- event: Event,
- cfg: Config,
- model_route: ModelRoute,
- networking_api: k8s_client.NetworkingV1Api,
- extensions_api: ExtensionsHigressIoV1Api,
- istio_networking_api: NetworkingIstioIoV1Alpha3Api,
- ):
- event_type = event.type
- model_route_from_db = await ModelRoute.one_by_id(
- session,
- model_route.id,
- options=[selectinload(ModelRoute.route_targets)],
- )
- targets: List[ModelRouteTarget] = (
- getattr(model_route_from_db, "route_targets", []) if model_route_from_db else []
- )
- has_fallback_target = any(
- target
- for target in targets
- if target.fallback_status_codes and len(target.fallback_status_codes) > 0
- )
- destinations = []
- fallback_destinations = []
- if not model_route_from_db:
- event_type = EventType.DELETED
- if event.type != EventType.DELETED:
- destinations, fallback_destinations = await calculate_destinations(
- session, model_route
- )
- # Effective model name = `<org-slug>/<route.name>` for non-platform
- # Orgs (so two Orgs can use the same `route.name` without colliding
- # in Higress's AI proxy match rules), unprefixed for the platform Org
- # (backward compatible for existing clients).
- route_owner = await Principal.one_by_id(session, model_route.owner_principal_id)
- effective_name = effective_route_name(
- model_route.name,
- getattr(route_owner, "slug", None),
- getattr(route_owner, "id", None) == PLATFORM_PRINCIPAL_ID,
- )
- ingress_name = mcp_handler.model_route_ingress_name(model_route.id)
- await sync_model_route_mapper(
- cfg=cfg,
- extensions_api=extensions_api,
- ingress_name=ingress_name,
- route_name=effective_name,
- destinations=destinations,
- fallback_destinations=fallback_destinations,
- )
- # FIXME: Copy the fallback destination to the main ingress for now to make sure the fallback
- # route is always hit when fallback is configured, even if the main route has no valid
- # destination. This is to avoid potential misconfiguration that causes the main route to
- # have no destination and the fallback route is not hit at all.
- await mcp_handler.ensure_model_ingress(
- ingress_class_name=cfg.gateway_ingress_class,
- event_type=event_type,
- ingress_name=ingress_name,
- route_name=effective_name,
- namespace=cfg.get_namespace(),
- destinations=destinations if len(destinations) > 0 else fallback_destinations,
- networking_api=networking_api,
- included_generic_route=False,
- included_proxy_route=model_route.generic_proxy,
- )
- fallback_event_type = event_type
- if not has_fallback_target:
- fallback_event_type = EventType.DELETED
- # Fallback ingress
- await mcp_handler.ensure_model_ingress(
- ingress_class_name=cfg.gateway_ingress_class,
- event_type=fallback_event_type,
- ingress_name=mcp_handler.fallback_ingress_name(ingress_name),
- route_name=effective_name,
- namespace=cfg.get_namespace(),
- destinations=fallback_destinations,
- networking_api=networking_api,
- included_generic_route=False,
- included_proxy_route=model_route.generic_proxy,
- extra_annotations=mcp_handler.higress_http_header_matcher(
- "exact", "x-higress-fallback-from", ingress_name
- ),
- )
- # Fallback filter
- await mcp_handler.ensure_fallback_filter(
- event_type=fallback_event_type,
- ingress_name=ingress_name,
- namespace=cfg.get_namespace(),
- networking_istio_api=istio_networking_api,
- )
- # Generic-route transformer: inject x-higress-llm-model when /model/proxy/<id>/
- # is hit, so the existing main ingress header matcher + fallback chain apply.
- await ensure_route_generic_transformer_config(
- cfg=cfg,
- model_route=model_route,
- effective_name=effective_name,
- extensions_api=extensions_api,
- generic_proxy_enabled=(
- event_type != EventType.DELETED and bool(model_route.generic_proxy)
- ),
- )
- # ensure ai proxy config
- await ensure_route_ai_proxy_config(
- cfg=cfg,
- model_route_id=model_route.id,
- extensions_api=extensions_api,
- route_destinations=destinations,
- fallback_destinations=fallback_destinations,
- )
- def flatten_destinations(
- weight_to_count: List[Tuple[int, int, mcp_handler.DestinationTupleList]],
- max_weight: Optional[int] = 0,
- ) -> mcp_handler.DestinationTupleList:
- persentage_list = mcp_handler.hamilton_calculate_weight(
- [(weight, count) for weight, count, _ in weight_to_count],
- max_weight=max_weight,
- )
- flatten_registry_list: mcp_handler.DestinationTupleList = []
- index = 0
- for _, _, registry_list_part in weight_to_count:
- for count, model_name, registry in registry_list_part:
- total_percentage = sum(persentage_list[index : index + count])
- index += count
- if total_percentage != 0:
- flatten_registry_list.append((total_percentage, model_name, registry))
- return flatten_registry_list
- async def calculate_destinations(
- session: AsyncSession,
- model_route: ModelRoute,
- ) -> Tuple[mcp_handler.DestinationTupleList, mcp_handler.DestinationTupleList]:
- """
- return persentage Tuple for each registry with model name and the fallback registry
- """
- weight_to_count: List[Tuple[int, int, mcp_handler.DestinationTupleList]] = []
- fallback_weight_to_count: List[
- Tuple[int, int, mcp_handler.DestinationTupleList]
- ] = []
- targets = await ModelRouteTarget.all_by_field(session, "route_id", model_route.id)
- for target in targets:
- if target.state != TargetStateEnum.ACTIVE:
- continue
- to_extend: mcp_handler.DestinationTupleList = []
- if target.model_id is not None:
- model = await Model.one_by_id(session, target.model_id)
- if model is None:
- continue
- to_extend = await calculate_model_destinations(session, model)
- elif target.provider_id is not None:
- to_extend = await provider_destinations(
- session=session,
- provider_id=target.provider_id,
- provider_model_name=target.provider_model_name,
- )
- if to_extend is None or len(to_extend) == 0:
- # no valid destination found
- continue
- count = sum([count for count, _, _ in to_extend])
- weight_to_count.append((target.weight, count, to_extend))
- if (
- target.fallback_status_codes is not None
- and len(target.fallback_status_codes) > 0
- ):
- fallback_weight_to_count.append((target.weight, count, to_extend))
- if len(weight_to_count) == 0:
- return [], []
- flatten_registry_list = flatten_destinations(weight_to_count)
- fallback_registry_list = []
- if len(fallback_weight_to_count) > 0:
- # fallback might have 0 weight, so set max_weight to 1
- fallback_registry_list = flatten_destinations(
- fallback_weight_to_count, max_weight=1
- )
- return flatten_registry_list, fallback_registry_list
- async def provider_destinations(
- session: AsyncSession,
- provider_id: int,
- provider_model_name: str,
- ) -> mcp_handler.DestinationTupleList:
- """
- return count dict for provider registry
- """
- provider = await ModelProvider.one_by_id(session, provider_id)
- if provider is None:
- return []
- return [(1, provider_model_name, mcp_handler.provider_registry(provider))]
- async def calculate_model_destinations(
- session: AsyncSession,
- model: Model,
- ) -> mcp_handler.DestinationTupleList:
- """
- return count dict for each registry
- """
- # find out is handling default cluster's model
- cluster_registry = await get_cluster_registry(session, model.cluster_id)
- if cluster_registry is not None:
- return [(1, model.name, cluster_registry)]
- instances = await ModelInstance.all_by_field(session, "model_id", model.id)
- instances = [
- instance
- for instance in instances
- if instance.worker_ip is not None
- and instance.port is not None
- and instance.worker_ip != ""
- and instance.state == ModelInstanceStateEnum.RUNNING
- ]
- worker_list = await Worker.all_by_fields(
- session=session,
- fields={
- "cluster_id": model.cluster_id,
- "deleted_at": None,
- },
- extra_conditions=[
- Worker.id.in_(
- [
- instance.worker_id
- for instance in instances
- if instance.worker_id is not None
- ]
- )
- ],
- )
- workers = {worker.id: worker for worker in worker_list}
- registry_list = mcp_handler.model_instances_registry_list(instances, workers)
- return registry_list
- class WorkerController:
- def __init__(self, cfg: Config):
- self._provisioning = WorkerProvisioningController(cfg)
- async def start(self):
- """
- Start the controller.
- """
- async for event in Worker.subscribe(source="worker_controller"):
- if event.type == EventType.HEARTBEAT:
- continue
- try:
- await self._reconcile(event)
- await self._provisioning._reconcile(event)
- await self._notify_relatives(event)
- except Exception as e:
- logger.error(f"Failed to reconcile worker: {e}")
- async def _reconcile(self, event: Event):
- """
- Delete instances base on the worker state and event type.
- """
- if event.type not in (EventType.UPDATED, EventType.DELETED):
- return
- worker: Worker = event.data
- if not worker:
- return
- if worker.state.is_provisioning and worker.state != WorkerStateEnum.DELETING:
- # Skip reconciliation for provisioning and deleting workers.
- # There is a dedicated controller to handle provisioning.
- return
- if event.type == EventType.UPDATED:
- changed_fields = event.changed_fields
- if not changed_fields or "state" not in changed_fields:
- # No state change
- return
- async with async_session() as session:
- all_instances = await ModelInstance.all_by_field(
- session, "cluster_id", worker.cluster_id
- )
- if not all_instances:
- return
- matched_instances = []
- for instance in all_instances:
- match = get_model_instance_worker_match(
- instance,
- worker_name=worker.name,
- worker_id=worker.id,
- )
- if match.matched:
- matched_instances.append((instance, match))
- if not matched_instances:
- return
- if event.type == EventType.DELETED:
- instance_names = await ModelInstanceService(session).batch_delete(
- [instance for instance, _ in matched_instances]
- )
- if instance_names:
- logger.info(
- f"Delete instance {', '.join(instance_names)} "
- f"since worker {worker.name} is deleted"
- )
- return
- if (
- worker.unreachable
- or worker.state == WorkerStateEnum.UNREACHABLE
- or worker.state == WorkerStateEnum.NOT_READY
- ):
- await self.update_impacted_instance_states_to_unreachable(
- session,
- matched_instances,
- worker.name,
- )
- return
- async def update_impacted_instance_states_to_unreachable(
- self,
- session,
- matched_instances,
- worker_name,
- ):
- instance_names = set()
- subordinate_worker_names = set()
- for instance, match in matched_instances:
- patch = {}
- distributed_servers_changed = False
- if (
- match.is_main_worker
- and instance.state == ModelInstanceStateEnum.RUNNING
- ):
- patch["state"] = ModelInstanceStateEnum.UNREACHABLE
- patch["state_message"] = "Worker is unreachable from the server"
- instance_names.add(instance.name)
- for index in match.subordinate_worker_indexes:
- subordinate_worker = instance.distributed_servers.subordinate_workers[
- index
- ]
- if subordinate_worker.state == ModelInstanceStateEnum.UNREACHABLE:
- continue
- subordinate_worker.state = ModelInstanceStateEnum.UNREACHABLE
- subordinate_worker.state_message = (
- "Worker is unreachable from the server"
- )
- subordinate_worker_names.add(
- f"{instance.name}:{subordinate_worker.worker_name}"
- )
- distributed_servers_changed = True
- if distributed_servers_changed:
- patch["distributed_servers"] = instance.distributed_servers
- flag_modified(instance, "distributed_servers")
- if patch:
- await ModelInstanceService(session).update(instance, patch)
- if instance_names:
- logger.info(
- f"Marked instance {', '.join(instance_names)} unreachable "
- f"since worker {worker_name} is unreachable from the server"
- )
- if subordinate_worker_names:
- logger.info(
- f"Marked subordinate workers {', '.join(subordinate_worker_names)} unreachable "
- f"since worker {worker_name} is unreachable from the server"
- )
- async def _notify_relatives(self, event: Event):
- if event.type not in (EventType.UPDATED, EventType.DELETED):
- return
- worker: Worker = event.data
- changed_fields = event.changed_fields
- if not worker or (not changed_fields and event.type != EventType.DELETED):
- return
- state_changed: Optional[Tuple[Any, Any]] = (changed_fields or {}).get(
- "state", None
- )
- proxy_mode_changed: Optional[Tuple[Any, Any]] = (changed_fields or {}).get(
- "proxy_mode", None
- )
- should_notify_parents = (
- state_changed is not None
- or proxy_mode_changed is not None
- or event.type == EventType.DELETED
- )
- proxy_address_changed: Optional[Tuple[Any, Any]] = (changed_fields or {}).get(
- "proxy_address", None
- )
- should_notify_children = (
- proxy_address_changed is not None or proxy_mode_changed is not None
- )
- if not should_notify_parents and not should_notify_children:
- return
- async with async_session() as session:
- if should_notify_parents and worker.worker_pool_id is not None:
- worker_pool = await WorkerPool.one_by_id(
- session,
- worker.worker_pool_id,
- options=[selectinload(WorkerPool.pool_workers)],
- )
- if worker_pool is not None:
- copied_pool = WorkerPool(**worker_pool.model_dump())
- await event_bus.publish(
- copied_pool.__class__.__name__.lower(),
- Event(
- type=EventType.UPDATED,
- data=copied_pool,
- ),
- )
- if should_notify_parents and worker.cluster_id is not None:
- cluster = await Cluster.one_by_id(
- session,
- worker.cluster_id,
- options=[
- selectinload(Cluster.cluster_workers),
- selectinload(Cluster.cluster_models),
- ],
- )
- if cluster is not None:
- copied_cluster = Cluster(**cluster.model_dump())
- await event_bus.publish(
- copied_cluster.__class__.__name__.lower(),
- Event(
- type=EventType.UPDATED,
- data=copied_cluster,
- ),
- )
- if should_notify_children:
- instances = await ModelInstance.all_by_fields(
- session,
- fields={"worker_id": worker.id},
- options=[selectinload(ModelInstance.model)],
- )
- notified_model = set()
- for instance in instances:
- if instance.model_id in notified_model:
- continue
- notified_model.add(instance.model_id)
- copied_model = Model(**instance.model.model_dump())
- await event_bus.publish(
- copied_model.__class__.__name__.lower(),
- Event(
- type=EventType.UPDATED,
- data=copied_model,
- ),
- )
- class InferenceBackendController:
- """
- Inference backend controller initializes built-in and community backends in the database.
- """
- async def start(self):
- async with async_session() as session:
- # Initialize built-in backends
- await self._init_built_in_backends(session)
- # Initialize community backends
- await self._init_community_backends(session)
- async def _init_built_in_backends(self, session: AsyncSession):
- """Initialize built-in backends in the database."""
- for built_in_backend in get_built_in_backend():
- if built_in_backend.backend_name == BackendEnum.CUSTOM.value:
- continue
- # Built-in backends always seed as Platform (owner_principal_id IS NULL).
- # Per-Org overrides live in additional rows created by Org owners /
- # managers; those are managed via the inference_backend routes.
- backend = await InferenceBackend.one_by_fields(
- session,
- {
- "backend_name": built_in_backend.backend_name,
- "owner_principal_id": None,
- },
- )
- if not backend:
- # Create new built-in backend with backend_source
- built_in_backend.backend_source = BackendSourceEnum.BUILT_IN
- built_in_backend.enabled = True
- await InferenceBackend.create(session, built_in_backend)
- logger.info(
- f"Init built-in backend {built_in_backend.backend_name} in database"
- )
- elif backend.backend_source is None:
- # Update existing backend without backend_source
- backend.backend_source = BackendSourceEnum.BUILT_IN
- if backend.enabled is None:
- backend.enabled = True
- await backend.update(
- session,
- {
- "backend_source": BackendSourceEnum.BUILT_IN,
- "enabled": (
- backend.enabled if backend.enabled is not None else True
- ),
- },
- )
- logger.info(
- f"Updated backend_source for existing built-in backend {backend.backend_name}"
- )
- async def _init_community_backends(self, session: AsyncSession): # noqa: C901
- """Load community backends from community-inference-backends.yaml into database."""
- try:
- # Get the path to community-inference-backends.yaml
- yaml_file = files("gpustack.assets").joinpath(
- "community-inference-backends.yaml"
- )
- if not yaml_file.is_file():
- logger.debug(
- "community-inference-backends.yaml not found, skipping community backend initialization"
- )
- return
- yaml_data = yaml.safe_load(yaml_file.read_text())
- if not yaml_data:
- logger.debug(
- "No community backends found in community-inference-backends.yaml"
- )
- return
- if not isinstance(yaml_data, list):
- logger.error(
- f"Invalid community-inference-backends.yaml format: expected list, got {type(yaml_data).__name__}"
- )
- return
- # Collect backend names from YAML
- yaml_backend_names = set()
- for backend_config in yaml_data:
- backend_name = backend_config.get("backend_name")
- if backend_name:
- yaml_backend_names.add(backend_name)
- await self._upsert_community_backend(session, backend_config)
- # Query all community backends from database. Only Platform
- # rows are owned by the catalog yaml; Org-private community
- # additions stay untouched.
- all_backends = await InferenceBackend.all(session)
- db_community_backends = [
- backend
- for backend in all_backends
- if backend.backend_source == BackendSourceEnum.COMMUNITY
- and backend.owner_principal_id is None
- ]
- # Delete community backends that are no longer in YAML
- for backend in db_community_backends:
- if backend.backend_name in yaml_backend_names:
- continue
- if backend.enabled:
- # Convert to custom backend to preserve user's custom versions
- # Convert all built_in_frameworks versions to custom_framework versions
- converted_versions = {}
- if backend.version_configs and backend.version_configs.root:
- for version, config in backend.version_configs.root.items():
- config_data = config.model_dump()
- if config_data.get("built_in_frameworks"):
- config_data["custom_framework"] = config_data[
- "built_in_frameworks"
- ][0]
- config_data["built_in_frameworks"] = None
- converted_versions[version] = VersionConfig(**config_data)
- # Prepare update data
- update_data = {
- "backend_source": BackendSourceEnum.CUSTOM,
- "enabled": False,
- "version_configs": VersionConfigDict(root=converted_versions),
- }
- flag_modified(backend, "version_configs")
- await backend.update(session, update_data)
- logger.info(
- f"Converted community backend '{backend.backend_name}' to custom backend"
- )
- else:
- # Delete if no custom versions
- await backend.delete(session)
- logger.info(
- f"Deleted community backend '{backend.backend_name}' "
- f"(no longer in community-inference-backends.yaml)"
- )
- logger.debug(
- "Community backends initialized from community-inference-backends.yaml"
- )
- except (ModuleNotFoundError, FileNotFoundError):
- # community_backends directory or yaml file does not exist
- logger.debug(
- "Community backends directory or file not found, skipping initialization"
- )
- except Exception as e:
- logger.error(f"Failed to initialize community backends: {e}")
- async def _upsert_community_backend(self, session: AsyncSession, config: dict):
- """Create or update a community backend from YAML configuration."""
- backend_name = config.get("backend_name")
- if not backend_name:
- return
- # Prepare backend data
- allowed_keys = [
- "backend_name",
- "version_configs",
- "default_version",
- "default_backend_param",
- "default_run_command",
- "default_entrypoint",
- "health_check_path",
- "description",
- "icon",
- "default_env",
- ]
- backend_data = {k: config[k] for k in allowed_keys if k in config}
- # Set backend source
- backend_data["backend_source"] = BackendSourceEnum.COMMUNITY
- backend_data["enabled"] = False
- # Convert version_configs to VersionConfigDict
- if 'version_configs' in backend_data and backend_data['version_configs']:
- version_config_dict = {}
- for version, ver_config in backend_data['version_configs'].items():
- # All versions loaded from YAML are predefined versions
- # Convert framework information to built_in_frameworks
- frameworks = None
- if 'built_in_frameworks' in ver_config:
- frameworks = ver_config['built_in_frameworks']
- elif (
- 'custom_framework' in ver_config and ver_config['custom_framework']
- ):
- # Even if YAML uses custom_framework, convert it to built_in_frameworks
- frameworks = [ver_config['custom_framework']]
- # Set built_in_frameworks and clear custom_framework
- if frameworks:
- ver_config['built_in_frameworks'] = (
- frameworks if isinstance(frameworks, list) else [frameworks]
- )
- else:
- # If no framework specified, use empty list to mark as predefined version
- ver_config['built_in_frameworks'] = []
- # Ensure custom_framework is None (predefined versions should not have custom_framework)
- ver_config['custom_framework'] = None
- version_config_dict[version] = VersionConfig(**ver_config)
- backend_data['version_configs'] = VersionConfigDict(
- root=version_config_dict
- )
- # Upsert: update if exists, create if not. Community backends seed
- # at the Platform scope (owner_principal_id IS NULL) — Org-private
- # extensions live in additional rows owned by Orgs.
- existing = await InferenceBackend.one_by_fields(
- session, {"backend_name": backend_name, "owner_principal_id": None}
- )
- if existing:
- # Smart merge logic to preserve user customizations
- # 1. Merge version_configs: preserve user custom versions, update YAML versions
- if 'version_configs' in backend_data and backend_data['version_configs']:
- yaml_versions = backend_data['version_configs'].root
- existing_versions = (
- existing.version_configs.root if existing.version_configs else {}
- )
- # Create merged version dictionary
- merged_versions = {}
- # First add all YAML versions (overwrite old versions with same name)
- for version, config in yaml_versions.items():
- merged_versions[version] = config
- # Then add user custom versions (built_in_frameworks is None)
- for version, config in existing_versions.items():
- if (
- config.built_in_frameworks is None
- and version not in yaml_versions
- ):
- # This is a user custom version not in YAML, preserve it
- merged_versions[version] = config
- backend_data['version_configs'] = VersionConfigDict(
- root=merged_versions
- )
- # 2. Preserve user-modified enabled status (if user enabled it, don't reset to False)
- if existing.enabled:
- backend_data['enabled'] = True
- # 3. Merge default_env (preserve user-added environment variables)
- if existing.default_env:
- if 'default_env' in backend_data and backend_data['default_env']:
- # Merge: YAML environment variables + user-added environment variables
- merged_env = dict(existing.default_env)
- merged_env.update(backend_data['default_env'])
- backend_data['default_env'] = merged_env
- else:
- # YAML doesn't define it, preserve user's
- backend_data['default_env'] = existing.default_env
- # 4. Update database
- await existing.update(session, backend_data)
- else:
- backend = InferenceBackend(**backend_data)
- await InferenceBackend.create(session, backend)
- class ModelFileController:
- """
- Model file controller syncs the model file download status to related model instances.
- """
- async def start(self):
- """
- Start the controller.
- """
- async for event in ModelFile.subscribe(source="model_file_controller"):
- if event.type == EventType.CREATED or event.type == EventType.UPDATED:
- await self._reconcile(event)
- async def _reconcile(self, event: Event):
- """
- Reconcile the model file.
- """
- file: ModelFile = event.data
- try:
- async with async_session() as session:
- file = await ModelFile.one_by_id(
- session,
- file.id,
- options=[
- selectinload(ModelFile.instances),
- selectinload(ModelFile.draft_instances),
- ],
- )
- if not file:
- # In case the file is deleted
- return
- for instance in file.instances + file.draft_instances:
- async with async_session() as session:
- await sync_instance_files_state(session, instance, [file])
- except Exception as e:
- logger.error(f"Failed to reconcile model file {file.id}: {e}")
- async def sync_instance_files_state(
- session: AsyncSession, instance: ModelInstance, files: List[ModelFile]
- ):
- for file in files:
- if file.worker_id == instance.worker_id:
- is_draft_model = _is_draft_model_file(file, instance)
- if is_draft_model:
- await sync_main_worker_model_file_state(
- session, file, instance, is_draft_model=True
- )
- else:
- await sync_main_worker_model_file_state(session, file, instance)
- else:
- await sync_distributed_model_file_state(session, file, instance)
- def _is_draft_model_file(file: ModelFile, instance: ModelInstance) -> bool:
- """
- Check if the model file is the draft model file for the given model instance.
- """
- if not instance.draft_model_source:
- return False
- if file.model_source_index == instance.draft_model_source.model_source_index:
- return True
- # The model uses a local path as its draft source, but the model file may come from a remote source.
- # Match by resolved path.
- if (
- instance.draft_model_source.source == SourceEnum.LOCAL_PATH
- and file.resolved_paths
- and file.resolved_paths[0] == instance.draft_model_source.local_path
- ):
- return True
- return False
- async def sync_main_worker_model_file_state(
- session: AsyncSession,
- file: ModelFile,
- instance: ModelInstance,
- is_draft_model: bool = False,
- ):
- """
- Sync the model file state to the related model instance.
- """
- if instance.state == ModelInstanceStateEnum.ERROR:
- return
- logger.trace(
- f"Syncing model file {file.id} with model instance {instance.id}, file state: {file.state}, "
- f"progress: {file.download_progress}, message: {file.state_message}, instance state: {instance.state}"
- )
- need_update = False
- # Downloading
- if file.state == ModelFileStateEnum.DOWNLOADING:
- if instance.state == ModelInstanceStateEnum.INITIALIZING:
- # Download started
- instance.state = ModelInstanceStateEnum.DOWNLOADING
- instance.download_progress = 0
- instance.state_message = ""
- need_update = True
- elif instance.state == ModelInstanceStateEnum.DOWNLOADING:
- # Update download progress
- if (
- is_draft_model
- and file.download_progress != instance.draft_model_download_progress
- and instance.draft_model_download_progress != 100
- ):
- # For the draft model file
- instance.draft_model_download_progress = file.download_progress
- need_update = True
- elif (
- file.download_progress != instance.download_progress
- and instance.download_progress != 100
- ):
- # For the main model file
- instance.download_progress = file.download_progress
- need_update = True
- # Download completed
- elif file.state == ModelFileStateEnum.READY and (
- instance.state == ModelInstanceStateEnum.DOWNLOADING
- or instance.state == ModelInstanceStateEnum.INITIALIZING
- ):
- if is_draft_model and (
- instance.draft_model_download_progress != 100
- or not instance.draft_model_resolved_path
- ):
- # Download completed for the draft model file
- instance.draft_model_download_progress = 100
- instance.draft_model_resolved_path = file.resolved_paths[0]
- need_update = True
- elif not is_draft_model and (
- instance.download_progress != 100 or not instance.resolved_path
- ):
- # Download completed for the main model file
- instance.download_progress = 100
- instance.resolved_path = file.resolved_paths[0]
- need_update = True
- if model_instance_download_completed(instance):
- # All files are downloaded
- instance.state = ModelInstanceStateEnum.STARTING
- instance.state_message = ""
- need_update = True
- elif instance.state == ModelInstanceStateEnum.INITIALIZING:
- # one but not all files downloaded, turn to DOWNLOADING state
- instance.state = ModelInstanceStateEnum.DOWNLOADING
- instance.state_message = ""
- need_update = True
- # Download error
- elif file.state == ModelFileStateEnum.ERROR:
- instance.state = ModelInstanceStateEnum.ERROR
- instance.state_message = file.state_message
- need_update = True
- if need_update:
- await ModelInstanceService(session).update(instance)
- async def sync_distributed_model_file_state( # noqa: C901
- session: AsyncSession, file: ModelFile, instance: ModelInstance
- ):
- """
- Sync the model file state to the related model instance.
- """
- if instance.state == ModelInstanceStateEnum.ERROR:
- return
- if (
- not instance.distributed_servers
- or not instance.distributed_servers.download_model_files
- ):
- return
- logger.trace(
- f"Syncing distributed model file {file.id} with model instance {instance.name}, file state: {file.state}, "
- f"progress: {file.download_progress}, message: {file.state_message}, instance state: {instance.state}"
- )
- need_update = False
- for item in instance.distributed_servers.subordinate_workers or []:
- if item.worker_id == file.worker_id:
- if (
- file.state == ModelFileStateEnum.DOWNLOADING
- and file.download_progress != item.download_progress
- ):
- item.download_progress = file.download_progress
- need_update = True
- elif (
- file.state == ModelFileStateEnum.READY and item.download_progress != 100
- ):
- item.download_progress = 100
- if model_instance_download_completed(instance):
- # All files are downloaded
- instance.state = ModelInstanceStateEnum.STARTING
- instance.state_message = ""
- need_update = True
- elif file.state == ModelFileStateEnum.ERROR:
- instance.state = ModelInstanceStateEnum.ERROR
- instance.state_message = file.state_message
- need_update = True
- if need_update:
- flag_modified(instance, "distributed_servers")
- await ModelInstanceService(session).update(instance)
- def model_instance_download_completed(instance: ModelInstance):
- if instance.download_progress != 100:
- return False
- if instance.draft_model_source and instance.draft_model_download_progress != 100:
- return False
- if (
- instance.distributed_servers
- and instance.distributed_servers.download_model_files
- ):
- for subworker in instance.distributed_servers.subordinate_workers or []:
- if subworker.download_progress != 100:
- return False
- return True
- def _get_worker_ids_for_file_download(
- instance: ModelInstance,
- ) -> List[str]:
- """
- Get the all worker IDs of the model instance that are
- responsible for downloading the model files,
- including the main worker and distributed workers.
- """
- worker_ids = [instance.worker_id] if instance.worker_id else []
- if (
- instance.distributed_servers
- and instance.distributed_servers.download_model_files
- ):
- worker_ids += [
- item.worker_id
- for item in instance.distributed_servers.subordinate_workers or []
- if item.worker_id
- ]
- return worker_ids
- async def new_workers_from_pool(
- session: AsyncSession, pool: WorkerPool
- ) -> List[Worker]:
- fields = {"deleted_at": None, "worker_pool_id": pool.id}
- current_workers = await Worker.all_by_fields(session, fields=fields)
- current_workers = [
- worker
- for worker in current_workers
- if worker.state not in [WorkerStateEnum.DELETING]
- ]
- # if has enough workers, no need to create more
- if len(current_workers) >= pool.replicas:
- return []
- delta = pool.replicas - len(current_workers)
- if pool.batch_size is not None and delta > pool.batch_size:
- delta = pool.batch_size
- provisioning_workers = [
- worker
- for worker in current_workers
- if worker.state in [WorkerStateEnum.PROVISIONING]
- ]
- # if has enough provisioning workers, no need to create more
- if pool.batch_size <= len(provisioning_workers):
- return []
- new_workers = []
- for _ in range(delta):
- new_worker = Worker(
- hostname="",
- ip="",
- ifname="",
- port=0,
- worker_uuid="",
- cluster=pool.cluster,
- worker_pool=pool,
- provider=pool.cluster.provider,
- name=f"pool-{pool.id}-"
- + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)),
- labels={
- "provider": pool.cluster.provider.value,
- "instance_type": pool.instance_type or "unknown",
- **pool.labels,
- },
- state=WorkerStateEnum.PENDING,
- status=WorkerStatus.get_default_status(),
- )
- new_workers.append(new_worker)
- return new_workers
- class WorkerPoolController:
- """Worker pool controller creates new workers based on the worker pool configuration."""
- async def start(self):
- async for event in WorkerPool.subscribe(source="worker_pool_controller"):
- if event.type == EventType.HEARTBEAT:
- continue
- try:
- await self._reconcile(event)
- except Exception as e:
- logger.error(f"Failed to reconcile worker pool: {e}")
- async def _reconcile(self, event: Event):
- """
- Reconcile the worker pool state with the current event.
- """
- logger.info(f"Reconcile worker pool {event.data.id} with event {event.type}")
- async with async_session() as session:
- pool = await WorkerPool.one_by_id(
- session, event.data.id, options=[selectinload(WorkerPool.cluster)]
- )
- if pool is None or pool.deleted_at is not None:
- return
- # mark the data to avoid read after commit
- cluster_name = pool.cluster.name
- cluster = pool.cluster
- pool_id = pool.id
- workers = await new_workers_from_pool(session, pool)
- if len(workers) == 0:
- return
- ids = []
- for worker in workers:
- created_worker: Worker = await Worker.create(
- session=session, source=worker, auto_commit=False
- )
- ids.append(created_worker.id)
- if cluster.state == ClusterStateEnum.PENDING:
- cluster.state = ClusterStateEnum.PROVISIONING
- cluster.state_message = None
- await cluster.update(session=session, auto_commit=False)
- await session.commit()
- logger.info(
- f"Created {len(ids)} new workers {ids} for cluster {cluster_name} worker pool {pool_id}"
- )
- class WorkerProvisioningController:
- def __init__(self, cfg: Config):
- self._cfg = cfg
- @classmethod
- async def _create_ssh_key(
- cls,
- session: AsyncSession,
- client: ProviderClientBase,
- worker: Worker,
- ) -> int:
- """
- Generate a new ssh key pair,
- And Create ssh_key in cloud provider.
- Create SSHKey record without commit and returns it.
- """
- logger.info(f"Creating ssh key for worker {worker.name}")
- private_key, public_key = generate_ssh_key_pair()
- ssh_key = Credential(
- credential_type=CredentialType.SSH,
- public_key=public_key,
- encoded_private_key=private_key,
- ssh_key_options=SSHKeyOptions(
- algorithm="ED25519",
- length=0,
- ),
- )
- ssh_key_id = await client.create_ssh_key(worker.name, public_key)
- ssh_key.external_id = str(ssh_key_id)
- ssh_key_rtn = await Credential.create(session, ssh_key, auto_commit=False)
- return ssh_key_rtn.id
- @classmethod
- async def _create_instances(
- cls,
- session: AsyncSession,
- client: ProviderClientBase,
- worker: Worker,
- cfg: Config,
- ) -> str:
- secret_fields = set(SensitivePredefinedConfig.model_fields.keys())
- secret_configs = (
- worker.cluster.worker_config.model_dump(include=secret_fields)
- if worker.cluster.worker_config
- else {}
- )
- user_data = await client.construct_user_data(
- server_url=worker.cluster.server_url or cfg.server_external_url,
- token=worker.cluster.registration_token,
- image_name=get_cluster_image_name(worker.cluster.worker_config),
- os_image=worker.worker_pool.os_image,
- secret_configs=secret_configs,
- worker_name=worker.name,
- )
- ssh_key = await Credential.one_by_id(session, worker.ssh_key_id)
- if ssh_key is None:
- raise ValueError(f"SSH key {worker.ssh_key_id} not found")
- to_create = construct_cloud_instance(worker, ssh_key, user_data.format())
- logger.info(f"Creating cloud instance for worker {worker.name}")
- logger.debug(f"Cloud instance configuration: {to_create}")
- return await client.create_instance(to_create)
- @classmethod
- async def _provisioning_started(
- cls,
- session: AsyncSession,
- client: ProviderClientBase,
- worker: Worker,
- instance: CloudInstance,
- ) -> bool:
- changed = True
- provider_config = worker.provider_config or {}
- volumes = list(
- (getattr(worker.worker_pool.cloud_options, "volumes", None) or [])
- )
- volume_ids = provider_config.get("volume_ids", [])
- if worker.advertise_address is None or worker.advertise_address == "":
- try:
- instance = await client.wait_for_public_ip(worker.external_id)
- worker.advertise_address = (
- instance.ip_address if instance.ip_address else ""
- )
- worker.state_message = "Waiting for volumes to attach"
- except Exception as e:
- logger.warning(
- f"Failed to wait for instance {worker.external_id} to get public ip: {e}"
- )
- elif len(volumes) != len(volume_ids) and len(volumes) > 0:
- volume_ids = await client.create_volumes_and_attach(
- worker.id, worker.external_id, worker.cluster.region, *volumes
- )
- provider_config["volume_ids"] = volume_ids
- worker.provider_config = provider_config
- elif (
- len(volumes) == len(volume_ids)
- and worker.state == WorkerStateEnum.PROVISIONING
- ):
- if not hasattr(provider_config, "volume_ids"):
- provider_config["volume_ids"] = []
- worker.provider_config = provider_config
- worker.state = WorkerStateEnum.INITIALIZING
- if worker.cluster.state != ClusterStateEnum.PROVISIONED:
- worker.cluster.state = ClusterStateEnum.PROVISIONED
- await worker.cluster.update(session=session, auto_commit=False)
- worker.state_message = "Initializing: installing required drivers and software. The worker will start automatically after setup."
- else:
- changed = False
- return changed
- @classmethod
- async def _provisioning_before_started(
- cls,
- session: AsyncSession,
- client: ProviderClientBase,
- worker: Worker,
- cfg: Config,
- ) -> Tuple[Optional[CloudInstance], bool]:
- """
- return started and changed
- """
- instance = None
- changed = False
- if worker.external_id is not None:
- instance = await client.get_instance(worker.external_id)
- # TODO should handle instance not exist problem
- if instance is None or instance.status == InstanceState.RUNNING:
- return instance, changed
- changed = True
- if worker.state == WorkerStateEnum.PENDING:
- worker.state = WorkerStateEnum.PROVISIONING
- worker.state_message = "Creating SSH key"
- elif worker.ssh_key_id is None:
- worker.ssh_key_id = await cls._create_ssh_key(session, client, worker)
- worker.state_message = "Creating cloud instance"
- elif worker.external_id is None:
- worker.external_id = await cls._create_instances(
- session, client, worker, cfg
- )
- worker.state_message = "Waiting for cloud instance started"
- elif worker.external_id is not None:
- try:
- # depress the timeout exception
- instance = await client.wait_for_started(worker.external_id)
- worker.state_message = "Waiting for instance's public ip"
- except Exception as e:
- logger.warning(
- f"Failed to wait for instance {worker.external_id} to start: {e}"
- )
- return instance, changed
- @classmethod
- async def _provisioning_instance(
- cls,
- session: AsyncSession,
- client: ProviderClientBase,
- worker: Worker,
- cfg: Config,
- ):
- # provider_config = worker.provider_config or {}
- # Phase I is to ensure instance running.
- instance, changed = await cls._provisioning_before_started(
- session, client, worker, cfg
- )
- if (
- not changed
- and instance is not None
- and instance.status == InstanceState.RUNNING
- ):
- # Phase II is to wait for instance infomation and attach volume.
- changed = await cls._provisioning_started(session, client, worker, instance)
- if changed:
- await WorkerService(session).update(
- worker=worker, source=None, auto_commit=False
- )
- @classmethod
- async def _deleting_instance(
- cls,
- session: AsyncSession,
- client: ProviderClientBase,
- worker: Worker,
- ):
- if worker.external_id is None:
- return
- ssh_key = await Credential.one_by_id(session, worker.ssh_key_id)
- try:
- await client.delete_instance(worker.external_id)
- if ssh_key and ssh_key.external_id:
- await client.delete_ssh_key(ssh_key.external_id)
- except Exception as e:
- logger.error(f"Failed to delete instance {worker.external_id}: {e}")
- # if using soft delete here, skip deletion and remove external_id
- if ssh_key:
- await ssh_key.delete(session, auto_commit=False)
- if worker.deleted_at is not None:
- await WorkerService(session).delete(worker, auto_commit=False)
- async def check_server_external_url(self, cluster_server_url: Optional[str] = None):
- server_url = cluster_server_url or self._cfg.server_external_url
- if server_url is None or server_url == "":
- raise ValueError(
- "Cluster's server_url is not configured, Please edit cluster first."
- )
- import aiohttp
- from yarl import URL
- healthz_url = str(URL(server_url) / "healthz")
- try:
- async with aiohttp.ClientSession() as session:
- async with session.get(healthz_url, timeout=10) as resp:
- if resp.status != 200:
- raise ValueError(
- f"External server healthz url {healthz_url} is not reachable, status code: {resp.status}"
- )
- except Exception as e:
- raise ValueError(
- f"Failed to check external server healthz url {healthz_url}: {e}"
- )
- async def _reconcile(self, event: Event):
- """
- When provisioning a worker, the state will transition from following steps:
- - PENDING - initial state for worker created by pool, the next state is PROVISIONING
- - PROVISIONING - begin provisioning with related info updated in worker object, the next state is PROVISIONED
- - PROVISIONED - done provisioning and waiting for worker to register
- - DELETING - worker is being deleted
- - ERROR - an error occurred during provisioning
- """
- worker: Worker = event.data
- if not worker:
- return
- if worker.state not in [
- WorkerStateEnum.PENDING,
- WorkerStateEnum.PROVISIONING,
- WorkerStateEnum.DELETING,
- ]:
- return
- logger.info(
- f"Reconcile provisioning worker {event.data.name} with event {event.type}"
- )
- async with async_session() as session:
- # Fetch the worker from the database
- worker: Worker = await Worker.one_by_id(
- session,
- worker.id,
- options=[
- selectinload(Worker.cluster),
- selectinload(Worker.worker_pool),
- ],
- )
- if not worker:
- return
- credential: CloudCredential = await CloudCredential.one_by_id(
- session, worker.cluster.credential_id
- )
- client = get_client_from_provider(
- worker.cluster.provider,
- credential=credential,
- )
- try:
- if worker.state == WorkerStateEnum.PENDING:
- await self.check_server_external_url(worker.cluster.server_url)
- if worker.state in [
- WorkerStateEnum.PENDING,
- WorkerStateEnum.PROVISIONING,
- ]:
- await self._provisioning_instance(
- session, client, worker, self._cfg
- )
- if worker.state == WorkerStateEnum.DELETING:
- await self._deleting_instance(session, client, worker)
- await session.commit()
- except Exception as e:
- message = f"Failed to provision or delete worker {worker.name}: {e}"
- logger.exception(message)
- await session.rollback()
- await session.refresh(worker)
- worker.state = WorkerStateEnum.ERROR
- worker.state_message = message
- await WorkerService(session).update(
- worker=worker, source=None, auto_commit=True
- )
- class ClusterController:
- def __init__(self, cfg: Config):
- self._cfg = cfg
- self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
- self._k8s_config = get_async_k8s_config(cfg=cfg)
- pass
- async def start(self):
- """
- Start the controller.
- """
- if self._cfg.gateway_mode != GatewayModeEnum.disabled:
- base_client = k8s_client.ApiClient(configuration=self._k8s_config)
- self._higress_network_api = NetworkingHigressIoV1Api(base_client)
- async for event in Cluster.subscribe(source="cluster_controller"):
- if event.type == EventType.HEARTBEAT:
- continue
- try:
- await self._reconcile(event)
- except Exception as e:
- logger.error(f"Failed to reconcile cluster: {e}")
- async def _reconcile(self, event: Event):
- """
- Reconcile the cluster state.
- """
- await self._sync_cluster_state(event)
- if self._disable_gateway:
- return
- await self._ensure_worker_mcp_bridge(event)
- async def _sync_cluster_state(self, event: Event):
- if event.type == EventType.DELETED:
- return
- cluster: Cluster = event.data
- if not cluster:
- return
- async with async_session() as session:
- cluster: Cluster = await Cluster.one_by_id(
- session, cluster.id, options=[selectinload(Cluster.cluster_workers)]
- )
- if not cluster or cluster.provider in [
- ClusterProvider.Kubernetes,
- ClusterProvider.Docker,
- ]:
- return
- if cluster.workers == 0 and cluster.state != ClusterStateEnum.PENDING:
- cluster.state = ClusterStateEnum.PENDING
- cluster.state_message = (
- "No workers have been provisioned for this cluster yet."
- )
- await cluster.update(session=session, auto_commit=True)
- async def _ensure_worker_mcp_bridge(self, event: Event):
- """
- The worker registry list for cluster is no longer needed.
- Use empty list to trigger MCPBridge controller to clean up the worker registries
- and proxies when cluster is created or deleted.
- """
- if self._cfg.gateway_mode == GatewayModeEnum.disabled:
- return
- cluster: Cluster = event.data
- mcp_resource_name = mcp_handler.default_mcp_bridge_name
- desired_registries = []
- to_delete_prefix = mcp_handler.cluster_worker_prefix(cluster.id)
- try:
- await mcp_handler.ensure_mcp_bridge(
- client=self._higress_network_api,
- namespace=self._cfg.gateway_namespace,
- mcp_bridge_name=mcp_resource_name,
- desired_registries=desired_registries,
- to_delete_prefix=to_delete_prefix,
- )
- except Exception as e:
- logger.error(f"Failed to ensure MCPBridge for cluster {cluster.name}: {e}")
- raise
- async def notify_model_route_target(session: AsyncSession, model: Model, event: Event):
- if event.type == EventType.DELETED:
- return
- should_notify = False
- if event.changed_fields is not None:
- related_fields = ["ready_replicas", "replicas"]
- for field in related_fields:
- if field in event.changed_fields:
- should_notify = True
- break
- model: Model = await Model.one_by_id(
- session=session,
- id=model.id,
- options=[
- selectinload(Model.model_route_targets),
- ],
- )
- if not model:
- return
- targets = model.model_route_targets
- for target in targets:
- if should_notify:
- target_copy = ModelRouteTarget(**target.model_dump())
- await event_bus.publish(
- target_copy.__class__.__name__.lower(),
- Event(
- type=EventType.UPDATED,
- data=target_copy,
- changed_fields={
- "model": (
- {},
- {
- "id": model.id,
- "name": model.name,
- "ready_replicas": model.ready_replicas,
- "replicas": model.replicas,
- },
- )
- },
- ),
- )
- async def sync_categories_and_meta(session: AsyncSession, model: Model, event: Event):
- if event.type == EventType.DELETED:
- return
- model: Model = await Model.one_by_id(
- session=session,
- id=model.id,
- options=[
- selectinload(Model.model_routes),
- ],
- )
- if not model:
- return
- routes = model.model_routes
- for route in routes:
- # created_by_model default to false if not set
- if not route.created_by_model:
- continue
- if route.categories != model.categories or route.meta != model.meta:
- await ModelRouteService(session).update(
- model_route=route,
- source={"categories": model.categories, "meta": model.meta},
- auto_commit=True,
- )
- class ModelProviderController:
- def __init__(self, cfg: Config):
- self._config = cfg
- self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
- self._k8s_config = get_async_k8s_config(cfg=cfg)
- async def start(self):
- if self._disable_gateway:
- return
- if not self._disable_gateway:
- base_client = k8s_client.ApiClient(configuration=self._k8s_config)
- self._higress_network_api = NetworkingHigressIoV1Api(base_client)
- self._higress_extension_api = ExtensionsHigressIoV1Api(base_client)
- async for event in ModelProvider.subscribe(source="model_provider_controller"):
- try:
- await self._reconcile(event)
- except Exception as e:
- logger.exception(f"Failed to reconcile model provider: {e}")
- async def _ensure_provider_registry(
- self,
- model_provider: ModelProvider,
- event: Event,
- ):
- provider_registry = mcp_handler.provider_registry(model_provider)
- registry_to_remove = (
- provider_registry is None or event.type == EventType.DELETED
- )
- to_delete_prefix = (
- f"{mcp_handler.provider_id_prefix}{model_provider.id}"
- if registry_to_remove
- else None
- )
- desired_registries = [] if registry_to_remove else [provider_registry]
- provider_proxy = mcp_handler.provider_proxy(model_provider)
- proxy_to_remove = provider_proxy is None or event.type == EventType.DELETED
- to_delete_proxy_prefix = (
- f"proxy-{model_provider.id}" if proxy_to_remove else None
- )
- desired_proxies = [] if proxy_to_remove else [provider_proxy]
- try:
- await mcp_handler.ensure_mcp_bridge(
- client=self._higress_network_api,
- namespace=self._config.gateway_namespace,
- mcp_bridge_name=mcp_handler.default_mcp_bridge_name,
- desired_registries=desired_registries,
- desired_proxies=desired_proxies,
- to_delete_prefix=to_delete_prefix,
- to_delete_proxies_prefix=to_delete_proxy_prefix,
- )
- except Exception as e:
- logger.error(
- f"Failed to ensure MCPRegistry for model provider {model_provider.name}: {e}"
- )
- raise
- async def _ensure_provider_ai_proxy_config(self):
- try:
- async with async_session() as session:
- providers = await ModelProvider.all_by_field(
- session,
- "deleted_at",
- None,
- )
- provider_config_list, match_rules = (
- mcp_handler.provider_proxy_plugin_spec(*providers)
- )
- await mcp_handler.ensure_wasm_plugin(
- api=self._higress_extension_api,
- name=mcp_handler.gpustack_ai_proxy_name,
- namespace=self._config.gateway_namespace,
- spec_diff=partial(
- mcp_handler.ai_proxy_diff_spec,
- expected_providers=provider_config_list,
- expected_match_rules=match_rules,
- operating_id_prefix=mcp_handler.provider_id_prefix,
- ),
- )
- except Exception as e:
- logger.error(f"Failed to ensure provider's ai_proxy config: {e}")
- raise
- async def _notify_provider_model_routes(
- self, session: AsyncSession, model_provider: ModelProvider, event: Event
- ):
- if event.type != EventType.UPDATED:
- return
- changed_fields = event.changed_fields or {}
- should_notify = False
- if "config" not in changed_fields:
- return
- # the changed field "config" must have old and new value, otherwise it's not a valid update event for config change.
- # index 0 of the tuple is the old value, index 1 is the new value.
- # each value must be a list with only 1 element as it is a norman field instead of relationship field.
- old_config = changed_fields["config"][0][0]
- if isinstance(changed_fields["config"][0][0], BaseModel):
- old_config = changed_fields["config"][0][0].model_dump()
- new_config = changed_fields["config"][1][0]
- if isinstance(changed_fields["config"][1][0], BaseModel):
- new_config = changed_fields["config"][1][0].model_dump()
- # use hardcoded fields to determine whether to notify.
- # For ProviderConfigType, including:
- # - openaiCustomUrl
- # - ollamaServerHost
- # - difyApiUrl
- # The above fields will affect the registry type of the provider_registry,
- # it requires notifying ingress to regenerate registry destination.
- related_fields = [
- "openaiCustomUrl",
- "ollamaServerHost",
- "difyApiUrl",
- ]
- for field in related_fields:
- if old_config.get(field) != new_config.get(field):
- should_notify = True
- break
- if not should_notify:
- return
- targets = await ModelRouteTarget.all_by_fields(
- session=session,
- fields={"provider_id": model_provider.id},
- options=[selectinload(ModelRouteTarget.model_route)],
- )
- unique_routes = {
- target.model_route.id: target.model_route
- for target in targets
- if target.model_route is not None
- }
- for route in unique_routes.values():
- route_copy = ModelRoute.model_validate(route.model_dump())
- await event_bus.publish(
- route_copy.__class__.__name__.lower(),
- Event(type=EventType.UPDATED, data=route_copy),
- )
- async def _reconcile(self, event: Event):
- """
- Reconcile the model provider.
- """
- model_provider: ModelProvider = event.data
- if not model_provider:
- return
- if event.type == EventType.DELETED:
- await self._ensure_provider_registry(model_provider, event)
- await self._ensure_provider_ai_proxy_config()
- return
- async with async_session() as session:
- model_provider: ModelProvider = await ModelProvider.one_by_id(
- session, model_provider.id
- )
- if not model_provider:
- return
- await self._ensure_provider_registry(model_provider, event)
- await self._ensure_provider_ai_proxy_config()
- await self._notify_provider_model_routes(session, model_provider, event)
- class ModelRouteTargetController:
- def __init__(self, config: Config):
- self._config = config
- async def start(self):
- async for event in ModelRouteTarget.subscribe(
- source="model_route_target_controller"
- ):
- try:
- await self._reconcile(event)
- except Exception as e:
- logger.exception(f"Failed to reconcile model route target: {e}")
- async def _notify_parents(
- self, session: AsyncSession, target: ModelRouteTarget, event: Event
- ):
- if event.type not in (EventType.UPDATED, EventType.DELETED):
- return
- changed_fields = event.changed_fields
- if not target or (not changed_fields and event.type != EventType.DELETED):
- return
- should_notify_fields = [
- "state",
- "provider_id",
- "model_id",
- "provider_model_name",
- "model",
- ]
- should_notify = event.type == EventType.DELETED
- if not should_notify:
- for field in should_notify_fields:
- if field in (changed_fields or {}):
- should_notify = True
- break
- if not should_notify:
- return
- try:
- model_route: ModelRoute = await ModelRoute.one_by_id(
- session, target.route_id
- )
- if not model_route:
- return
- copied_route = ModelRoute.model_validate(model_route.model_dump())
- await event_bus.publish(
- ModelRoute.__name__.lower(),
- Event(type=EventType.UPDATED, data=copied_route),
- )
- except Exception as e:
- logger.error(f"Failed to notify model route for target {target.name}: {e}")
- async def _sync_state(
- self, session: AsyncSession, target: ModelRouteTarget, event: Event
- ):
- if event.type == EventType.DELETED:
- return
- # Handle ID-only events from distributed mode
- target_id = (
- target.id
- if hasattr(target, 'id')
- else target.get('id') if isinstance(target, dict) else None
- )
- if not target_id:
- return
- target: ModelRouteTarget = await ModelRouteTarget.one_by_id(session, target_id)
- if not target:
- return
- if target.provider_id is not None:
- target_state = TargetStateEnum.ACTIVE
- if target.model_id is not None:
- model = await Model.one_by_id(session, target.model_id)
- if not model:
- return
- target_state = (
- TargetStateEnum.ACTIVE
- if model.ready_replicas > 0
- else TargetStateEnum.UNAVAILABLE
- )
- if target.state != target_state:
- target.state = target_state
- await target.update(session=session, auto_commit=True)
- async def _update_orphan_route(
- self, session: AsyncSession, target: ModelRouteTarget, event: Event
- ) -> bool:
- """
- Update the orphan route if the target is deleted or has no associated model.
- If the target model is not deleted, transfer model_route to a non model-created model.
- """
- if event.type != EventType.DELETED:
- return True
- if target.model_id is None:
- return True
- model = await Model.one_by_id(session, target.model_id)
- if not model or model.deleted_at is not None:
- return True
- # If the model is not deleted, transfer the model route to a non model-created model route to avoid service disruption.
- # The model route will be automatically deleted by the controller after the target is deleted.
- orphan_route = await ModelRoute.one_by_id(session=session, id=target.route_id)
- if (
- not orphan_route
- or orphan_route.deleted_at is not None
- or not orphan_route.created_by_model
- ):
- # The route is already deleted or not created by model, no need to transfer.
- # returns true to trigger parent notification and state sync to update the route state if needed.
- return True
- try:
- route_service = ModelRouteService(session=session)
- await route_service.update(
- orphan_route, source={"created_by_model": False}, auto_commit=True
- )
- except Exception as e:
- logger.error(f"Failed to transfer model route {orphan_route.id}: {e}")
- return True
- return False
- async def _reconcile(self, event: Event):
- target: ModelRouteTarget = event.data
- if not target:
- return
- async with async_session() as session:
- should_notify_parents = await self._update_orphan_route(
- session, target, event
- )
- if should_notify_parents:
- await self._notify_parents(session, target, event)
- await self._sync_state(session, target, event)
- class ModelRouteController:
- def __init__(self, cfg: Config):
- self._config = cfg
- self._gateway_namespace = cfg.gateway_namespace
- self._k8s_config = get_async_k8s_config(cfg=cfg)
- self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
- async def start(self):
- if not self._disable_gateway:
- base_client = k8s_client.ApiClient(configuration=self._k8s_config)
- self._networking_api = k8s_client.NetworkingV1Api(base_client)
- self._higress_extension_api = ExtensionsHigressIoV1Api(base_client)
- self._networking_istio_api = NetworkingIstioIoV1Alpha3Api(base_client)
- async for event in ModelRoute.subscribe(source="model_route_controller"):
- try:
- await self._reconcile(event)
- except Exception as e:
- logger.exception(f"Failed to reconcile model route: {e}")
- async def _sync_targets(self, session: AsyncSession, event: Event) -> bool:
- if event.type == EventType.DELETED:
- return False
- model_route: ModelRoute = event.data
- if not model_route:
- return False
- # Handle ID-only events from distributed mode
- model_route_id = (
- model_route.id
- if hasattr(model_route, 'id')
- else model_route.get('id') if isinstance(model_route, dict) else None
- )
- if not model_route_id:
- return False
- model_route: ModelRoute = await ModelRoute.one_by_id(
- session,
- model_route_id,
- options=[selectinload(ModelRoute.route_targets)],
- )
- if not model_route:
- return False
- target_total = len(model_route.route_targets)
- ready_target_total = len(
- [
- target
- for target in model_route.route_targets
- if target.state == TargetStateEnum.ACTIVE
- ]
- )
- model_route_service = ModelRouteService(session=session)
- if target_total == 0 and model_route.created_by_model:
- await model_route_service.delete(model_route, auto_commit=True)
- return True
- if (
- model_route.targets != target_total
- or model_route.ready_targets != ready_target_total
- ):
- model_route.targets = target_total
- model_route.ready_targets = ready_target_total
- await model_route_service.update(model_route, auto_commit=True)
- return True
- return False
- async def _reconcile(self, event: Event):
- """
- Reconcile the model route.
- """
- model_route: ModelRoute = event.data
- if not model_route:
- return
- async with async_session() as session:
- # sync targets will update model route record so make sure to do it before other operations
- updated = await self._sync_targets(session, event)
- if not self._disable_gateway and not updated:
- await sync_gateway(
- cfg=self._config,
- session=session,
- event=event,
- networking_api=self._networking_api,
- extensions_api=self._higress_extension_api,
- model_route=model_route,
- istio_networking_api=self._networking_istio_api,
- )
- await distribute_models_to_user(session, model_route, event)
|