controllers.py 100 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655
  1. import logging
  2. import random
  3. import string
  4. import asyncio
  5. import yaml
  6. from importlib.resources import files
  7. from functools import partial
  8. from typing import Any, Dict, List, Tuple, Optional, Set
  9. from pydantic import BaseModel
  10. from sqlmodel import select
  11. from sqlmodel.ext.asyncio.session import AsyncSession
  12. from sqlalchemy.orm import selectinload
  13. from sqlalchemy.orm.attributes import flag_modified
  14. from gpustack.config.config import (
  15. Config,
  16. get_cluster_image_name,
  17. )
  18. from gpustack.policies.scorers.offload_layer_scorer import OffloadLayerScorer
  19. from gpustack.policies.scorers.placement_scorer import PlacementScorer, ScaleTypeEnum
  20. from gpustack.policies.scorers.score_chain import (
  21. ModelInstanceScoreChain,
  22. )
  23. from gpustack.policies.base import ModelInstanceScore
  24. from gpustack.policies.scorers.status_scorer import StatusScorer
  25. from gpustack.schemas.inference_backend import (
  26. InferenceBackend,
  27. get_built_in_backend,
  28. VersionConfig,
  29. VersionConfigDict,
  30. )
  31. from gpustack.schemas.links import ModelRoutePrincipalLink
  32. from gpustack.schemas.model_files import ModelFile, ModelFileStateEnum
  33. from gpustack.schemas.model_routes import (
  34. ModelRoute,
  35. ModelRouteTarget,
  36. MyModel,
  37. TargetStateEnum,
  38. effective_route_name,
  39. )
  40. from gpustack.schemas.principals import (
  41. Principal,
  42. PLATFORM_PRINCIPAL_ID,
  43. )
  44. from gpustack.schemas.models import (
  45. BackendEnum,
  46. BackendSourceEnum,
  47. ModelSource,
  48. Model,
  49. ModelInstance,
  50. ModelInstanceCreate,
  51. ModelInstanceStateEnum,
  52. SourceEnum,
  53. get_backend,
  54. )
  55. from gpustack.schemas.config import (
  56. GatewayModeEnum,
  57. SensitivePredefinedConfig,
  58. )
  59. from gpustack.schemas.workers import (
  60. Worker,
  61. WorkerStateEnum,
  62. WorkerStatus,
  63. )
  64. from gpustack.schemas.clusters import (
  65. Cluster,
  66. WorkerPool,
  67. CloudCredential,
  68. Credential,
  69. CredentialType,
  70. ClusterStateEnum,
  71. SSHKeyOptions,
  72. ClusterProvider,
  73. )
  74. from gpustack.schemas.users import (
  75. User,
  76. is_default_cluster_user,
  77. )
  78. from gpustack.server.bus import Event, EventType, event_bus
  79. from gpustack.utils.model_source import get_draft_model_source
  80. from gpustack import envs
  81. from gpustack.server.db import async_session
  82. from gpustack.server.services import (
  83. ModelFileService,
  84. ModelInstanceService,
  85. ModelService,
  86. WorkerService,
  87. ModelRouteService,
  88. )
  89. from gpustack.utils.model_instance_workers import get_model_instance_worker_match
  90. from gpustack.cloud_providers.common import (
  91. get_client_from_provider,
  92. construct_cloud_instance,
  93. generate_ssh_key_pair,
  94. )
  95. from gpustack.cloud_providers.abstract import (
  96. ProviderClientBase,
  97. CloudInstance,
  98. InstanceState,
  99. )
  100. from kubernetes_asyncio import client as k8s_client
  101. from gpustack.gateway.client.networking_higress_io_v1_api import (
  102. NetworkingHigressIoV1Api,
  103. McpBridgeRegistry,
  104. )
  105. from gpustack.gateway.client.extensions_higress_io_v1_api import (
  106. ExtensionsHigressIoV1Api,
  107. WasmPluginMatchRule,
  108. WasmPluginSpec,
  109. )
  110. from gpustack.gateway.client.networking_istio_io_v1alpha3_api import (
  111. NetworkingIstioIoV1Alpha3Api,
  112. )
  113. from gpustack.gateway import utils as mcp_handler
  114. from gpustack.gateway import get_async_k8s_config
  115. from gpustack.schemas.model_provider import (
  116. ModelProvider,
  117. )
  118. logger = logging.getLogger(__name__)
  119. class ModelController:
  120. def __init__(self, cfg: Config):
  121. self._config = cfg
  122. self._k8s_config = get_async_k8s_config(cfg=cfg)
  123. self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
  124. pass
  125. async def start(self):
  126. """
  127. Start the controller.
  128. """
  129. if not self._disable_gateway:
  130. base_client = k8s_client.ApiClient(configuration=self._k8s_config)
  131. self._higress_network_api = NetworkingHigressIoV1Api(base_client)
  132. async for event in Model.subscribe(source="model_controller"):
  133. if event.type == EventType.HEARTBEAT:
  134. continue
  135. await self._reconcile(event)
  136. async def _ensure_model_mcp_bridge(
  137. self, session: AsyncSession, event_type: EventType, model: Model
  138. ):
  139. if self._disable_gateway:
  140. return
  141. model_instances = await ModelInstance.all_by_fields(
  142. session,
  143. fields={"model_id": model.id, "deleted_at": None},
  144. )
  145. worker_by_id = None
  146. worker_ids = {
  147. instance.worker_id for instance in model_instances if instance.worker_id
  148. }
  149. if worker_ids:
  150. workers = await Worker.all_by_fields(
  151. session,
  152. extra_conditions=[
  153. Worker.id.in_(worker_ids),
  154. ],
  155. )
  156. worker_by_id = {worker.id: worker for worker in workers}
  157. await mcp_handler.ensure_model_mcp_bridge(
  158. event_type=event_type,
  159. model_id=model.id,
  160. model_instances=model_instances,
  161. networking_higress_api=self._higress_network_api,
  162. namespace=self._config.gateway_namespace,
  163. cluster_id=model.cluster_id,
  164. workers=worker_by_id,
  165. )
  166. async def _reconcile(self, event: Event):
  167. """
  168. Reconcile the model.
  169. """
  170. model: Model = event.data
  171. try:
  172. async with async_session() as session:
  173. await sync_replicas(session, model)
  174. await notify_model_route_target(
  175. session=session, model=model, event=event
  176. )
  177. await sync_categories_and_meta(session, model, event)
  178. await self._ensure_model_mcp_bridge(session, event.type, model)
  179. except Exception as e:
  180. logger.error(f"Failed to reconcile model {model.name}: {e}")
  181. class ModelInstanceController:
  182. def __init__(self, cfg: Config):
  183. self._config = cfg
  184. pass
  185. async def start(self):
  186. """
  187. Start the controller.
  188. """
  189. async for event in ModelInstance.subscribe(source="model_instance_controller"):
  190. if event.type == EventType.HEARTBEAT:
  191. continue
  192. await self._reconcile(event)
  193. async def _reconcile(self, event: Event):
  194. """
  195. Reconcile the model.
  196. """
  197. model_instance: ModelInstance = event.data
  198. try:
  199. async with async_session() as session:
  200. model = await Model.one_by_id(session, model_instance.model_id)
  201. if not model:
  202. return
  203. model_deleting = model.deleted_at is not None
  204. if event.type == EventType.DELETED:
  205. # trigger model replica sync, but only if model is not deleted
  206. if not model_deleting:
  207. copied_model = Model.model_validate(model.model_dump())
  208. asyncio.create_task(
  209. event_bus.publish(
  210. Model.__name__.lower(),
  211. Event(type=EventType.UPDATED, data=copied_model),
  212. )
  213. )
  214. elif model_instance.state == ModelInstanceStateEnum.INITIALIZING:
  215. await ensure_instance_model_file(session, model_instance)
  216. return
  217. if model_deleting:
  218. return
  219. await model.refresh(session)
  220. await sync_ready_replicas(session, model)
  221. except Exception as e:
  222. logger.error(
  223. f"Failed to reconcile model instance {model_instance.name}: {e}"
  224. )
  225. async def sync_replicas(session: AsyncSession, model: Model):
  226. """
  227. Synchronize the replicas.
  228. """
  229. # Re-fetch model from database to ensure we have latest state
  230. # (event data may be from a different session or stale)
  231. fresh_model = await Model.one_by_id(session, model.id)
  232. if not fresh_model or fresh_model.deleted_at is not None:
  233. return
  234. model = fresh_model
  235. instances = await ModelInstance.all_by_field(session, "model_id", model.id)
  236. if len(instances) < model.replicas:
  237. for _ in range(model.replicas - len(instances)):
  238. name_prefix = ''.join(
  239. random.choices(string.ascii_letters + string.digits, k=5)
  240. )
  241. instance = ModelInstanceCreate(
  242. name=f"{model.name}-{name_prefix}",
  243. model_id=model.id,
  244. model_name=model.name,
  245. source=model.source,
  246. huggingface_repo_id=model.huggingface_repo_id,
  247. huggingface_filename=model.huggingface_filename,
  248. model_scope_model_id=model.model_scope_model_id,
  249. model_scope_file_path=model.model_scope_file_path,
  250. local_path=model.local_path,
  251. state=ModelInstanceStateEnum.PENDING,
  252. cluster_id=model.cluster_id,
  253. # Inherit the parent Model's tenant binding — the schema
  254. # default of PLATFORM_PRINCIPAL_ID would otherwise
  255. # land instances of a non-Default-Org Model in Default.
  256. owner_principal_id=model.owner_principal_id,
  257. draft_model_source=get_draft_model_source(model),
  258. backend=get_backend(model),
  259. backend_version=model.backend_version,
  260. )
  261. await ModelInstanceService(session).create(instance)
  262. logger.debug(f"Created model instance for model {model.name}")
  263. elif len(instances) > model.replicas:
  264. # Get instances for update lock, to avoid race condition with scheduler
  265. instances = await ModelInstance.all_by_field(
  266. session, "model_id", model.id, for_update=True
  267. )
  268. candidates = await find_scale_down_candidates(instances, model)
  269. scale_down_count = len(candidates) - model.replicas
  270. if scale_down_count > 0:
  271. scale_down_instances = []
  272. for candidate in candidates[:scale_down_count]:
  273. scale_down_instances.append(candidate.model_instance)
  274. scale_down_instance_names = await ModelInstanceService(
  275. session
  276. ).batch_delete(scale_down_instances)
  277. if scale_down_instance_names:
  278. logger.debug(f"Deleted model instances: {scale_down_instance_names}")
  279. async def distribute_models_to_user(
  280. session: AsyncSession, model: ModelRoute, event: Event
  281. ):
  282. if len(event.changed_fields) == 0 and event.type == EventType.CREATED:
  283. return
  284. model_dict = model.model_dump(exclude={"instances", "users", "cluster"})
  285. model_id = model.id
  286. to_delete_model_user_ids: Set[int] = set()
  287. to_update_model_user_ids: Set[int] = set()
  288. to_create_model_user_ids: Set[int] = set()
  289. if event.type == EventType.DELETED:
  290. users = await User.all_by_fields(
  291. session, fields={"deleted_at": None, "is_admin": False}
  292. )
  293. for user in users:
  294. to_delete_model_user_ids.add(user.id)
  295. if event.type == EventType.UPDATED:
  296. changed_fields = event.changed_fields.copy()
  297. changed_users = changed_fields.pop("users", None)
  298. if changed_users is not None:
  299. old_users, new_users = changed_users
  300. old_user_ids = {user.id for user in old_users}
  301. new_user_ids = {user.id for user in new_users}
  302. to_create_model_user_ids = new_user_ids - old_user_ids
  303. to_delete_model_user_ids = old_user_ids - new_user_ids
  304. if len(changed_fields) > 0:
  305. users = await User.all_by_fields(
  306. session,
  307. fields={"deleted_at": None, "is_admin": False},
  308. extra_conditions=[
  309. User.principal_id.in_(
  310. select(ModelRoutePrincipalLink.principal_id).where(
  311. ModelRoutePrincipalLink.route_id == model.id
  312. )
  313. )
  314. ],
  315. )
  316. current_user_ids = {user.id for user in users}
  317. to_update_model_user_ids = current_user_ids - to_create_model_user_ids
  318. if event.type == EventType.CREATED:
  319. users = await User.all_by_fields(
  320. session,
  321. fields={"deleted_at": None, "is_admin": False},
  322. extra_conditions=[
  323. User.principal_id.in_(
  324. select(ModelRoutePrincipalLink.principal_id).where(
  325. ModelRoutePrincipalLink.route_id == model.id
  326. )
  327. )
  328. ],
  329. )
  330. for user in users:
  331. to_create_model_user_ids.add(user.id)
  332. tasks = []
  333. for event_type, ids in [
  334. (EventType.CREATED, to_create_model_user_ids),
  335. (EventType.DELETED, to_delete_model_user_ids),
  336. (EventType.UPDATED, to_update_model_user_ids),
  337. ]:
  338. for user_id in ids:
  339. my_model = MyModel(
  340. pid=f"{model_id}:{user_id}",
  341. user_id=user_id,
  342. **model_dict,
  343. )
  344. tasks.append(
  345. event_bus.publish(
  346. MyModel.__name__.lower(), Event(type=event_type, data=my_model)
  347. )
  348. )
  349. if tasks:
  350. await asyncio.gather(*tasks)
  351. async def ensure_instance_model_file(session: AsyncSession, instance: ModelInstance):
  352. """
  353. Synchronize the model file of the model instance.
  354. """
  355. if instance.worker_id is None:
  356. # Not scheduled yet
  357. return
  358. instance = await ModelInstance.one_by_id(
  359. session,
  360. instance.id,
  361. options=[
  362. selectinload(ModelInstance.model_files),
  363. ],
  364. )
  365. if not instance:
  366. return
  367. if len(instance.model_files) > 0:
  368. await sync_instance_files_state(session, instance, instance.model_files)
  369. return
  370. retry_model_files = []
  371. model_files = await get_or_create_model_files_for_instance(session, instance)
  372. draft_model_files = []
  373. if instance.draft_model_source:
  374. draft_model_files = await get_or_create_model_files_for_instance(
  375. session, instance, is_draft_model=True
  376. )
  377. for model_file in model_files + draft_model_files:
  378. if model_file.state == ModelFileStateEnum.ERROR:
  379. # Retry the download
  380. retry_model_files.append(model_file.readable_source)
  381. model_file.state = ModelFileStateEnum.DOWNLOADING
  382. model_file.download_progress = 0
  383. model_file.state_message = ""
  384. await model_file.update(session, auto_commit=False)
  385. if retry_model_files:
  386. await session.commit()
  387. logger.info(
  388. f"Retrying download for model files {retry_model_files} for model instance {instance.name}"
  389. )
  390. instance = await ModelInstance.one_by_id(session, instance.id)
  391. instance.model_files = model_files
  392. instance.draft_model_files = draft_model_files
  393. await sync_instance_files_state(session, instance, model_files + draft_model_files)
  394. async def get_or_create_model_files_for_instance(
  395. session: AsyncSession, instance: ModelInstance, is_draft_model: bool = False
  396. ) -> List[ModelFile]:
  397. """
  398. Get or create model files for the given model instance.
  399. If is_draft_model is True, get or create model files for the draft model.
  400. """
  401. model_files = await get_model_files_for_instance(session, instance, is_draft_model)
  402. worker_ids = _get_worker_ids_for_file_download(instance)
  403. # Return early if all model files are already created for the workers
  404. if len(model_files) == len(worker_ids):
  405. return model_files
  406. # Get the worker IDs that are missing model files.
  407. missing_worker_ids = set(worker_ids) - {
  408. model_file.worker_id for model_file in model_files
  409. }
  410. if not missing_worker_ids:
  411. return model_files
  412. model_source = instance
  413. if is_draft_model:
  414. model_source = instance.draft_model_source
  415. # Create model files for the missing worker IDs.
  416. for worker_id in missing_worker_ids:
  417. model_file = ModelFile(
  418. source=model_source.source,
  419. huggingface_repo_id=model_source.huggingface_repo_id,
  420. huggingface_filename=model_source.huggingface_filename,
  421. model_scope_model_id=model_source.model_scope_model_id,
  422. model_scope_file_path=model_source.model_scope_file_path,
  423. local_path=model_source.local_path,
  424. state=ModelFileStateEnum.DOWNLOADING,
  425. worker_id=worker_id,
  426. source_index=model_source.model_source_index,
  427. )
  428. await ModelFile.create(session, model_file)
  429. logger.info(
  430. f"Created model file for model instance {instance.name} and worker {worker_id}"
  431. )
  432. # After creating the model files, fetch them again to return the complete list.
  433. return await get_model_files_for_instance(session, instance, is_draft_model)
  434. async def get_model_files_for_instance(
  435. session: AsyncSession, instance: ModelInstance, is_draft_model: bool = False
  436. ) -> List[ModelFile]:
  437. """
  438. Get the model files for the given model instance.
  439. If draft_model is provided, get the model files for the draft model.
  440. """
  441. worker_ids = _get_worker_ids_for_file_download(instance)
  442. model_source: ModelSource = instance
  443. if is_draft_model:
  444. model_source = instance.draft_model_source
  445. model_files = await ModelFileService(session).get_by_source_index(
  446. model_source.model_source_index
  447. )
  448. model_files = [
  449. model_file for model_file in model_files if model_file.worker_id in worker_ids
  450. ]
  451. if model_source.source == SourceEnum.LOCAL_PATH and model_source.local_path:
  452. # If the source is local path, get the model files with the same local path.
  453. local_path_model_files = await ModelFileService(session).get_by_resolved_path(
  454. model_source.local_path
  455. )
  456. local_path_model_files = [
  457. model_file
  458. for model_file in local_path_model_files
  459. if model_file.worker_id in worker_ids
  460. ]
  461. existing_worker_ids = {mf.worker_id for mf in model_files}
  462. additional_files = [
  463. model_file
  464. for model_file in local_path_model_files
  465. if model_file.worker_id not in existing_worker_ids
  466. ]
  467. model_files.extend(additional_files)
  468. return model_files
  469. async def find_scale_down_candidates(
  470. instances: List[ModelInstance],
  471. model: Model,
  472. *,
  473. status_max_score: Optional[float] = None,
  474. offload_max_score: Optional[float] = None,
  475. placement_max_score: Optional[float] = None,
  476. total_max_score: Optional[float] = None,
  477. ) -> List[ModelInstanceScore]:
  478. try:
  479. if status_max_score is None:
  480. status_max_score = envs.SCHEDULER_SCALE_DOWN_STATUS_MAX_SCORE
  481. if offload_max_score is None:
  482. offload_max_score = envs.SCHEDULER_SCALE_DOWN_OFFLOAD_MAX_SCORE
  483. if placement_max_score is None:
  484. placement_max_score = envs.SCHEDULER_SCALE_DOWN_PLACEMENT_MAX_SCORE
  485. chain = ModelInstanceScoreChain(
  486. scorers=[
  487. StatusScorer(model, max_score=status_max_score),
  488. OffloadLayerScorer(model, max_score=offload_max_score),
  489. PlacementScorer(
  490. model,
  491. instances,
  492. scale_type=ScaleTypeEnum.SCALE_DOWN,
  493. max_score=placement_max_score,
  494. ),
  495. ],
  496. total_max_score=total_max_score,
  497. )
  498. final_candidates = await chain.score(instances)
  499. final_candidates = sorted(
  500. final_candidates, key=lambda x: x.score, reverse=False
  501. )
  502. return final_candidates
  503. except Exception as e:
  504. state_message = (
  505. f"Failed to find scale down candidates for model {model.name}: {e}"
  506. )
  507. logger.error(state_message)
  508. return []
  509. async def sync_ready_replicas(session: AsyncSession, model: Model):
  510. """
  511. Synchronize the ready replicas.
  512. """
  513. if model.deleted_at is not None:
  514. return
  515. instances = await ModelInstance.all_by_field(session, "model_id", model.id)
  516. ready_replicas: int = 0
  517. for _, instance in enumerate(instances):
  518. if instance.state == ModelInstanceStateEnum.RUNNING:
  519. ready_replicas += 1
  520. if model.ready_replicas != ready_replicas:
  521. model.ready_replicas = ready_replicas
  522. await ModelService(session).update(model)
  523. async def get_cluster_registry(
  524. session: AsyncSession, cluster_id: int
  525. ) -> Optional[McpBridgeRegistry]:
  526. cluster_user = await User.one_by_field(
  527. session=session,
  528. field="cluster_id",
  529. value=cluster_id,
  530. options=[selectinload(User.cluster)],
  531. )
  532. if is_default_cluster_user(cluster_user):
  533. return None
  534. cluster_registry = mcp_handler.cluster_registry(cluster_user.cluster)
  535. if cluster_registry is None:
  536. return None
  537. return cluster_registry
  538. async def sync_model_route_mapper(
  539. cfg: Config,
  540. extensions_api: ExtensionsHigressIoV1Api,
  541. ingress_name: str,
  542. route_name: str,
  543. destinations: mcp_handler.DestinationTupleList,
  544. fallback_destinations: mcp_handler.DestinationTupleList,
  545. ):
  546. """
  547. Synchronize the model route mapper.
  548. """
  549. ingress_prefix = f"{cfg.get_namespace()}/"
  550. if cfg.get_namespace() == cfg.gateway_namespace:
  551. ingress_prefix = ""
  552. model_name_to_registries: Dict[str, List[str]] = {}
  553. for _, model_name, registry in destinations:
  554. if route_name == model_name:
  555. # Skip self mapping
  556. continue
  557. registries = model_name_to_registries.setdefault(model_name, [])
  558. registries.append(registry.get_service_name())
  559. fallback_model_name_to_registries: Dict[str, List[str]] = {}
  560. for _, model_name, registry in fallback_destinations:
  561. registries = fallback_model_name_to_registries.setdefault(model_name, [])
  562. registries.append(registry.get_service_name())
  563. expected_rules = mcp_handler.get_expected_match_list(
  564. route_name=route_name,
  565. ingress_prefix=ingress_prefix,
  566. ingress_name=ingress_name,
  567. model_name_to_registries=model_name_to_registries,
  568. fallback_model_name_to_registries=fallback_model_name_to_registries,
  569. )
  570. def spec_diff(current_spec: Optional[WasmPluginSpec]) -> WasmPluginSpec:
  571. # the current spec must exist. If not, it means the plugin has been deleted manually,
  572. # we should not recreate it until next update event to avoid potential misconfiguration.
  573. if current_spec is None:
  574. return current_spec
  575. to_keep_rules: List[WasmPluginMatchRule] = []
  576. full_ingress_name = f"{ingress_prefix}{ingress_name}"
  577. for rule in current_spec.matchRules or []:
  578. if full_ingress_name not in rule.ingress:
  579. to_keep_rules.append(rule)
  580. to_keep_rules.extend(expected_rules)
  581. to_keep_rules.sort(key=lambda r: r.ingress[0] if r.ingress else "")
  582. current_spec.matchRules = to_keep_rules
  583. return current_spec
  584. await mcp_handler.ensure_wasm_plugin(
  585. api=extensions_api,
  586. name=mcp_handler.gpustack_model_mapper_name,
  587. namespace=cfg.gateway_namespace,
  588. spec_diff=spec_diff,
  589. )
  590. async def ensure_route_generic_transformer_config(
  591. cfg: Config,
  592. model_route: ModelRoute,
  593. effective_name: str,
  594. extensions_api: ExtensionsHigressIoV1Api,
  595. generic_proxy_enabled: bool,
  596. ):
  597. """
  598. Reconcile the single HeaderRule that maps /model/proxy/<route_id>/... to this
  599. route's x-higress-llm-model. When generic_proxy_enabled is False (generic proxy
  600. disabled or route deleted), the rule is removed and other routes are untouched.
  601. ``effective_name`` is the fully-qualified model name including the
  602. Org slug prefix (e.g. ``org1/qwen3-0.6b``) for non-platform Orgs;
  603. platform Org keeps the unprefixed ``model_route.name``.
  604. """
  605. operating_path_pattern = mcp_handler.build_generic_route_path_pattern(
  606. model_route.id
  607. )
  608. expected_header_rules: List[Dict[str, Any]] = []
  609. if generic_proxy_enabled:
  610. expected_header_rules.append(
  611. mcp_handler.build_generic_route_header_rule(model_route.id, effective_name)
  612. )
  613. await mcp_handler.ensure_wasm_plugin(
  614. api=extensions_api,
  615. name=mcp_handler.gpustack_generic_route_transformer_name,
  616. namespace=cfg.gateway_namespace,
  617. spec_diff=partial(
  618. mcp_handler.generic_route_transformer_diff_spec,
  619. expected_header_rules=expected_header_rules,
  620. operating_path_pattern=operating_path_pattern,
  621. ),
  622. )
  623. async def ensure_route_ai_proxy_config(
  624. cfg: Config,
  625. model_route_id: int,
  626. extensions_api: ExtensionsHigressIoV1Api,
  627. route_destinations: mcp_handler.DestinationTupleList,
  628. fallback_destinations: mcp_handler.DestinationTupleList,
  629. ):
  630. service_namespace_prefix = cfg.get_namespace() + "/"
  631. if cfg.get_namespace() == cfg.gateway_namespace:
  632. service_namespace_prefix = ""
  633. operating_id = mcp_handler.model_route_cleanup_prefix(model_route_id)
  634. ingress_name = mcp_handler.model_route_ingress_name(model_route_id)
  635. fallback_ingress_name = mcp_handler.fallback_ingress_name(ingress_name)
  636. expected_providers = []
  637. expected_match_rules = []
  638. # cross provider needs to configure ai_proxy
  639. unique_registry_services: Set[str] = set(
  640. registry.get_service_name()
  641. for _, _, registry in route_destinations
  642. if (not registry.name.startswith(mcp_handler.provider_id_prefix))
  643. )
  644. unique_fallback_registry_services: Set[str] = set(
  645. registry.get_service_name()
  646. for _, _, registry in fallback_destinations
  647. if (not registry.name.startswith(mcp_handler.provider_id_prefix))
  648. )
  649. if len(unique_registry_services) + len(unique_fallback_registry_services) > 0:
  650. expected_providers.append(
  651. mcp_handler.ai_proxy_openai_provider_config(operating_id)
  652. )
  653. if len(unique_registry_services) > 0:
  654. expected_match_rules.append(
  655. WasmPluginMatchRule(
  656. config={
  657. "activeProviderId": operating_id,
  658. },
  659. configDisable=False,
  660. service=list(unique_registry_services),
  661. ingress=[f"{service_namespace_prefix}{ingress_name}"],
  662. )
  663. )
  664. # same logic for fallback
  665. if len(unique_fallback_registry_services) > 0:
  666. expected_match_rules.append(
  667. WasmPluginMatchRule(
  668. config={
  669. "activeProviderId": operating_id,
  670. },
  671. configDisable=False,
  672. service=list(unique_fallback_registry_services),
  673. ingress=[f"{service_namespace_prefix}{fallback_ingress_name}"],
  674. )
  675. )
  676. await mcp_handler.ensure_wasm_plugin(
  677. api=extensions_api,
  678. name=mcp_handler.gpustack_ai_proxy_name,
  679. namespace=cfg.gateway_namespace,
  680. spec_diff=partial(
  681. mcp_handler.ai_proxy_diff_spec,
  682. expected_providers=expected_providers,
  683. expected_match_rules=expected_match_rules,
  684. operating_id_prefix=operating_id,
  685. ),
  686. )
  687. async def sync_gateway(
  688. session: AsyncSession,
  689. event: Event,
  690. cfg: Config,
  691. model_route: ModelRoute,
  692. networking_api: k8s_client.NetworkingV1Api,
  693. extensions_api: ExtensionsHigressIoV1Api,
  694. istio_networking_api: NetworkingIstioIoV1Alpha3Api,
  695. ):
  696. event_type = event.type
  697. model_route_from_db = await ModelRoute.one_by_id(
  698. session,
  699. model_route.id,
  700. options=[selectinload(ModelRoute.route_targets)],
  701. )
  702. targets: List[ModelRouteTarget] = (
  703. getattr(model_route_from_db, "route_targets", []) if model_route_from_db else []
  704. )
  705. has_fallback_target = any(
  706. target
  707. for target in targets
  708. if target.fallback_status_codes and len(target.fallback_status_codes) > 0
  709. )
  710. destinations = []
  711. fallback_destinations = []
  712. if not model_route_from_db:
  713. event_type = EventType.DELETED
  714. if event.type != EventType.DELETED:
  715. destinations, fallback_destinations = await calculate_destinations(
  716. session, model_route
  717. )
  718. # Effective model name = `<org-slug>/<route.name>` for non-platform
  719. # Orgs (so two Orgs can use the same `route.name` without colliding
  720. # in Higress's AI proxy match rules), unprefixed for the platform Org
  721. # (backward compatible for existing clients).
  722. route_owner = await Principal.one_by_id(session, model_route.owner_principal_id)
  723. effective_name = effective_route_name(
  724. model_route.name,
  725. getattr(route_owner, "slug", None),
  726. getattr(route_owner, "id", None) == PLATFORM_PRINCIPAL_ID,
  727. )
  728. ingress_name = mcp_handler.model_route_ingress_name(model_route.id)
  729. await sync_model_route_mapper(
  730. cfg=cfg,
  731. extensions_api=extensions_api,
  732. ingress_name=ingress_name,
  733. route_name=effective_name,
  734. destinations=destinations,
  735. fallback_destinations=fallback_destinations,
  736. )
  737. # FIXME: Copy the fallback destination to the main ingress for now to make sure the fallback
  738. # route is always hit when fallback is configured, even if the main route has no valid
  739. # destination. This is to avoid potential misconfiguration that causes the main route to
  740. # have no destination and the fallback route is not hit at all.
  741. await mcp_handler.ensure_model_ingress(
  742. ingress_class_name=cfg.gateway_ingress_class,
  743. event_type=event_type,
  744. ingress_name=ingress_name,
  745. route_name=effective_name,
  746. namespace=cfg.get_namespace(),
  747. destinations=destinations if len(destinations) > 0 else fallback_destinations,
  748. networking_api=networking_api,
  749. included_generic_route=False,
  750. included_proxy_route=model_route.generic_proxy,
  751. )
  752. fallback_event_type = event_type
  753. if not has_fallback_target:
  754. fallback_event_type = EventType.DELETED
  755. # Fallback ingress
  756. await mcp_handler.ensure_model_ingress(
  757. ingress_class_name=cfg.gateway_ingress_class,
  758. event_type=fallback_event_type,
  759. ingress_name=mcp_handler.fallback_ingress_name(ingress_name),
  760. route_name=effective_name,
  761. namespace=cfg.get_namespace(),
  762. destinations=fallback_destinations,
  763. networking_api=networking_api,
  764. included_generic_route=False,
  765. included_proxy_route=model_route.generic_proxy,
  766. extra_annotations=mcp_handler.higress_http_header_matcher(
  767. "exact", "x-higress-fallback-from", ingress_name
  768. ),
  769. )
  770. # Fallback filter
  771. await mcp_handler.ensure_fallback_filter(
  772. event_type=fallback_event_type,
  773. ingress_name=ingress_name,
  774. namespace=cfg.get_namespace(),
  775. networking_istio_api=istio_networking_api,
  776. )
  777. # Generic-route transformer: inject x-higress-llm-model when /model/proxy/<id>/
  778. # is hit, so the existing main ingress header matcher + fallback chain apply.
  779. await ensure_route_generic_transformer_config(
  780. cfg=cfg,
  781. model_route=model_route,
  782. effective_name=effective_name,
  783. extensions_api=extensions_api,
  784. generic_proxy_enabled=(
  785. event_type != EventType.DELETED and bool(model_route.generic_proxy)
  786. ),
  787. )
  788. # ensure ai proxy config
  789. await ensure_route_ai_proxy_config(
  790. cfg=cfg,
  791. model_route_id=model_route.id,
  792. extensions_api=extensions_api,
  793. route_destinations=destinations,
  794. fallback_destinations=fallback_destinations,
  795. )
  796. def flatten_destinations(
  797. weight_to_count: List[Tuple[int, int, mcp_handler.DestinationTupleList]],
  798. max_weight: Optional[int] = 0,
  799. ) -> mcp_handler.DestinationTupleList:
  800. persentage_list = mcp_handler.hamilton_calculate_weight(
  801. [(weight, count) for weight, count, _ in weight_to_count],
  802. max_weight=max_weight,
  803. )
  804. flatten_registry_list: mcp_handler.DestinationTupleList = []
  805. index = 0
  806. for _, _, registry_list_part in weight_to_count:
  807. for count, model_name, registry in registry_list_part:
  808. total_percentage = sum(persentage_list[index : index + count])
  809. index += count
  810. if total_percentage != 0:
  811. flatten_registry_list.append((total_percentage, model_name, registry))
  812. return flatten_registry_list
  813. async def calculate_destinations(
  814. session: AsyncSession,
  815. model_route: ModelRoute,
  816. ) -> Tuple[mcp_handler.DestinationTupleList, mcp_handler.DestinationTupleList]:
  817. """
  818. return persentage Tuple for each registry with model name and the fallback registry
  819. """
  820. weight_to_count: List[Tuple[int, int, mcp_handler.DestinationTupleList]] = []
  821. fallback_weight_to_count: List[
  822. Tuple[int, int, mcp_handler.DestinationTupleList]
  823. ] = []
  824. targets = await ModelRouteTarget.all_by_field(session, "route_id", model_route.id)
  825. for target in targets:
  826. if target.state != TargetStateEnum.ACTIVE:
  827. continue
  828. to_extend: mcp_handler.DestinationTupleList = []
  829. if target.model_id is not None:
  830. model = await Model.one_by_id(session, target.model_id)
  831. if model is None:
  832. continue
  833. to_extend = await calculate_model_destinations(session, model)
  834. elif target.provider_id is not None:
  835. to_extend = await provider_destinations(
  836. session=session,
  837. provider_id=target.provider_id,
  838. provider_model_name=target.provider_model_name,
  839. )
  840. if to_extend is None or len(to_extend) == 0:
  841. # no valid destination found
  842. continue
  843. count = sum([count for count, _, _ in to_extend])
  844. weight_to_count.append((target.weight, count, to_extend))
  845. if (
  846. target.fallback_status_codes is not None
  847. and len(target.fallback_status_codes) > 0
  848. ):
  849. fallback_weight_to_count.append((target.weight, count, to_extend))
  850. if len(weight_to_count) == 0:
  851. return [], []
  852. flatten_registry_list = flatten_destinations(weight_to_count)
  853. fallback_registry_list = []
  854. if len(fallback_weight_to_count) > 0:
  855. # fallback might have 0 weight, so set max_weight to 1
  856. fallback_registry_list = flatten_destinations(
  857. fallback_weight_to_count, max_weight=1
  858. )
  859. return flatten_registry_list, fallback_registry_list
  860. async def provider_destinations(
  861. session: AsyncSession,
  862. provider_id: int,
  863. provider_model_name: str,
  864. ) -> mcp_handler.DestinationTupleList:
  865. """
  866. return count dict for provider registry
  867. """
  868. provider = await ModelProvider.one_by_id(session, provider_id)
  869. if provider is None:
  870. return []
  871. return [(1, provider_model_name, mcp_handler.provider_registry(provider))]
  872. async def calculate_model_destinations(
  873. session: AsyncSession,
  874. model: Model,
  875. ) -> mcp_handler.DestinationTupleList:
  876. """
  877. return count dict for each registry
  878. """
  879. # find out is handling default cluster's model
  880. cluster_registry = await get_cluster_registry(session, model.cluster_id)
  881. if cluster_registry is not None:
  882. return [(1, model.name, cluster_registry)]
  883. instances = await ModelInstance.all_by_field(session, "model_id", model.id)
  884. instances = [
  885. instance
  886. for instance in instances
  887. if instance.worker_ip is not None
  888. and instance.port is not None
  889. and instance.worker_ip != ""
  890. and instance.state == ModelInstanceStateEnum.RUNNING
  891. ]
  892. worker_list = await Worker.all_by_fields(
  893. session=session,
  894. fields={
  895. "cluster_id": model.cluster_id,
  896. "deleted_at": None,
  897. },
  898. extra_conditions=[
  899. Worker.id.in_(
  900. [
  901. instance.worker_id
  902. for instance in instances
  903. if instance.worker_id is not None
  904. ]
  905. )
  906. ],
  907. )
  908. workers = {worker.id: worker for worker in worker_list}
  909. registry_list = mcp_handler.model_instances_registry_list(instances, workers)
  910. return registry_list
  911. class WorkerController:
  912. def __init__(self, cfg: Config):
  913. self._provisioning = WorkerProvisioningController(cfg)
  914. async def start(self):
  915. """
  916. Start the controller.
  917. """
  918. async for event in Worker.subscribe(source="worker_controller"):
  919. if event.type == EventType.HEARTBEAT:
  920. continue
  921. try:
  922. await self._reconcile(event)
  923. await self._provisioning._reconcile(event)
  924. await self._notify_relatives(event)
  925. except Exception as e:
  926. logger.error(f"Failed to reconcile worker: {e}")
  927. async def _reconcile(self, event: Event):
  928. """
  929. Delete instances base on the worker state and event type.
  930. """
  931. if event.type not in (EventType.UPDATED, EventType.DELETED):
  932. return
  933. worker: Worker = event.data
  934. if not worker:
  935. return
  936. if worker.state.is_provisioning and worker.state != WorkerStateEnum.DELETING:
  937. # Skip reconciliation for provisioning and deleting workers.
  938. # There is a dedicated controller to handle provisioning.
  939. return
  940. if event.type == EventType.UPDATED:
  941. changed_fields = event.changed_fields
  942. if not changed_fields or "state" not in changed_fields:
  943. # No state change
  944. return
  945. async with async_session() as session:
  946. all_instances = await ModelInstance.all_by_field(
  947. session, "cluster_id", worker.cluster_id
  948. )
  949. if not all_instances:
  950. return
  951. matched_instances = []
  952. for instance in all_instances:
  953. match = get_model_instance_worker_match(
  954. instance,
  955. worker_name=worker.name,
  956. worker_id=worker.id,
  957. )
  958. if match.matched:
  959. matched_instances.append((instance, match))
  960. if not matched_instances:
  961. return
  962. if event.type == EventType.DELETED:
  963. instance_names = await ModelInstanceService(session).batch_delete(
  964. [instance for instance, _ in matched_instances]
  965. )
  966. if instance_names:
  967. logger.info(
  968. f"Delete instance {', '.join(instance_names)} "
  969. f"since worker {worker.name} is deleted"
  970. )
  971. return
  972. if (
  973. worker.unreachable
  974. or worker.state == WorkerStateEnum.UNREACHABLE
  975. or worker.state == WorkerStateEnum.NOT_READY
  976. ):
  977. await self.update_impacted_instance_states_to_unreachable(
  978. session,
  979. matched_instances,
  980. worker.name,
  981. )
  982. return
  983. async def update_impacted_instance_states_to_unreachable(
  984. self,
  985. session,
  986. matched_instances,
  987. worker_name,
  988. ):
  989. instance_names = set()
  990. subordinate_worker_names = set()
  991. for instance, match in matched_instances:
  992. patch = {}
  993. distributed_servers_changed = False
  994. if (
  995. match.is_main_worker
  996. and instance.state == ModelInstanceStateEnum.RUNNING
  997. ):
  998. patch["state"] = ModelInstanceStateEnum.UNREACHABLE
  999. patch["state_message"] = "Worker is unreachable from the server"
  1000. instance_names.add(instance.name)
  1001. for index in match.subordinate_worker_indexes:
  1002. subordinate_worker = instance.distributed_servers.subordinate_workers[
  1003. index
  1004. ]
  1005. if subordinate_worker.state == ModelInstanceStateEnum.UNREACHABLE:
  1006. continue
  1007. subordinate_worker.state = ModelInstanceStateEnum.UNREACHABLE
  1008. subordinate_worker.state_message = (
  1009. "Worker is unreachable from the server"
  1010. )
  1011. subordinate_worker_names.add(
  1012. f"{instance.name}:{subordinate_worker.worker_name}"
  1013. )
  1014. distributed_servers_changed = True
  1015. if distributed_servers_changed:
  1016. patch["distributed_servers"] = instance.distributed_servers
  1017. flag_modified(instance, "distributed_servers")
  1018. if patch:
  1019. await ModelInstanceService(session).update(instance, patch)
  1020. if instance_names:
  1021. logger.info(
  1022. f"Marked instance {', '.join(instance_names)} unreachable "
  1023. f"since worker {worker_name} is unreachable from the server"
  1024. )
  1025. if subordinate_worker_names:
  1026. logger.info(
  1027. f"Marked subordinate workers {', '.join(subordinate_worker_names)} unreachable "
  1028. f"since worker {worker_name} is unreachable from the server"
  1029. )
  1030. async def _notify_relatives(self, event: Event):
  1031. if event.type not in (EventType.UPDATED, EventType.DELETED):
  1032. return
  1033. worker: Worker = event.data
  1034. changed_fields = event.changed_fields
  1035. if not worker or (not changed_fields and event.type != EventType.DELETED):
  1036. return
  1037. state_changed: Optional[Tuple[Any, Any]] = (changed_fields or {}).get(
  1038. "state", None
  1039. )
  1040. proxy_mode_changed: Optional[Tuple[Any, Any]] = (changed_fields or {}).get(
  1041. "proxy_mode", None
  1042. )
  1043. should_notify_parents = (
  1044. state_changed is not None
  1045. or proxy_mode_changed is not None
  1046. or event.type == EventType.DELETED
  1047. )
  1048. proxy_address_changed: Optional[Tuple[Any, Any]] = (changed_fields or {}).get(
  1049. "proxy_address", None
  1050. )
  1051. should_notify_children = (
  1052. proxy_address_changed is not None or proxy_mode_changed is not None
  1053. )
  1054. if not should_notify_parents and not should_notify_children:
  1055. return
  1056. async with async_session() as session:
  1057. if should_notify_parents and worker.worker_pool_id is not None:
  1058. worker_pool = await WorkerPool.one_by_id(
  1059. session,
  1060. worker.worker_pool_id,
  1061. options=[selectinload(WorkerPool.pool_workers)],
  1062. )
  1063. if worker_pool is not None:
  1064. copied_pool = WorkerPool(**worker_pool.model_dump())
  1065. await event_bus.publish(
  1066. copied_pool.__class__.__name__.lower(),
  1067. Event(
  1068. type=EventType.UPDATED,
  1069. data=copied_pool,
  1070. ),
  1071. )
  1072. if should_notify_parents and worker.cluster_id is not None:
  1073. cluster = await Cluster.one_by_id(
  1074. session,
  1075. worker.cluster_id,
  1076. options=[
  1077. selectinload(Cluster.cluster_workers),
  1078. selectinload(Cluster.cluster_models),
  1079. ],
  1080. )
  1081. if cluster is not None:
  1082. copied_cluster = Cluster(**cluster.model_dump())
  1083. await event_bus.publish(
  1084. copied_cluster.__class__.__name__.lower(),
  1085. Event(
  1086. type=EventType.UPDATED,
  1087. data=copied_cluster,
  1088. ),
  1089. )
  1090. if should_notify_children:
  1091. instances = await ModelInstance.all_by_fields(
  1092. session,
  1093. fields={"worker_id": worker.id},
  1094. options=[selectinload(ModelInstance.model)],
  1095. )
  1096. notified_model = set()
  1097. for instance in instances:
  1098. if instance.model_id in notified_model:
  1099. continue
  1100. notified_model.add(instance.model_id)
  1101. copied_model = Model(**instance.model.model_dump())
  1102. await event_bus.publish(
  1103. copied_model.__class__.__name__.lower(),
  1104. Event(
  1105. type=EventType.UPDATED,
  1106. data=copied_model,
  1107. ),
  1108. )
  1109. class InferenceBackendController:
  1110. """
  1111. Inference backend controller initializes built-in and community backends in the database.
  1112. """
  1113. async def start(self):
  1114. async with async_session() as session:
  1115. # Initialize built-in backends
  1116. await self._init_built_in_backends(session)
  1117. # Initialize community backends
  1118. await self._init_community_backends(session)
  1119. async def _init_built_in_backends(self, session: AsyncSession):
  1120. """Initialize built-in backends in the database."""
  1121. for built_in_backend in get_built_in_backend():
  1122. if built_in_backend.backend_name == BackendEnum.CUSTOM.value:
  1123. continue
  1124. # Built-in backends always seed as Platform (owner_principal_id IS NULL).
  1125. # Per-Org overrides live in additional rows created by Org owners /
  1126. # managers; those are managed via the inference_backend routes.
  1127. backend = await InferenceBackend.one_by_fields(
  1128. session,
  1129. {
  1130. "backend_name": built_in_backend.backend_name,
  1131. "owner_principal_id": None,
  1132. },
  1133. )
  1134. if not backend:
  1135. # Create new built-in backend with backend_source
  1136. built_in_backend.backend_source = BackendSourceEnum.BUILT_IN
  1137. built_in_backend.enabled = True
  1138. await InferenceBackend.create(session, built_in_backend)
  1139. logger.info(
  1140. f"Init built-in backend {built_in_backend.backend_name} in database"
  1141. )
  1142. elif backend.backend_source is None:
  1143. # Update existing backend without backend_source
  1144. backend.backend_source = BackendSourceEnum.BUILT_IN
  1145. if backend.enabled is None:
  1146. backend.enabled = True
  1147. await backend.update(
  1148. session,
  1149. {
  1150. "backend_source": BackendSourceEnum.BUILT_IN,
  1151. "enabled": (
  1152. backend.enabled if backend.enabled is not None else True
  1153. ),
  1154. },
  1155. )
  1156. logger.info(
  1157. f"Updated backend_source for existing built-in backend {backend.backend_name}"
  1158. )
  1159. async def _init_community_backends(self, session: AsyncSession): # noqa: C901
  1160. """Load community backends from community-inference-backends.yaml into database."""
  1161. try:
  1162. # Get the path to community-inference-backends.yaml
  1163. yaml_file = files("gpustack.assets").joinpath(
  1164. "community-inference-backends.yaml"
  1165. )
  1166. if not yaml_file.is_file():
  1167. logger.debug(
  1168. "community-inference-backends.yaml not found, skipping community backend initialization"
  1169. )
  1170. return
  1171. yaml_data = yaml.safe_load(yaml_file.read_text())
  1172. if not yaml_data:
  1173. logger.debug(
  1174. "No community backends found in community-inference-backends.yaml"
  1175. )
  1176. return
  1177. if not isinstance(yaml_data, list):
  1178. logger.error(
  1179. f"Invalid community-inference-backends.yaml format: expected list, got {type(yaml_data).__name__}"
  1180. )
  1181. return
  1182. # Collect backend names from YAML
  1183. yaml_backend_names = set()
  1184. for backend_config in yaml_data:
  1185. backend_name = backend_config.get("backend_name")
  1186. if backend_name:
  1187. yaml_backend_names.add(backend_name)
  1188. await self._upsert_community_backend(session, backend_config)
  1189. # Query all community backends from database. Only Platform
  1190. # rows are owned by the catalog yaml; Org-private community
  1191. # additions stay untouched.
  1192. all_backends = await InferenceBackend.all(session)
  1193. db_community_backends = [
  1194. backend
  1195. for backend in all_backends
  1196. if backend.backend_source == BackendSourceEnum.COMMUNITY
  1197. and backend.owner_principal_id is None
  1198. ]
  1199. # Delete community backends that are no longer in YAML
  1200. for backend in db_community_backends:
  1201. if backend.backend_name in yaml_backend_names:
  1202. continue
  1203. if backend.enabled:
  1204. # Convert to custom backend to preserve user's custom versions
  1205. # Convert all built_in_frameworks versions to custom_framework versions
  1206. converted_versions = {}
  1207. if backend.version_configs and backend.version_configs.root:
  1208. for version, config in backend.version_configs.root.items():
  1209. config_data = config.model_dump()
  1210. if config_data.get("built_in_frameworks"):
  1211. config_data["custom_framework"] = config_data[
  1212. "built_in_frameworks"
  1213. ][0]
  1214. config_data["built_in_frameworks"] = None
  1215. converted_versions[version] = VersionConfig(**config_data)
  1216. # Prepare update data
  1217. update_data = {
  1218. "backend_source": BackendSourceEnum.CUSTOM,
  1219. "enabled": False,
  1220. "version_configs": VersionConfigDict(root=converted_versions),
  1221. }
  1222. flag_modified(backend, "version_configs")
  1223. await backend.update(session, update_data)
  1224. logger.info(
  1225. f"Converted community backend '{backend.backend_name}' to custom backend"
  1226. )
  1227. else:
  1228. # Delete if no custom versions
  1229. await backend.delete(session)
  1230. logger.info(
  1231. f"Deleted community backend '{backend.backend_name}' "
  1232. f"(no longer in community-inference-backends.yaml)"
  1233. )
  1234. logger.debug(
  1235. "Community backends initialized from community-inference-backends.yaml"
  1236. )
  1237. except (ModuleNotFoundError, FileNotFoundError):
  1238. # community_backends directory or yaml file does not exist
  1239. logger.debug(
  1240. "Community backends directory or file not found, skipping initialization"
  1241. )
  1242. except Exception as e:
  1243. logger.error(f"Failed to initialize community backends: {e}")
  1244. async def _upsert_community_backend(self, session: AsyncSession, config: dict):
  1245. """Create or update a community backend from YAML configuration."""
  1246. backend_name = config.get("backend_name")
  1247. if not backend_name:
  1248. return
  1249. # Prepare backend data
  1250. allowed_keys = [
  1251. "backend_name",
  1252. "version_configs",
  1253. "default_version",
  1254. "default_backend_param",
  1255. "default_run_command",
  1256. "default_entrypoint",
  1257. "health_check_path",
  1258. "description",
  1259. "icon",
  1260. "default_env",
  1261. ]
  1262. backend_data = {k: config[k] for k in allowed_keys if k in config}
  1263. # Set backend source
  1264. backend_data["backend_source"] = BackendSourceEnum.COMMUNITY
  1265. backend_data["enabled"] = False
  1266. # Convert version_configs to VersionConfigDict
  1267. if 'version_configs' in backend_data and backend_data['version_configs']:
  1268. version_config_dict = {}
  1269. for version, ver_config in backend_data['version_configs'].items():
  1270. # All versions loaded from YAML are predefined versions
  1271. # Convert framework information to built_in_frameworks
  1272. frameworks = None
  1273. if 'built_in_frameworks' in ver_config:
  1274. frameworks = ver_config['built_in_frameworks']
  1275. elif (
  1276. 'custom_framework' in ver_config and ver_config['custom_framework']
  1277. ):
  1278. # Even if YAML uses custom_framework, convert it to built_in_frameworks
  1279. frameworks = [ver_config['custom_framework']]
  1280. # Set built_in_frameworks and clear custom_framework
  1281. if frameworks:
  1282. ver_config['built_in_frameworks'] = (
  1283. frameworks if isinstance(frameworks, list) else [frameworks]
  1284. )
  1285. else:
  1286. # If no framework specified, use empty list to mark as predefined version
  1287. ver_config['built_in_frameworks'] = []
  1288. # Ensure custom_framework is None (predefined versions should not have custom_framework)
  1289. ver_config['custom_framework'] = None
  1290. version_config_dict[version] = VersionConfig(**ver_config)
  1291. backend_data['version_configs'] = VersionConfigDict(
  1292. root=version_config_dict
  1293. )
  1294. # Upsert: update if exists, create if not. Community backends seed
  1295. # at the Platform scope (owner_principal_id IS NULL) — Org-private
  1296. # extensions live in additional rows owned by Orgs.
  1297. existing = await InferenceBackend.one_by_fields(
  1298. session, {"backend_name": backend_name, "owner_principal_id": None}
  1299. )
  1300. if existing:
  1301. # Smart merge logic to preserve user customizations
  1302. # 1. Merge version_configs: preserve user custom versions, update YAML versions
  1303. if 'version_configs' in backend_data and backend_data['version_configs']:
  1304. yaml_versions = backend_data['version_configs'].root
  1305. existing_versions = (
  1306. existing.version_configs.root if existing.version_configs else {}
  1307. )
  1308. # Create merged version dictionary
  1309. merged_versions = {}
  1310. # First add all YAML versions (overwrite old versions with same name)
  1311. for version, config in yaml_versions.items():
  1312. merged_versions[version] = config
  1313. # Then add user custom versions (built_in_frameworks is None)
  1314. for version, config in existing_versions.items():
  1315. if (
  1316. config.built_in_frameworks is None
  1317. and version not in yaml_versions
  1318. ):
  1319. # This is a user custom version not in YAML, preserve it
  1320. merged_versions[version] = config
  1321. backend_data['version_configs'] = VersionConfigDict(
  1322. root=merged_versions
  1323. )
  1324. # 2. Preserve user-modified enabled status (if user enabled it, don't reset to False)
  1325. if existing.enabled:
  1326. backend_data['enabled'] = True
  1327. # 3. Merge default_env (preserve user-added environment variables)
  1328. if existing.default_env:
  1329. if 'default_env' in backend_data and backend_data['default_env']:
  1330. # Merge: YAML environment variables + user-added environment variables
  1331. merged_env = dict(existing.default_env)
  1332. merged_env.update(backend_data['default_env'])
  1333. backend_data['default_env'] = merged_env
  1334. else:
  1335. # YAML doesn't define it, preserve user's
  1336. backend_data['default_env'] = existing.default_env
  1337. # 4. Update database
  1338. await existing.update(session, backend_data)
  1339. else:
  1340. backend = InferenceBackend(**backend_data)
  1341. await InferenceBackend.create(session, backend)
  1342. class ModelFileController:
  1343. """
  1344. Model file controller syncs the model file download status to related model instances.
  1345. """
  1346. async def start(self):
  1347. """
  1348. Start the controller.
  1349. """
  1350. async for event in ModelFile.subscribe(source="model_file_controller"):
  1351. if event.type == EventType.CREATED or event.type == EventType.UPDATED:
  1352. await self._reconcile(event)
  1353. async def _reconcile(self, event: Event):
  1354. """
  1355. Reconcile the model file.
  1356. """
  1357. file: ModelFile = event.data
  1358. try:
  1359. async with async_session() as session:
  1360. file = await ModelFile.one_by_id(
  1361. session,
  1362. file.id,
  1363. options=[
  1364. selectinload(ModelFile.instances),
  1365. selectinload(ModelFile.draft_instances),
  1366. ],
  1367. )
  1368. if not file:
  1369. # In case the file is deleted
  1370. return
  1371. for instance in file.instances + file.draft_instances:
  1372. async with async_session() as session:
  1373. await sync_instance_files_state(session, instance, [file])
  1374. except Exception as e:
  1375. logger.error(f"Failed to reconcile model file {file.id}: {e}")
  1376. async def sync_instance_files_state(
  1377. session: AsyncSession, instance: ModelInstance, files: List[ModelFile]
  1378. ):
  1379. for file in files:
  1380. if file.worker_id == instance.worker_id:
  1381. is_draft_model = _is_draft_model_file(file, instance)
  1382. if is_draft_model:
  1383. await sync_main_worker_model_file_state(
  1384. session, file, instance, is_draft_model=True
  1385. )
  1386. else:
  1387. await sync_main_worker_model_file_state(session, file, instance)
  1388. else:
  1389. await sync_distributed_model_file_state(session, file, instance)
  1390. def _is_draft_model_file(file: ModelFile, instance: ModelInstance) -> bool:
  1391. """
  1392. Check if the model file is the draft model file for the given model instance.
  1393. """
  1394. if not instance.draft_model_source:
  1395. return False
  1396. if file.model_source_index == instance.draft_model_source.model_source_index:
  1397. return True
  1398. # The model uses a local path as its draft source, but the model file may come from a remote source.
  1399. # Match by resolved path.
  1400. if (
  1401. instance.draft_model_source.source == SourceEnum.LOCAL_PATH
  1402. and file.resolved_paths
  1403. and file.resolved_paths[0] == instance.draft_model_source.local_path
  1404. ):
  1405. return True
  1406. return False
  1407. async def sync_main_worker_model_file_state(
  1408. session: AsyncSession,
  1409. file: ModelFile,
  1410. instance: ModelInstance,
  1411. is_draft_model: bool = False,
  1412. ):
  1413. """
  1414. Sync the model file state to the related model instance.
  1415. """
  1416. if instance.state == ModelInstanceStateEnum.ERROR:
  1417. return
  1418. logger.trace(
  1419. f"Syncing model file {file.id} with model instance {instance.id}, file state: {file.state}, "
  1420. f"progress: {file.download_progress}, message: {file.state_message}, instance state: {instance.state}"
  1421. )
  1422. need_update = False
  1423. # Downloading
  1424. if file.state == ModelFileStateEnum.DOWNLOADING:
  1425. if instance.state == ModelInstanceStateEnum.INITIALIZING:
  1426. # Download started
  1427. instance.state = ModelInstanceStateEnum.DOWNLOADING
  1428. instance.download_progress = 0
  1429. instance.state_message = ""
  1430. need_update = True
  1431. elif instance.state == ModelInstanceStateEnum.DOWNLOADING:
  1432. # Update download progress
  1433. if (
  1434. is_draft_model
  1435. and file.download_progress != instance.draft_model_download_progress
  1436. and instance.draft_model_download_progress != 100
  1437. ):
  1438. # For the draft model file
  1439. instance.draft_model_download_progress = file.download_progress
  1440. need_update = True
  1441. elif (
  1442. file.download_progress != instance.download_progress
  1443. and instance.download_progress != 100
  1444. ):
  1445. # For the main model file
  1446. instance.download_progress = file.download_progress
  1447. need_update = True
  1448. # Download completed
  1449. elif file.state == ModelFileStateEnum.READY and (
  1450. instance.state == ModelInstanceStateEnum.DOWNLOADING
  1451. or instance.state == ModelInstanceStateEnum.INITIALIZING
  1452. ):
  1453. if is_draft_model and (
  1454. instance.draft_model_download_progress != 100
  1455. or not instance.draft_model_resolved_path
  1456. ):
  1457. # Download completed for the draft model file
  1458. instance.draft_model_download_progress = 100
  1459. instance.draft_model_resolved_path = file.resolved_paths[0]
  1460. need_update = True
  1461. elif not is_draft_model and (
  1462. instance.download_progress != 100 or not instance.resolved_path
  1463. ):
  1464. # Download completed for the main model file
  1465. instance.download_progress = 100
  1466. instance.resolved_path = file.resolved_paths[0]
  1467. need_update = True
  1468. if model_instance_download_completed(instance):
  1469. # All files are downloaded
  1470. instance.state = ModelInstanceStateEnum.STARTING
  1471. instance.state_message = ""
  1472. need_update = True
  1473. elif instance.state == ModelInstanceStateEnum.INITIALIZING:
  1474. # one but not all files downloaded, turn to DOWNLOADING state
  1475. instance.state = ModelInstanceStateEnum.DOWNLOADING
  1476. instance.state_message = ""
  1477. need_update = True
  1478. # Download error
  1479. elif file.state == ModelFileStateEnum.ERROR:
  1480. instance.state = ModelInstanceStateEnum.ERROR
  1481. instance.state_message = file.state_message
  1482. need_update = True
  1483. if need_update:
  1484. await ModelInstanceService(session).update(instance)
  1485. async def sync_distributed_model_file_state( # noqa: C901
  1486. session: AsyncSession, file: ModelFile, instance: ModelInstance
  1487. ):
  1488. """
  1489. Sync the model file state to the related model instance.
  1490. """
  1491. if instance.state == ModelInstanceStateEnum.ERROR:
  1492. return
  1493. if (
  1494. not instance.distributed_servers
  1495. or not instance.distributed_servers.download_model_files
  1496. ):
  1497. return
  1498. logger.trace(
  1499. f"Syncing distributed model file {file.id} with model instance {instance.name}, file state: {file.state}, "
  1500. f"progress: {file.download_progress}, message: {file.state_message}, instance state: {instance.state}"
  1501. )
  1502. need_update = False
  1503. for item in instance.distributed_servers.subordinate_workers or []:
  1504. if item.worker_id == file.worker_id:
  1505. if (
  1506. file.state == ModelFileStateEnum.DOWNLOADING
  1507. and file.download_progress != item.download_progress
  1508. ):
  1509. item.download_progress = file.download_progress
  1510. need_update = True
  1511. elif (
  1512. file.state == ModelFileStateEnum.READY and item.download_progress != 100
  1513. ):
  1514. item.download_progress = 100
  1515. if model_instance_download_completed(instance):
  1516. # All files are downloaded
  1517. instance.state = ModelInstanceStateEnum.STARTING
  1518. instance.state_message = ""
  1519. need_update = True
  1520. elif file.state == ModelFileStateEnum.ERROR:
  1521. instance.state = ModelInstanceStateEnum.ERROR
  1522. instance.state_message = file.state_message
  1523. need_update = True
  1524. if need_update:
  1525. flag_modified(instance, "distributed_servers")
  1526. await ModelInstanceService(session).update(instance)
  1527. def model_instance_download_completed(instance: ModelInstance):
  1528. if instance.download_progress != 100:
  1529. return False
  1530. if instance.draft_model_source and instance.draft_model_download_progress != 100:
  1531. return False
  1532. if (
  1533. instance.distributed_servers
  1534. and instance.distributed_servers.download_model_files
  1535. ):
  1536. for subworker in instance.distributed_servers.subordinate_workers or []:
  1537. if subworker.download_progress != 100:
  1538. return False
  1539. return True
  1540. def _get_worker_ids_for_file_download(
  1541. instance: ModelInstance,
  1542. ) -> List[str]:
  1543. """
  1544. Get the all worker IDs of the model instance that are
  1545. responsible for downloading the model files,
  1546. including the main worker and distributed workers.
  1547. """
  1548. worker_ids = [instance.worker_id] if instance.worker_id else []
  1549. if (
  1550. instance.distributed_servers
  1551. and instance.distributed_servers.download_model_files
  1552. ):
  1553. worker_ids += [
  1554. item.worker_id
  1555. for item in instance.distributed_servers.subordinate_workers or []
  1556. if item.worker_id
  1557. ]
  1558. return worker_ids
  1559. async def new_workers_from_pool(
  1560. session: AsyncSession, pool: WorkerPool
  1561. ) -> List[Worker]:
  1562. fields = {"deleted_at": None, "worker_pool_id": pool.id}
  1563. current_workers = await Worker.all_by_fields(session, fields=fields)
  1564. current_workers = [
  1565. worker
  1566. for worker in current_workers
  1567. if worker.state not in [WorkerStateEnum.DELETING]
  1568. ]
  1569. # if has enough workers, no need to create more
  1570. if len(current_workers) >= pool.replicas:
  1571. return []
  1572. delta = pool.replicas - len(current_workers)
  1573. if pool.batch_size is not None and delta > pool.batch_size:
  1574. delta = pool.batch_size
  1575. provisioning_workers = [
  1576. worker
  1577. for worker in current_workers
  1578. if worker.state in [WorkerStateEnum.PROVISIONING]
  1579. ]
  1580. # if has enough provisioning workers, no need to create more
  1581. if pool.batch_size <= len(provisioning_workers):
  1582. return []
  1583. new_workers = []
  1584. for _ in range(delta):
  1585. new_worker = Worker(
  1586. hostname="",
  1587. ip="",
  1588. ifname="",
  1589. port=0,
  1590. worker_uuid="",
  1591. cluster=pool.cluster,
  1592. worker_pool=pool,
  1593. provider=pool.cluster.provider,
  1594. name=f"pool-{pool.id}-"
  1595. + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)),
  1596. labels={
  1597. "provider": pool.cluster.provider.value,
  1598. "instance_type": pool.instance_type or "unknown",
  1599. **pool.labels,
  1600. },
  1601. state=WorkerStateEnum.PENDING,
  1602. status=WorkerStatus.get_default_status(),
  1603. )
  1604. new_workers.append(new_worker)
  1605. return new_workers
  1606. class WorkerPoolController:
  1607. """Worker pool controller creates new workers based on the worker pool configuration."""
  1608. async def start(self):
  1609. async for event in WorkerPool.subscribe(source="worker_pool_controller"):
  1610. if event.type == EventType.HEARTBEAT:
  1611. continue
  1612. try:
  1613. await self._reconcile(event)
  1614. except Exception as e:
  1615. logger.error(f"Failed to reconcile worker pool: {e}")
  1616. async def _reconcile(self, event: Event):
  1617. """
  1618. Reconcile the worker pool state with the current event.
  1619. """
  1620. logger.info(f"Reconcile worker pool {event.data.id} with event {event.type}")
  1621. async with async_session() as session:
  1622. pool = await WorkerPool.one_by_id(
  1623. session, event.data.id, options=[selectinload(WorkerPool.cluster)]
  1624. )
  1625. if pool is None or pool.deleted_at is not None:
  1626. return
  1627. # mark the data to avoid read after commit
  1628. cluster_name = pool.cluster.name
  1629. cluster = pool.cluster
  1630. pool_id = pool.id
  1631. workers = await new_workers_from_pool(session, pool)
  1632. if len(workers) == 0:
  1633. return
  1634. ids = []
  1635. for worker in workers:
  1636. created_worker: Worker = await Worker.create(
  1637. session=session, source=worker, auto_commit=False
  1638. )
  1639. ids.append(created_worker.id)
  1640. if cluster.state == ClusterStateEnum.PENDING:
  1641. cluster.state = ClusterStateEnum.PROVISIONING
  1642. cluster.state_message = None
  1643. await cluster.update(session=session, auto_commit=False)
  1644. await session.commit()
  1645. logger.info(
  1646. f"Created {len(ids)} new workers {ids} for cluster {cluster_name} worker pool {pool_id}"
  1647. )
  1648. class WorkerProvisioningController:
  1649. def __init__(self, cfg: Config):
  1650. self._cfg = cfg
  1651. @classmethod
  1652. async def _create_ssh_key(
  1653. cls,
  1654. session: AsyncSession,
  1655. client: ProviderClientBase,
  1656. worker: Worker,
  1657. ) -> int:
  1658. """
  1659. Generate a new ssh key pair,
  1660. And Create ssh_key in cloud provider.
  1661. Create SSHKey record without commit and returns it.
  1662. """
  1663. logger.info(f"Creating ssh key for worker {worker.name}")
  1664. private_key, public_key = generate_ssh_key_pair()
  1665. ssh_key = Credential(
  1666. credential_type=CredentialType.SSH,
  1667. public_key=public_key,
  1668. encoded_private_key=private_key,
  1669. ssh_key_options=SSHKeyOptions(
  1670. algorithm="ED25519",
  1671. length=0,
  1672. ),
  1673. )
  1674. ssh_key_id = await client.create_ssh_key(worker.name, public_key)
  1675. ssh_key.external_id = str(ssh_key_id)
  1676. ssh_key_rtn = await Credential.create(session, ssh_key, auto_commit=False)
  1677. return ssh_key_rtn.id
  1678. @classmethod
  1679. async def _create_instances(
  1680. cls,
  1681. session: AsyncSession,
  1682. client: ProviderClientBase,
  1683. worker: Worker,
  1684. cfg: Config,
  1685. ) -> str:
  1686. secret_fields = set(SensitivePredefinedConfig.model_fields.keys())
  1687. secret_configs = (
  1688. worker.cluster.worker_config.model_dump(include=secret_fields)
  1689. if worker.cluster.worker_config
  1690. else {}
  1691. )
  1692. user_data = await client.construct_user_data(
  1693. server_url=worker.cluster.server_url or cfg.server_external_url,
  1694. token=worker.cluster.registration_token,
  1695. image_name=get_cluster_image_name(worker.cluster.worker_config),
  1696. os_image=worker.worker_pool.os_image,
  1697. secret_configs=secret_configs,
  1698. worker_name=worker.name,
  1699. )
  1700. ssh_key = await Credential.one_by_id(session, worker.ssh_key_id)
  1701. if ssh_key is None:
  1702. raise ValueError(f"SSH key {worker.ssh_key_id} not found")
  1703. to_create = construct_cloud_instance(worker, ssh_key, user_data.format())
  1704. logger.info(f"Creating cloud instance for worker {worker.name}")
  1705. logger.debug(f"Cloud instance configuration: {to_create}")
  1706. return await client.create_instance(to_create)
  1707. @classmethod
  1708. async def _provisioning_started(
  1709. cls,
  1710. session: AsyncSession,
  1711. client: ProviderClientBase,
  1712. worker: Worker,
  1713. instance: CloudInstance,
  1714. ) -> bool:
  1715. changed = True
  1716. provider_config = worker.provider_config or {}
  1717. volumes = list(
  1718. (getattr(worker.worker_pool.cloud_options, "volumes", None) or [])
  1719. )
  1720. volume_ids = provider_config.get("volume_ids", [])
  1721. if worker.advertise_address is None or worker.advertise_address == "":
  1722. try:
  1723. instance = await client.wait_for_public_ip(worker.external_id)
  1724. worker.advertise_address = (
  1725. instance.ip_address if instance.ip_address else ""
  1726. )
  1727. worker.state_message = "Waiting for volumes to attach"
  1728. except Exception as e:
  1729. logger.warning(
  1730. f"Failed to wait for instance {worker.external_id} to get public ip: {e}"
  1731. )
  1732. elif len(volumes) != len(volume_ids) and len(volumes) > 0:
  1733. volume_ids = await client.create_volumes_and_attach(
  1734. worker.id, worker.external_id, worker.cluster.region, *volumes
  1735. )
  1736. provider_config["volume_ids"] = volume_ids
  1737. worker.provider_config = provider_config
  1738. elif (
  1739. len(volumes) == len(volume_ids)
  1740. and worker.state == WorkerStateEnum.PROVISIONING
  1741. ):
  1742. if not hasattr(provider_config, "volume_ids"):
  1743. provider_config["volume_ids"] = []
  1744. worker.provider_config = provider_config
  1745. worker.state = WorkerStateEnum.INITIALIZING
  1746. if worker.cluster.state != ClusterStateEnum.PROVISIONED:
  1747. worker.cluster.state = ClusterStateEnum.PROVISIONED
  1748. await worker.cluster.update(session=session, auto_commit=False)
  1749. worker.state_message = "Initializing: installing required drivers and software. The worker will start automatically after setup."
  1750. else:
  1751. changed = False
  1752. return changed
  1753. @classmethod
  1754. async def _provisioning_before_started(
  1755. cls,
  1756. session: AsyncSession,
  1757. client: ProviderClientBase,
  1758. worker: Worker,
  1759. cfg: Config,
  1760. ) -> Tuple[Optional[CloudInstance], bool]:
  1761. """
  1762. return started and changed
  1763. """
  1764. instance = None
  1765. changed = False
  1766. if worker.external_id is not None:
  1767. instance = await client.get_instance(worker.external_id)
  1768. # TODO should handle instance not exist problem
  1769. if instance is None or instance.status == InstanceState.RUNNING:
  1770. return instance, changed
  1771. changed = True
  1772. if worker.state == WorkerStateEnum.PENDING:
  1773. worker.state = WorkerStateEnum.PROVISIONING
  1774. worker.state_message = "Creating SSH key"
  1775. elif worker.ssh_key_id is None:
  1776. worker.ssh_key_id = await cls._create_ssh_key(session, client, worker)
  1777. worker.state_message = "Creating cloud instance"
  1778. elif worker.external_id is None:
  1779. worker.external_id = await cls._create_instances(
  1780. session, client, worker, cfg
  1781. )
  1782. worker.state_message = "Waiting for cloud instance started"
  1783. elif worker.external_id is not None:
  1784. try:
  1785. # depress the timeout exception
  1786. instance = await client.wait_for_started(worker.external_id)
  1787. worker.state_message = "Waiting for instance's public ip"
  1788. except Exception as e:
  1789. logger.warning(
  1790. f"Failed to wait for instance {worker.external_id} to start: {e}"
  1791. )
  1792. return instance, changed
  1793. @classmethod
  1794. async def _provisioning_instance(
  1795. cls,
  1796. session: AsyncSession,
  1797. client: ProviderClientBase,
  1798. worker: Worker,
  1799. cfg: Config,
  1800. ):
  1801. # provider_config = worker.provider_config or {}
  1802. # Phase I is to ensure instance running.
  1803. instance, changed = await cls._provisioning_before_started(
  1804. session, client, worker, cfg
  1805. )
  1806. if (
  1807. not changed
  1808. and instance is not None
  1809. and instance.status == InstanceState.RUNNING
  1810. ):
  1811. # Phase II is to wait for instance infomation and attach volume.
  1812. changed = await cls._provisioning_started(session, client, worker, instance)
  1813. if changed:
  1814. await WorkerService(session).update(
  1815. worker=worker, source=None, auto_commit=False
  1816. )
  1817. @classmethod
  1818. async def _deleting_instance(
  1819. cls,
  1820. session: AsyncSession,
  1821. client: ProviderClientBase,
  1822. worker: Worker,
  1823. ):
  1824. if worker.external_id is None:
  1825. return
  1826. ssh_key = await Credential.one_by_id(session, worker.ssh_key_id)
  1827. try:
  1828. await client.delete_instance(worker.external_id)
  1829. if ssh_key and ssh_key.external_id:
  1830. await client.delete_ssh_key(ssh_key.external_id)
  1831. except Exception as e:
  1832. logger.error(f"Failed to delete instance {worker.external_id}: {e}")
  1833. # if using soft delete here, skip deletion and remove external_id
  1834. if ssh_key:
  1835. await ssh_key.delete(session, auto_commit=False)
  1836. if worker.deleted_at is not None:
  1837. await WorkerService(session).delete(worker, auto_commit=False)
  1838. async def check_server_external_url(self, cluster_server_url: Optional[str] = None):
  1839. server_url = cluster_server_url or self._cfg.server_external_url
  1840. if server_url is None or server_url == "":
  1841. raise ValueError(
  1842. "Cluster's server_url is not configured, Please edit cluster first."
  1843. )
  1844. import aiohttp
  1845. from yarl import URL
  1846. healthz_url = str(URL(server_url) / "healthz")
  1847. try:
  1848. async with aiohttp.ClientSession() as session:
  1849. async with session.get(healthz_url, timeout=10) as resp:
  1850. if resp.status != 200:
  1851. raise ValueError(
  1852. f"External server healthz url {healthz_url} is not reachable, status code: {resp.status}"
  1853. )
  1854. except Exception as e:
  1855. raise ValueError(
  1856. f"Failed to check external server healthz url {healthz_url}: {e}"
  1857. )
  1858. async def _reconcile(self, event: Event):
  1859. """
  1860. When provisioning a worker, the state will transition from following steps:
  1861. - PENDING - initial state for worker created by pool, the next state is PROVISIONING
  1862. - PROVISIONING - begin provisioning with related info updated in worker object, the next state is PROVISIONED
  1863. - PROVISIONED - done provisioning and waiting for worker to register
  1864. - DELETING - worker is being deleted
  1865. - ERROR - an error occurred during provisioning
  1866. """
  1867. worker: Worker = event.data
  1868. if not worker:
  1869. return
  1870. if worker.state not in [
  1871. WorkerStateEnum.PENDING,
  1872. WorkerStateEnum.PROVISIONING,
  1873. WorkerStateEnum.DELETING,
  1874. ]:
  1875. return
  1876. logger.info(
  1877. f"Reconcile provisioning worker {event.data.name} with event {event.type}"
  1878. )
  1879. async with async_session() as session:
  1880. # Fetch the worker from the database
  1881. worker: Worker = await Worker.one_by_id(
  1882. session,
  1883. worker.id,
  1884. options=[
  1885. selectinload(Worker.cluster),
  1886. selectinload(Worker.worker_pool),
  1887. ],
  1888. )
  1889. if not worker:
  1890. return
  1891. credential: CloudCredential = await CloudCredential.one_by_id(
  1892. session, worker.cluster.credential_id
  1893. )
  1894. client = get_client_from_provider(
  1895. worker.cluster.provider,
  1896. credential=credential,
  1897. )
  1898. try:
  1899. if worker.state == WorkerStateEnum.PENDING:
  1900. await self.check_server_external_url(worker.cluster.server_url)
  1901. if worker.state in [
  1902. WorkerStateEnum.PENDING,
  1903. WorkerStateEnum.PROVISIONING,
  1904. ]:
  1905. await self._provisioning_instance(
  1906. session, client, worker, self._cfg
  1907. )
  1908. if worker.state == WorkerStateEnum.DELETING:
  1909. await self._deleting_instance(session, client, worker)
  1910. await session.commit()
  1911. except Exception as e:
  1912. message = f"Failed to provision or delete worker {worker.name}: {e}"
  1913. logger.exception(message)
  1914. await session.rollback()
  1915. await session.refresh(worker)
  1916. worker.state = WorkerStateEnum.ERROR
  1917. worker.state_message = message
  1918. await WorkerService(session).update(
  1919. worker=worker, source=None, auto_commit=True
  1920. )
  1921. class ClusterController:
  1922. def __init__(self, cfg: Config):
  1923. self._cfg = cfg
  1924. self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
  1925. self._k8s_config = get_async_k8s_config(cfg=cfg)
  1926. pass
  1927. async def start(self):
  1928. """
  1929. Start the controller.
  1930. """
  1931. if self._cfg.gateway_mode != GatewayModeEnum.disabled:
  1932. base_client = k8s_client.ApiClient(configuration=self._k8s_config)
  1933. self._higress_network_api = NetworkingHigressIoV1Api(base_client)
  1934. async for event in Cluster.subscribe(source="cluster_controller"):
  1935. if event.type == EventType.HEARTBEAT:
  1936. continue
  1937. try:
  1938. await self._reconcile(event)
  1939. except Exception as e:
  1940. logger.error(f"Failed to reconcile cluster: {e}")
  1941. async def _reconcile(self, event: Event):
  1942. """
  1943. Reconcile the cluster state.
  1944. """
  1945. await self._sync_cluster_state(event)
  1946. if self._disable_gateway:
  1947. return
  1948. await self._ensure_worker_mcp_bridge(event)
  1949. async def _sync_cluster_state(self, event: Event):
  1950. if event.type == EventType.DELETED:
  1951. return
  1952. cluster: Cluster = event.data
  1953. if not cluster:
  1954. return
  1955. async with async_session() as session:
  1956. cluster: Cluster = await Cluster.one_by_id(
  1957. session, cluster.id, options=[selectinload(Cluster.cluster_workers)]
  1958. )
  1959. if not cluster or cluster.provider in [
  1960. ClusterProvider.Kubernetes,
  1961. ClusterProvider.Docker,
  1962. ]:
  1963. return
  1964. if cluster.workers == 0 and cluster.state != ClusterStateEnum.PENDING:
  1965. cluster.state = ClusterStateEnum.PENDING
  1966. cluster.state_message = (
  1967. "No workers have been provisioned for this cluster yet."
  1968. )
  1969. await cluster.update(session=session, auto_commit=True)
  1970. async def _ensure_worker_mcp_bridge(self, event: Event):
  1971. """
  1972. The worker registry list for cluster is no longer needed.
  1973. Use empty list to trigger MCPBridge controller to clean up the worker registries
  1974. and proxies when cluster is created or deleted.
  1975. """
  1976. if self._cfg.gateway_mode == GatewayModeEnum.disabled:
  1977. return
  1978. cluster: Cluster = event.data
  1979. mcp_resource_name = mcp_handler.default_mcp_bridge_name
  1980. desired_registries = []
  1981. to_delete_prefix = mcp_handler.cluster_worker_prefix(cluster.id)
  1982. try:
  1983. await mcp_handler.ensure_mcp_bridge(
  1984. client=self._higress_network_api,
  1985. namespace=self._cfg.gateway_namespace,
  1986. mcp_bridge_name=mcp_resource_name,
  1987. desired_registries=desired_registries,
  1988. to_delete_prefix=to_delete_prefix,
  1989. )
  1990. except Exception as e:
  1991. logger.error(f"Failed to ensure MCPBridge for cluster {cluster.name}: {e}")
  1992. raise
  1993. async def notify_model_route_target(session: AsyncSession, model: Model, event: Event):
  1994. if event.type == EventType.DELETED:
  1995. return
  1996. should_notify = False
  1997. if event.changed_fields is not None:
  1998. related_fields = ["ready_replicas", "replicas"]
  1999. for field in related_fields:
  2000. if field in event.changed_fields:
  2001. should_notify = True
  2002. break
  2003. model: Model = await Model.one_by_id(
  2004. session=session,
  2005. id=model.id,
  2006. options=[
  2007. selectinload(Model.model_route_targets),
  2008. ],
  2009. )
  2010. if not model:
  2011. return
  2012. targets = model.model_route_targets
  2013. for target in targets:
  2014. if should_notify:
  2015. target_copy = ModelRouteTarget(**target.model_dump())
  2016. await event_bus.publish(
  2017. target_copy.__class__.__name__.lower(),
  2018. Event(
  2019. type=EventType.UPDATED,
  2020. data=target_copy,
  2021. changed_fields={
  2022. "model": (
  2023. {},
  2024. {
  2025. "id": model.id,
  2026. "name": model.name,
  2027. "ready_replicas": model.ready_replicas,
  2028. "replicas": model.replicas,
  2029. },
  2030. )
  2031. },
  2032. ),
  2033. )
  2034. async def sync_categories_and_meta(session: AsyncSession, model: Model, event: Event):
  2035. if event.type == EventType.DELETED:
  2036. return
  2037. model: Model = await Model.one_by_id(
  2038. session=session,
  2039. id=model.id,
  2040. options=[
  2041. selectinload(Model.model_routes),
  2042. ],
  2043. )
  2044. if not model:
  2045. return
  2046. routes = model.model_routes
  2047. for route in routes:
  2048. # created_by_model default to false if not set
  2049. if not route.created_by_model:
  2050. continue
  2051. if route.categories != model.categories or route.meta != model.meta:
  2052. await ModelRouteService(session).update(
  2053. model_route=route,
  2054. source={"categories": model.categories, "meta": model.meta},
  2055. auto_commit=True,
  2056. )
  2057. class ModelProviderController:
  2058. def __init__(self, cfg: Config):
  2059. self._config = cfg
  2060. self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
  2061. self._k8s_config = get_async_k8s_config(cfg=cfg)
  2062. async def start(self):
  2063. if self._disable_gateway:
  2064. return
  2065. if not self._disable_gateway:
  2066. base_client = k8s_client.ApiClient(configuration=self._k8s_config)
  2067. self._higress_network_api = NetworkingHigressIoV1Api(base_client)
  2068. self._higress_extension_api = ExtensionsHigressIoV1Api(base_client)
  2069. async for event in ModelProvider.subscribe(source="model_provider_controller"):
  2070. try:
  2071. await self._reconcile(event)
  2072. except Exception as e:
  2073. logger.exception(f"Failed to reconcile model provider: {e}")
  2074. async def _ensure_provider_registry(
  2075. self,
  2076. model_provider: ModelProvider,
  2077. event: Event,
  2078. ):
  2079. provider_registry = mcp_handler.provider_registry(model_provider)
  2080. registry_to_remove = (
  2081. provider_registry is None or event.type == EventType.DELETED
  2082. )
  2083. to_delete_prefix = (
  2084. f"{mcp_handler.provider_id_prefix}{model_provider.id}"
  2085. if registry_to_remove
  2086. else None
  2087. )
  2088. desired_registries = [] if registry_to_remove else [provider_registry]
  2089. provider_proxy = mcp_handler.provider_proxy(model_provider)
  2090. proxy_to_remove = provider_proxy is None or event.type == EventType.DELETED
  2091. to_delete_proxy_prefix = (
  2092. f"proxy-{model_provider.id}" if proxy_to_remove else None
  2093. )
  2094. desired_proxies = [] if proxy_to_remove else [provider_proxy]
  2095. try:
  2096. await mcp_handler.ensure_mcp_bridge(
  2097. client=self._higress_network_api,
  2098. namespace=self._config.gateway_namespace,
  2099. mcp_bridge_name=mcp_handler.default_mcp_bridge_name,
  2100. desired_registries=desired_registries,
  2101. desired_proxies=desired_proxies,
  2102. to_delete_prefix=to_delete_prefix,
  2103. to_delete_proxies_prefix=to_delete_proxy_prefix,
  2104. )
  2105. except Exception as e:
  2106. logger.error(
  2107. f"Failed to ensure MCPRegistry for model provider {model_provider.name}: {e}"
  2108. )
  2109. raise
  2110. async def _ensure_provider_ai_proxy_config(self):
  2111. try:
  2112. async with async_session() as session:
  2113. providers = await ModelProvider.all_by_field(
  2114. session,
  2115. "deleted_at",
  2116. None,
  2117. )
  2118. provider_config_list, match_rules = (
  2119. mcp_handler.provider_proxy_plugin_spec(*providers)
  2120. )
  2121. await mcp_handler.ensure_wasm_plugin(
  2122. api=self._higress_extension_api,
  2123. name=mcp_handler.gpustack_ai_proxy_name,
  2124. namespace=self._config.gateway_namespace,
  2125. spec_diff=partial(
  2126. mcp_handler.ai_proxy_diff_spec,
  2127. expected_providers=provider_config_list,
  2128. expected_match_rules=match_rules,
  2129. operating_id_prefix=mcp_handler.provider_id_prefix,
  2130. ),
  2131. )
  2132. except Exception as e:
  2133. logger.error(f"Failed to ensure provider's ai_proxy config: {e}")
  2134. raise
  2135. async def _notify_provider_model_routes(
  2136. self, session: AsyncSession, model_provider: ModelProvider, event: Event
  2137. ):
  2138. if event.type != EventType.UPDATED:
  2139. return
  2140. changed_fields = event.changed_fields or {}
  2141. should_notify = False
  2142. if "config" not in changed_fields:
  2143. return
  2144. # the changed field "config" must have old and new value, otherwise it's not a valid update event for config change.
  2145. # index 0 of the tuple is the old value, index 1 is the new value.
  2146. # each value must be a list with only 1 element as it is a norman field instead of relationship field.
  2147. old_config = changed_fields["config"][0][0]
  2148. if isinstance(changed_fields["config"][0][0], BaseModel):
  2149. old_config = changed_fields["config"][0][0].model_dump()
  2150. new_config = changed_fields["config"][1][0]
  2151. if isinstance(changed_fields["config"][1][0], BaseModel):
  2152. new_config = changed_fields["config"][1][0].model_dump()
  2153. # use hardcoded fields to determine whether to notify.
  2154. # For ProviderConfigType, including:
  2155. # - openaiCustomUrl
  2156. # - ollamaServerHost
  2157. # - difyApiUrl
  2158. # The above fields will affect the registry type of the provider_registry,
  2159. # it requires notifying ingress to regenerate registry destination.
  2160. related_fields = [
  2161. "openaiCustomUrl",
  2162. "ollamaServerHost",
  2163. "difyApiUrl",
  2164. ]
  2165. for field in related_fields:
  2166. if old_config.get(field) != new_config.get(field):
  2167. should_notify = True
  2168. break
  2169. if not should_notify:
  2170. return
  2171. targets = await ModelRouteTarget.all_by_fields(
  2172. session=session,
  2173. fields={"provider_id": model_provider.id},
  2174. options=[selectinload(ModelRouteTarget.model_route)],
  2175. )
  2176. unique_routes = {
  2177. target.model_route.id: target.model_route
  2178. for target in targets
  2179. if target.model_route is not None
  2180. }
  2181. for route in unique_routes.values():
  2182. route_copy = ModelRoute.model_validate(route.model_dump())
  2183. await event_bus.publish(
  2184. route_copy.__class__.__name__.lower(),
  2185. Event(type=EventType.UPDATED, data=route_copy),
  2186. )
  2187. async def _reconcile(self, event: Event):
  2188. """
  2189. Reconcile the model provider.
  2190. """
  2191. model_provider: ModelProvider = event.data
  2192. if not model_provider:
  2193. return
  2194. if event.type == EventType.DELETED:
  2195. await self._ensure_provider_registry(model_provider, event)
  2196. await self._ensure_provider_ai_proxy_config()
  2197. return
  2198. async with async_session() as session:
  2199. model_provider: ModelProvider = await ModelProvider.one_by_id(
  2200. session, model_provider.id
  2201. )
  2202. if not model_provider:
  2203. return
  2204. await self._ensure_provider_registry(model_provider, event)
  2205. await self._ensure_provider_ai_proxy_config()
  2206. await self._notify_provider_model_routes(session, model_provider, event)
  2207. class ModelRouteTargetController:
  2208. def __init__(self, config: Config):
  2209. self._config = config
  2210. async def start(self):
  2211. async for event in ModelRouteTarget.subscribe(
  2212. source="model_route_target_controller"
  2213. ):
  2214. try:
  2215. await self._reconcile(event)
  2216. except Exception as e:
  2217. logger.exception(f"Failed to reconcile model route target: {e}")
  2218. async def _notify_parents(
  2219. self, session: AsyncSession, target: ModelRouteTarget, event: Event
  2220. ):
  2221. if event.type not in (EventType.UPDATED, EventType.DELETED):
  2222. return
  2223. changed_fields = event.changed_fields
  2224. if not target or (not changed_fields and event.type != EventType.DELETED):
  2225. return
  2226. should_notify_fields = [
  2227. "state",
  2228. "provider_id",
  2229. "model_id",
  2230. "provider_model_name",
  2231. "model",
  2232. ]
  2233. should_notify = event.type == EventType.DELETED
  2234. if not should_notify:
  2235. for field in should_notify_fields:
  2236. if field in (changed_fields or {}):
  2237. should_notify = True
  2238. break
  2239. if not should_notify:
  2240. return
  2241. try:
  2242. model_route: ModelRoute = await ModelRoute.one_by_id(
  2243. session, target.route_id
  2244. )
  2245. if not model_route:
  2246. return
  2247. copied_route = ModelRoute.model_validate(model_route.model_dump())
  2248. await event_bus.publish(
  2249. ModelRoute.__name__.lower(),
  2250. Event(type=EventType.UPDATED, data=copied_route),
  2251. )
  2252. except Exception as e:
  2253. logger.error(f"Failed to notify model route for target {target.name}: {e}")
  2254. async def _sync_state(
  2255. self, session: AsyncSession, target: ModelRouteTarget, event: Event
  2256. ):
  2257. if event.type == EventType.DELETED:
  2258. return
  2259. # Handle ID-only events from distributed mode
  2260. target_id = (
  2261. target.id
  2262. if hasattr(target, 'id')
  2263. else target.get('id') if isinstance(target, dict) else None
  2264. )
  2265. if not target_id:
  2266. return
  2267. target: ModelRouteTarget = await ModelRouteTarget.one_by_id(session, target_id)
  2268. if not target:
  2269. return
  2270. if target.provider_id is not None:
  2271. target_state = TargetStateEnum.ACTIVE
  2272. if target.model_id is not None:
  2273. model = await Model.one_by_id(session, target.model_id)
  2274. if not model:
  2275. return
  2276. target_state = (
  2277. TargetStateEnum.ACTIVE
  2278. if model.ready_replicas > 0
  2279. else TargetStateEnum.UNAVAILABLE
  2280. )
  2281. if target.state != target_state:
  2282. target.state = target_state
  2283. await target.update(session=session, auto_commit=True)
  2284. async def _update_orphan_route(
  2285. self, session: AsyncSession, target: ModelRouteTarget, event: Event
  2286. ) -> bool:
  2287. """
  2288. Update the orphan route if the target is deleted or has no associated model.
  2289. If the target model is not deleted, transfer model_route to a non model-created model.
  2290. """
  2291. if event.type != EventType.DELETED:
  2292. return True
  2293. if target.model_id is None:
  2294. return True
  2295. model = await Model.one_by_id(session, target.model_id)
  2296. if not model or model.deleted_at is not None:
  2297. return True
  2298. # If the model is not deleted, transfer the model route to a non model-created model route to avoid service disruption.
  2299. # The model route will be automatically deleted by the controller after the target is deleted.
  2300. orphan_route = await ModelRoute.one_by_id(session=session, id=target.route_id)
  2301. if (
  2302. not orphan_route
  2303. or orphan_route.deleted_at is not None
  2304. or not orphan_route.created_by_model
  2305. ):
  2306. # The route is already deleted or not created by model, no need to transfer.
  2307. # returns true to trigger parent notification and state sync to update the route state if needed.
  2308. return True
  2309. try:
  2310. route_service = ModelRouteService(session=session)
  2311. await route_service.update(
  2312. orphan_route, source={"created_by_model": False}, auto_commit=True
  2313. )
  2314. except Exception as e:
  2315. logger.error(f"Failed to transfer model route {orphan_route.id}: {e}")
  2316. return True
  2317. return False
  2318. async def _reconcile(self, event: Event):
  2319. target: ModelRouteTarget = event.data
  2320. if not target:
  2321. return
  2322. async with async_session() as session:
  2323. should_notify_parents = await self._update_orphan_route(
  2324. session, target, event
  2325. )
  2326. if should_notify_parents:
  2327. await self._notify_parents(session, target, event)
  2328. await self._sync_state(session, target, event)
  2329. class ModelRouteController:
  2330. def __init__(self, cfg: Config):
  2331. self._config = cfg
  2332. self._gateway_namespace = cfg.gateway_namespace
  2333. self._k8s_config = get_async_k8s_config(cfg=cfg)
  2334. self._disable_gateway = cfg.gateway_mode == GatewayModeEnum.disabled
  2335. async def start(self):
  2336. if not self._disable_gateway:
  2337. base_client = k8s_client.ApiClient(configuration=self._k8s_config)
  2338. self._networking_api = k8s_client.NetworkingV1Api(base_client)
  2339. self._higress_extension_api = ExtensionsHigressIoV1Api(base_client)
  2340. self._networking_istio_api = NetworkingIstioIoV1Alpha3Api(base_client)
  2341. async for event in ModelRoute.subscribe(source="model_route_controller"):
  2342. try:
  2343. await self._reconcile(event)
  2344. except Exception as e:
  2345. logger.exception(f"Failed to reconcile model route: {e}")
  2346. async def _sync_targets(self, session: AsyncSession, event: Event) -> bool:
  2347. if event.type == EventType.DELETED:
  2348. return False
  2349. model_route: ModelRoute = event.data
  2350. if not model_route:
  2351. return False
  2352. # Handle ID-only events from distributed mode
  2353. model_route_id = (
  2354. model_route.id
  2355. if hasattr(model_route, 'id')
  2356. else model_route.get('id') if isinstance(model_route, dict) else None
  2357. )
  2358. if not model_route_id:
  2359. return False
  2360. model_route: ModelRoute = await ModelRoute.one_by_id(
  2361. session,
  2362. model_route_id,
  2363. options=[selectinload(ModelRoute.route_targets)],
  2364. )
  2365. if not model_route:
  2366. return False
  2367. target_total = len(model_route.route_targets)
  2368. ready_target_total = len(
  2369. [
  2370. target
  2371. for target in model_route.route_targets
  2372. if target.state == TargetStateEnum.ACTIVE
  2373. ]
  2374. )
  2375. model_route_service = ModelRouteService(session=session)
  2376. if target_total == 0 and model_route.created_by_model:
  2377. await model_route_service.delete(model_route, auto_commit=True)
  2378. return True
  2379. if (
  2380. model_route.targets != target_total
  2381. or model_route.ready_targets != ready_target_total
  2382. ):
  2383. model_route.targets = target_total
  2384. model_route.ready_targets = ready_target_total
  2385. await model_route_service.update(model_route, auto_commit=True)
  2386. return True
  2387. return False
  2388. async def _reconcile(self, event: Event):
  2389. """
  2390. Reconcile the model route.
  2391. """
  2392. model_route: ModelRoute = event.data
  2393. if not model_route:
  2394. return
  2395. async with async_session() as session:
  2396. # sync targets will update model route record so make sure to do it before other operations
  2397. updated = await self._sync_targets(session, event)
  2398. if not self._disable_gateway and not updated:
  2399. await sync_gateway(
  2400. cfg=self._config,
  2401. session=session,
  2402. event=event,
  2403. networking_api=self._networking_api,
  2404. extensions_api=self._higress_extension_api,
  2405. model_route=model_route,
  2406. istio_networking_api=self._networking_istio_api,
  2407. )
  2408. await distribute_models_to_user(session, model_route, event)