inference_backend.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603
  1. import logging
  2. import math
  3. from copy import deepcopy
  4. from typing import List, Tuple, Optional, Dict
  5. import yaml
  6. from fastapi import APIRouter, Body
  7. from gpustack_runner.runner import ServiceVersionedRunner, ServiceRunner
  8. from gpustack_runtime.deployer.__utils__ import compare_versions
  9. from pydantic import ValidationError
  10. from starlette.responses import StreamingResponse
  11. from gpustack.api.exceptions import (
  12. InternalServerErrorException,
  13. NotFoundException,
  14. BadRequestException,
  15. )
  16. from gpustack.api.tenant import (
  17. assert_org_owned_writable,
  18. validate_owner_principal,
  19. )
  20. from gpustack.schemas import Worker
  21. from gpustack.schemas.common import Pagination
  22. from gpustack.schemas.inference_backend import (
  23. InferenceBackend,
  24. InferenceBackendCreate,
  25. InferenceBackendListItem,
  26. InferenceBackendResponse,
  27. InferenceBackendUpdate,
  28. InferenceBackendsPublic,
  29. VersionConfig,
  30. VersionConfigDict,
  31. get_built_in_backend,
  32. InferenceBackendPublic,
  33. VersionListItem,
  34. is_built_in_backend,
  35. )
  36. from gpustack.schemas.models import BackendEnum, Model, BackendSourceEnum
  37. from gpustack.server.db import async_session
  38. from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
  39. from gpustack_runner import list_service_runners
  40. from gpustack_runtime.detector.ascend import get_ascend_cann_variant
  41. from gpustack_runtime.detector import ManufacturerEnum
  42. logger = logging.getLogger(__name__)
  43. router = APIRouter()
  44. def filter_yaml_fields(yaml_data: Dict, filter_keys: List[str]) -> Dict: # noqa: C901
  45. """
  46. Recursively remove specified keys from a nested YAML dict.
  47. Args:
  48. yaml_data: Dictionary parsed from YAML content.
  49. filter_keys: List of keys to remove wherever they appear.
  50. Returns:
  51. The same dict instance after filtering.
  52. """
  53. if not isinstance(yaml_data, dict):
  54. return yaml_data
  55. def _filter_in_place(obj: Dict):
  56. # Delete keys that should be filtered
  57. for key in list(obj.keys()):
  58. if key in filter_keys:
  59. try:
  60. del obj[key]
  61. except Exception:
  62. # Silently ignore any deletion issues
  63. pass
  64. continue
  65. # Recurse into nested dicts
  66. val = obj.get(key)
  67. if isinstance(val, dict):
  68. _filter_in_place(val)
  69. elif isinstance(val, list):
  70. for item in val:
  71. if isinstance(item, dict):
  72. _filter_in_place(item)
  73. _filter_in_place(yaml_data)
  74. return yaml_data
  75. async def check_backend_in_use(
  76. session: SessionDep, backend_name: str, backend_version: Optional[str] = None
  77. ) -> Tuple[bool, List[str]]:
  78. """
  79. Check if a backend or specific backend version is being used by any models.
  80. Args:
  81. session: Database session
  82. backend_name: The name of the backend to check
  83. backend_version: Optional specific version to check. If None, checks all versions.
  84. Returns:
  85. A tuple containing:
  86. - Boolean indicating if the backend/version is in use
  87. - List of model names that are using the backend/version
  88. """
  89. try:
  90. # Query models that use the specified backend
  91. if backend_version:
  92. # Check for specific backend and version combination
  93. models = await Model.all_by_fields(
  94. session, {"backend": backend_name, "backend_version": backend_version}
  95. )
  96. else:
  97. # Check for any models using this backend (any version)
  98. models = await Model.all_by_field(session, "backend", backend_name)
  99. models = [model for model in models if model.replicas > 0]
  100. model_names = [model.name for model in models]
  101. is_in_use = len(models) > 0
  102. return is_in_use, model_names
  103. except Exception as e:
  104. logger.error(f"Error checking backend usage: {e}")
  105. return False, []
  106. def get_lower_version_runners(
  107. runners: list[ServiceRunner], backend_version: str
  108. ) -> list[ServiceRunner]:
  109. """
  110. Filter runners whose version is less than or equal to the given backend_version.
  111. Rebuilds the list[ServiceRunner] structure with only the matching elements.
  112. Args:
  113. runners: List of ServiceRunner objects to filter
  114. backend_version: The version to compare against (only runners with versions <= this will be kept)
  115. Returns:
  116. List of ServiceRunner objects with filtered versions/backends
  117. """
  118. filtered_runners = []
  119. for runner in runners:
  120. # Create a new runner with filtered structure
  121. new_runner = deepcopy(runner)
  122. # Filter versions in backends
  123. for version in new_runner.versions:
  124. for backend in version.backends:
  125. # Filter backend versions that are <= backend_version
  126. backend.versions = [
  127. bv
  128. for bv in backend.versions
  129. if compare_versions(bv.version, backend_version) <= 0
  130. ]
  131. # Remove backends with no matching versions
  132. for version in new_runner.versions:
  133. version.backends = [
  134. backend for backend in version.backends if backend.versions
  135. ]
  136. # Remove versions with no matching backends
  137. new_runner.versions = [
  138. version for version in new_runner.versions if version.backends
  139. ]
  140. # Only add runner if it has matching versions
  141. if new_runner.versions:
  142. filtered_runners.append(new_runner)
  143. return filtered_runners
  144. def get_runner_versions_and_configs(
  145. backend_name: str, backend_version: Optional[str], **kwargs
  146. ) -> Tuple[Dict[str, ServiceVersionedRunner], VersionConfigDict, Optional[str]]:
  147. """
  148. Get runner versions and version configs for a given backend.
  149. Args:
  150. backend_name: The name of the backend service
  151. kwargs: Others keyword arguments to pass to list_service_runners()
  152. Returns:
  153. A tuple containing:
  154. - List of version strings
  155. - VersionConfigDict with version configurations
  156. - Default version (first available version or None)
  157. """
  158. runners_list = list_service_runners(
  159. service=backend_name.lower(),
  160. **kwargs,
  161. )
  162. if backend_version:
  163. runners_list = get_lower_version_runners(runners_list, backend_version)
  164. runner_versions: Dict[str, ServiceVersionedRunner] = {}
  165. version_configs = VersionConfigDict()
  166. default_version = None
  167. if runners_list and len(runners_list) > 0:
  168. for version in runners_list[0].versions:
  169. if version.version:
  170. runner_versions[version.version] = version
  171. backend_list = [
  172. f"{backend_runner.backend}" for backend_runner in version.backends
  173. ]
  174. version_configs.root[version.version] = VersionConfig(
  175. built_in_frameworks=backend_list,
  176. )
  177. if default_version is None:
  178. default_version = version.version
  179. return runner_versions, version_configs, default_version
  180. def deduplicate_versions(versions: List[VersionListItem]) -> List[VersionListItem]:
  181. seen = set()
  182. result = []
  183. for item in versions:
  184. key = (item.version, item.is_deprecated)
  185. if key not in seen:
  186. seen.add(key)
  187. result.append(item)
  188. return result
  189. def get_runner_deprecate(runners: List[ServiceVersionedRunner]) -> bool:
  190. """
  191. Check if all runners are deprecated.
  192. Args:
  193. runners: List of ServiceVersionedRunner objects
  194. Returns:
  195. True if all runners are deprecated, False otherwise.
  196. Returns False if the list is empty.
  197. """
  198. if not runners:
  199. return False
  200. return all(
  201. runner.backends[0].versions[0].variants[0].deprecated for runner in runners
  202. )
  203. def merge_list_runners( # noqa: C901
  204. backend_name: str, workers: List[Worker]
  205. ) -> Tuple[Dict[str, List[ServiceVersionedRunner]], VersionConfigDict, Optional[str]]:
  206. """
  207. Merge runner versions and configs from multiple workers.
  208. Extracts gpu.type and gpu.runtime_version from each worker's GPU devices
  209. and uses them as query conditions for list_service_runners.
  210. Args:
  211. backend_name: The name of the backend service
  212. workers: List of workers to extract GPU information from
  213. Returns:
  214. A tuple containing:
  215. - Dict[str, List[ServiceVersionedRunner]]: Merged runner versions, grouped by version
  216. - VersionConfigDict: Merged version configurations
  217. - Optional[str]: Default version (from first query)
  218. """
  219. # Collect unique query conditions from all workers
  220. query_conditions = set()
  221. for worker in workers:
  222. if worker.status and worker.status.gpu_devices:
  223. for gpu in worker.status.gpu_devices:
  224. # Extract variant for Ascend GPUs
  225. variant = None
  226. if gpu.vendor == ManufacturerEnum.ASCEND and gpu.arch_family:
  227. variant = get_ascend_cann_variant(gpu.arch_family).lower()
  228. # Add (type, runtime_version, variant) tuple to set
  229. # Use None for runtime_version if not available
  230. query_conditions.add((gpu.type, gpu.runtime_version, variant))
  231. merged_runner_versions: Dict[str, List[ServiceVersionedRunner]] = {}
  232. merged_version_configs = VersionConfigDict()
  233. merged_default_version = None
  234. # Loop through each unique query condition
  235. for idx, (gpu_type, runtime_version, variant) in enumerate(query_conditions):
  236. # Build kwargs for get_runner_versions_and_configs
  237. kwargs = {"backend": gpu_type}
  238. if variant:
  239. kwargs["backend_variant"] = variant
  240. # Get runner versions and configs for this condition
  241. runner_versions, version_configs, default_version = (
  242. get_runner_versions_and_configs(backend_name, runtime_version, **kwargs)
  243. )
  244. # For the first condition, use its results as base
  245. if idx == 0:
  246. # Convert Dict[str, ServiceVersionedRunner] to Dict[str, List[ServiceVersionedRunner]]
  247. merged_runner_versions = {
  248. version: [runner] for version, runner in runner_versions.items()
  249. }
  250. merged_version_configs = version_configs
  251. merged_default_version = default_version
  252. else:
  253. # Merge runner versions (append to list if exists)
  254. for version, runner in runner_versions.items():
  255. if version in merged_runner_versions:
  256. merged_runner_versions[version].append(runner)
  257. else:
  258. merged_runner_versions[version] = [runner]
  259. # Merge version configs
  260. for version, config in version_configs.root.items():
  261. if version not in merged_version_configs.root:
  262. # Add new version
  263. merged_version_configs.root[version] = config
  264. else:
  265. # Merge built_in_frameworks (deduplicate)
  266. existing_frameworks = (
  267. merged_version_configs.root[version].built_in_frameworks or []
  268. )
  269. new_frameworks = config.built_in_frameworks or []
  270. merged_frameworks = list(set(existing_frameworks + new_frameworks))
  271. merged_version_configs.root[version].built_in_frameworks = (
  272. merged_frameworks
  273. )
  274. return merged_runner_versions, merged_version_configs, merged_default_version
  275. @router.get("/list", response_model=InferenceBackendResponse)
  276. async def list_backend_configs( # noqa: C901
  277. session: SessionDep,
  278. ctx: TenantContextDep,
  279. cluster_id: Optional[int] = None,
  280. ):
  281. """
  282. Get list of available backend configurations with version information.
  283. Returns both built-in backends and custom backends from database.
  284. Built-in backends are identified and enhanced with runner versions.
  285. Each backend item includes available versions.
  286. Hybrid: when an Org row and a Platform row share the same backend_name,
  287. the Org row's metadata + version_configs win, then Platform versions
  288. are merged in for any keys the Org didn't define.
  289. """
  290. items = []
  291. if cluster_id and cluster_id > 0:
  292. workers = await Worker.all_by_field(session, "cluster_id", cluster_id)
  293. else:
  294. workers = await Worker.all(session)
  295. # Process all backends from database (includes both built-in and custom backends)
  296. try:
  297. all_rows = await InferenceBackend.all(session)
  298. # Hybrid filter:
  299. # - Single-Org caller (member, or platform admin act-as): see
  300. # Platform rows (NULL) + their own Org's rows. The merge below
  301. # collapses these into one entry per backend_name with Org keys
  302. # winning on collisions.
  303. # - Bypass mode (admin "All", system users): there's no single Org
  304. # to merge with, so we fall back to Platform-only. Merging across
  305. # multiple Org rows for the same backend_name would be
  306. # ill-defined (last-Org-wins), and the response model
  307. # (InferenceBackendListItem) has no owner_principal_id field to
  308. # distinguish them anyway. Callers that need a specific Org's
  309. # overrides — including workers running tenant-scoped deploys —
  310. # should fetch by id or pass an org context.
  311. bypass_filter = (
  312. ctx is None
  313. or (ctx.is_platform_admin and ctx.current_principal_id is None)
  314. or getattr(getattr(ctx, "user", None), "is_system", False)
  315. )
  316. if bypass_filter:
  317. visible_rows = [b for b in all_rows if b.owner_principal_id is None]
  318. else:
  319. visible_rows = [
  320. b
  321. for b in all_rows
  322. if b.owner_principal_id is None
  323. or b.owner_principal_id == ctx.current_principal_id
  324. ]
  325. # Group by backend_name; collapse Platform + Org into one logical
  326. # backend with merged versions (Org wins on key collisions). With
  327. # the filter above, ``visible_rows`` contains at most one Org row
  328. # per backend_name, so the merge is well-defined.
  329. #
  330. # Stash merged values in side dicts keyed by db id rather than
  331. # mutating the ORM rows themselves — no ``expunge`` dance, no
  332. # risk of a stray flush persisting the read-time merge.
  333. merged_versions_by_id: Dict[int, VersionConfigDict] = {}
  334. grouped: Dict[str, InferenceBackend] = {}
  335. for b in visible_rows:
  336. name = b.backend_name
  337. existing = grouped.get(name)
  338. if existing is None:
  339. grouped[name] = b
  340. continue
  341. org_row = b if (b.owner_principal_id is not None) else existing
  342. other = existing if org_row is b else b
  343. merged_versions = {
  344. **(other.version_configs.root if other.version_configs else {}),
  345. **(org_row.version_configs.root if org_row.version_configs else {}),
  346. }
  347. merged_versions_by_id[org_row.id] = VersionConfigDict(root=merged_versions)
  348. grouped[name] = org_row
  349. inference_backends = list(grouped.values())
  350. for backend in inference_backends:
  351. effective_version_configs = merged_versions_by_id.get(
  352. backend.id, backend.version_configs
  353. )
  354. # Get versions from version_config
  355. versions: List[VersionListItem] = []
  356. if effective_version_configs and effective_version_configs.root:
  357. versions = [
  358. VersionListItem(
  359. version=version, env=backend.get_backend_env(version)
  360. )
  361. for version in effective_version_configs.root.keys()
  362. ]
  363. if backend.is_built_in:
  364. # For built-in backends, add runner versions and use special show name
  365. runner_versions, version_configs, default_version = merge_list_runners(
  366. backend.backend_name,
  367. workers,
  368. )
  369. # Merge runner versions with existing versions
  370. for version, config in version_configs.root.items():
  371. # Check if this version has any built-in frameworks
  372. if config.built_in_frameworks:
  373. # Versions are only marked deprecated when no worker is compatible with them.
  374. is_deprecated = get_runner_deprecate(
  375. runner_versions.get(version, [])
  376. )
  377. # Get environment for this specific version
  378. version_env = backend.get_backend_env(version)
  379. versions.append(
  380. VersionListItem(
  381. version=version,
  382. is_deprecated=is_deprecated,
  383. env=version_env,
  384. )
  385. )
  386. # Remove duplicates while preserving order
  387. versions = deduplicate_versions(versions)
  388. # Use the runner-derived default if the row didn't set one;
  389. # local var so we don't mutate the ORM object.
  390. effective_default_version = backend.default_version or default_version
  391. backend_item = InferenceBackendListItem(
  392. backend_name=backend.backend_name,
  393. default_version=effective_default_version,
  394. default_backend_param=backend.default_backend_param,
  395. versions=versions,
  396. is_built_in=backend.is_built_in,
  397. enabled=True,
  398. backend_source=BackendSourceEnum.BUILT_IN,
  399. default_env=backend.default_env,
  400. )
  401. else:
  402. if (
  403. backend.backend_source == BackendSourceEnum.COMMUNITY
  404. and not backend.enabled
  405. ):
  406. continue
  407. # For custom backends, use backend_name as show_name
  408. backend_item = InferenceBackendListItem(
  409. backend_name=backend.backend_name,
  410. default_version=backend.default_version,
  411. default_backend_param=backend.default_backend_param,
  412. versions=versions,
  413. is_built_in=False,
  414. enabled=backend.enabled,
  415. backend_source=backend.backend_source,
  416. default_env=backend.default_env,
  417. )
  418. items.append(backend_item)
  419. # Ensure Custom backend is always included even if not in database
  420. custom_backend_item = InferenceBackendListItem(
  421. backend_name=BackendEnum.CUSTOM,
  422. default_version=None,
  423. default_backend_param=None,
  424. versions=[],
  425. is_built_in=False,
  426. enabled=True,
  427. backend_source=BackendSourceEnum.BUILT_IN,
  428. default_env=None,
  429. )
  430. items.append(custom_backend_item)
  431. except Exception as e:
  432. # Log error but don't fail the entire request
  433. logger.error(f"Failed to load backends from database: {e}")
  434. return InferenceBackendResponse(items=items)
  435. def _hybrid_backend_conditions(ctx) -> List:
  436. """Hybrid visibility filter for inference_backends.
  437. Platform rows (owner_principal_id IS NULL) are visible to everyone.
  438. Org rows are visible to:
  439. - their own Org's members (current_principal_id matches)
  440. - platform admin in "All" mode (no current_principal_id) — full bypass
  441. - system users (worker / cluster service accounts) — full bypass,
  442. since they need every Org's overrides to actually run a deploy
  443. whose backend version was customised at the Org level
  444. Platform admin in act-as mode (current_principal_id is set) follows the
  445. same scope as a non-admin caller in that Org: Platform NULL +
  446. that Org's rows only. They DON'T see other Orgs' rows while
  447. pretending to be in this one.
  448. """
  449. if ctx is None:
  450. return []
  451. if getattr(ctx.user, "is_system", False):
  452. return []
  453. if ctx.is_platform_admin and ctx.current_principal_id is None:
  454. return []
  455. from sqlalchemy import or_
  456. or_clauses = [InferenceBackend.owner_principal_id.is_(None)]
  457. if ctx.current_principal_id is not None:
  458. or_clauses.append(
  459. InferenceBackend.owner_principal_id == ctx.current_principal_id
  460. )
  461. return [or_(*or_clauses)]
  462. async def _fetch_visible_backend_rows(session, ctx) -> List[InferenceBackend]:
  463. """Hybrid-aware DB read: Platform rows always; Org rows scoped to ctx."""
  464. extra_conditions = _hybrid_backend_conditions(ctx)
  465. if extra_conditions:
  466. return await InferenceBackend.all_by_fields(
  467. session, fields={}, extra_conditions=extra_conditions
  468. )
  469. return await InferenceBackend.all(session)
  470. def _enrich_built_in_with_runner_versions(
  471. db_backend: InferenceBackendPublic,
  472. backend_name: str,
  473. with_deprecated: bool,
  474. ) -> None:
  475. """Layer runner-discovered versions on top of the DB row in place."""
  476. _, runner_versions, default_version = get_runner_versions_and_configs(
  477. backend_name,
  478. backend_version=None,
  479. with_deprecated=with_deprecated,
  480. )
  481. for runner_version, version_config in runner_versions.root.items():
  482. db_backend.built_in_version_configs[runner_version] = version_config
  483. if default_version and not db_backend.default_version:
  484. db_backend.default_version = default_version
  485. def _migrate_community_built_in_versions(db_backend: InferenceBackendPublic) -> None:
  486. """Move version_configs entries that carry built_in_frameworks into the
  487. dedicated built_in_version_configs map (community backends only)."""
  488. if (
  489. db_backend.backend_source != BackendSourceEnum.COMMUNITY
  490. or not db_backend.version_configs
  491. or not db_backend.version_configs.root
  492. ):
  493. return
  494. versions_to_move = {
  495. version: config
  496. for version, config in db_backend.version_configs.root.items()
  497. if config.built_in_frameworks
  498. }
  499. if not versions_to_move:
  500. return
  501. if not db_backend.built_in_version_configs:
  502. db_backend.built_in_version_configs = {}
  503. db_backend.built_in_version_configs.update(versions_to_move)
  504. for version in versions_to_move:
  505. del db_backend.version_configs.root[version]
  506. def _collapse_by_backend_name(
  507. db_result_sorted: List[InferenceBackend],
  508. ) -> List[InferenceBackendPublic]:
  509. """Collapse Platform + Org rows that share a backend_name into one
  510. public-model entry. Used for the non-admin single-card view.
  511. - Org row wins on metadata + version_configs (Org keys override
  512. Platform keys, missing Org keys fall back to Platform).
  513. - **Exception: ``enabled``**. Use ``Platform.enabled OR Org.enabled``
  514. so a stale or accidental Org row with ``enabled=False`` cannot
  515. shadow a Platform-enabled backend. The tradeoff is that an Org
  516. can no longer "disable" a Platform-shared community backend in
  517. its own scope — disabling has to happen at the Platform level.
  518. That's a deliberate choice: keeping the Hybrid view simple and
  519. avoiding "I didn't disable it but it's gone" confusion is worth
  520. more than per-Org opt-out, which can be re-introduced later via
  521. an explicit ``override_enabled`` flag if needed.
  522. Returns ``InferenceBackendPublic`` copies rather than ORM rows so the
  523. read-time merge can never be flushed back to the database. The caller
  524. pays one ``model_dump`` per row, which is cheap relative to the DB
  525. read this is feeding.
  526. """
  527. by_name: Dict[str, InferenceBackendPublic] = {}
  528. for backend in db_result_sorted:
  529. existing = by_name.get(backend.backend_name)
  530. if existing is None:
  531. by_name[backend.backend_name] = InferenceBackendPublic(
  532. **backend.model_dump()
  533. )
  534. continue
  535. # `existing` is the public copy of whatever we saw first; `backend`
  536. # is the new ORM row. Decide which side is the Org row and merge.
  537. if backend.owner_principal_id is not None:
  538. org_versions = backend.version_configs
  539. other_versions = existing.version_configs
  540. org_enabled = bool(backend.enabled)
  541. other_enabled = bool(existing.enabled)
  542. target = InferenceBackendPublic(**backend.model_dump())
  543. else:
  544. org_versions = existing.version_configs
  545. other_versions = backend.version_configs
  546. org_enabled = bool(existing.enabled)
  547. other_enabled = bool(backend.enabled)
  548. target = existing
  549. merged_versions = {
  550. **(other_versions.root if other_versions else {}),
  551. **(org_versions.root if org_versions else {}),
  552. }
  553. target.version_configs = VersionConfigDict(root=merged_versions)
  554. target.enabled = org_enabled or other_enabled
  555. by_name[backend.backend_name] = target
  556. return list(by_name.values())
  557. async def merge_runner_versions_to_db(
  558. session: SessionDep,
  559. with_deprecated: bool = True,
  560. *,
  561. ctx=None,
  562. ) -> List[InferenceBackendPublic]:
  563. """Backends visible to the caller, with runner versions enriched in.
  564. Hybrid display rules:
  565. - **Platform admin**: one row per DB row (no collapse). Admin needs
  566. to manage Platform rows and Org rows separately, so they show as
  567. distinct cards (typically distinguished by an Owner tag in the UI).
  568. - **Non-admin**: collapsed single-card view per backend_name —
  569. Platform + Org rows fold into one entry, Org wins on metadata,
  570. versions union (Org overrides Platform). Org owners don't need
  571. to know about the underlying two-row Hybrid storage.
  572. """
  573. db_result = await _fetch_visible_backend_rows(session, ctx)
  574. # Sort by id ascending so the Org row (created later, larger id)
  575. # naturally wins during the non-admin collapse.
  576. db_result_sorted = sorted(db_result, key=lambda x: x.id if x.id else 0)
  577. # Show uncollapsed rows for admin-style views (managing every row
  578. # independently). Admin act-as mode behaves like the Org member —
  579. # they're acting *inside* that Org and want the collapsed
  580. # single-card UX too.
  581. is_admin_view = ctx is None or (
  582. ctx.is_platform_admin and ctx.current_principal_id is None
  583. )
  584. if is_admin_view:
  585. publics = [
  586. InferenceBackendPublic(**row.model_dump()) for row in db_result_sorted
  587. ]
  588. else:
  589. publics = _collapse_by_backend_name(db_result_sorted)
  590. built_in_names = {
  591. b.backend_name
  592. for b in get_built_in_backend()
  593. if b.backend_name != BackendEnum.CUSTOM.value
  594. }
  595. merged_backends: List[InferenceBackendPublic] = []
  596. for public in publics:
  597. if public.backend_name in built_in_names:
  598. _enrich_built_in_with_runner_versions(
  599. public, public.backend_name, with_deprecated
  600. )
  601. else:
  602. _migrate_community_built_in_versions(public)
  603. merged_backends.append(public)
  604. return merged_backends
  605. def _generate_framework_index_map( # noqa: C901
  606. version_config_dicts: List[Dict[str, VersionConfig]]
  607. ) -> Dict[str, List[str]]:
  608. """
  609. Generate framework index map from a list of version config dictionaries.
  610. Args:
  611. version_config_dicts: List of dictionaries mapping version names to VersionConfig objects
  612. Returns:
  613. Dictionary mapping framework names to sorted lists of supported versions
  614. """
  615. framework_map = {}
  616. for version_configs in version_config_dicts:
  617. if not version_configs:
  618. continue
  619. for version, config in version_configs.items():
  620. if config.built_in_frameworks:
  621. for framework in config.built_in_frameworks:
  622. if framework not in framework_map:
  623. framework_map[framework] = []
  624. if version not in framework_map[framework]:
  625. framework_map[framework].append(version)
  626. if config.custom_framework:
  627. if config.custom_framework not in framework_map:
  628. framework_map[config.custom_framework] = []
  629. framework_map[config.custom_framework].append(version)
  630. # Sort versions for each framework
  631. for framework in framework_map:
  632. framework_map[framework].sort()
  633. return framework_map
  634. def _filter_community_backends(
  635. backends: List[InferenceBackendPublic],
  636. is_only_community: Optional[bool] = None,
  637. ) -> List[InferenceBackendPublic]:
  638. """
  639. Filter backends to only include community backends without custom frameworks.
  640. This function filters the backend list to only include backends with
  641. backend_source=COMMUNITY, and removes any versions that have custom_framework set.
  642. Args:
  643. backends: List of inference backends to filter
  644. Returns:
  645. List of community backends with non-custom framework versions only
  646. """
  647. filter_backends = []
  648. for backend in backends:
  649. if is_only_community:
  650. # using in community_backends catalog
  651. if backend.backend_source != BackendSourceEnum.COMMUNITY:
  652. continue
  653. backend.version_configs.root = {}
  654. else:
  655. # using in common inference_backends view
  656. if (
  657. backend.backend_source == BackendSourceEnum.COMMUNITY
  658. and not backend.enabled
  659. ):
  660. continue
  661. filter_backends.append(backend)
  662. return filter_backends
  663. @router.get("", response_model=InferenceBackendsPublic)
  664. async def get_inference_backends( # noqa: C901
  665. session: SessionDep,
  666. ctx: TenantContextDep,
  667. params: ListParamsDep,
  668. search: str = None,
  669. include_deprecated: bool = False,
  670. community: Optional[bool] = None,
  671. backend_source: Optional[str] = None,
  672. ):
  673. """
  674. Get paginated list of inference backends with optional search and filters.
  675. Args:
  676. session: Database session
  677. params: List parameters (page, perPage, watch, sort_by)
  678. search: Search keyword for backend_name and description
  679. include_deprecated: Include deprecated versions
  680. community: Filter community backends (True=community only with non-custom versions, False/None=all backends)
  681. backend_source: Filter by backend source (built-in, custom, or community)
  682. Returns:
  683. InferenceBackendsPublic: Paginated list of inference backends
  684. """
  685. fields = {}
  686. if params.watch:
  687. # Filter the streamed events with the same Hybrid visibility check.
  688. def _visible(b: InferenceBackend) -> bool:
  689. if ctx is None or (
  690. ctx.is_platform_admin and ctx.current_principal_id is None
  691. ):
  692. return True
  693. # System users (worker / cluster) need every Org's overrides
  694. # because they actually run the deploys.
  695. if getattr(getattr(ctx, "user", None), "is_system", False):
  696. return True
  697. org_id = getattr(b, "owner_principal_id", None)
  698. if org_id is None:
  699. return True
  700. return (
  701. ctx.current_principal_id is not None
  702. and org_id == ctx.current_principal_id
  703. )
  704. return StreamingResponse(
  705. InferenceBackend.streaming(fields=fields, filter_func=_visible),
  706. media_type="text/event-stream",
  707. )
  708. async with async_session() as session:
  709. merged_backends = await merge_runner_versions_to_db(
  710. session, with_deprecated=include_deprecated, ctx=ctx
  711. )
  712. # Get worker GPU information for framework sorting
  713. workers = await Worker.all(session)
  714. framework_list = set()
  715. for worker in workers:
  716. if worker.status and worker.status.gpu_devices:
  717. for gpu in worker.status.gpu_devices:
  718. framework_list.add(gpu.type)
  719. # Single-pass filtering and transformation pipeline:
  720. # 1. Framework sorting (data transformation)
  721. # 2. Search filter (early rejection)
  722. # 3. Community filter (early rejection)
  723. # 4. Backend source filter (early rejection)
  724. # 5. Framework index map generation (final transformation)
  725. filter_backends = []
  726. for backend in merged_backends:
  727. # 1. Sort frameworks by support status (must be first as it modifies data structure)
  728. sorted_version_configs = {}
  729. for version, config in backend.built_in_version_configs.items():
  730. if config.built_in_frameworks:
  731. supported = [
  732. framework
  733. for framework in config.built_in_frameworks
  734. if framework in framework_list
  735. ]
  736. unsupported = [
  737. framework
  738. for framework in config.built_in_frameworks
  739. if framework not in framework_list
  740. ]
  741. config.built_in_frameworks = supported + unsupported
  742. sorted_version_configs[version] = config
  743. backend.built_in_version_configs = sorted_version_configs
  744. # 2. Apply search filter (early rejection to reduce subsequent processing)
  745. if search:
  746. lower_search = search.lower()
  747. if not (
  748. lower_search in backend.backend_name.lower()
  749. or (backend.description and lower_search in backend.description.lower())
  750. ):
  751. continue # Skip backends that don't match search criteria
  752. # 3. Apply community filter (early rejection)
  753. if community is True:
  754. # Using in community_backends catalog
  755. if backend.backend_source != BackendSourceEnum.COMMUNITY:
  756. continue
  757. # Clear custom versions for community backends
  758. if backend.version_configs:
  759. backend.version_configs.root = {}
  760. else:
  761. # Using in common inference_backends view
  762. if (
  763. backend.backend_source == BackendSourceEnum.COMMUNITY
  764. and not backend.enabled
  765. ):
  766. continue
  767. # 4. Apply backend_source filter (early rejection)
  768. if backend_source:
  769. try:
  770. source_enum = BackendSourceEnum(backend_source)
  771. if backend.backend_source != source_enum:
  772. continue
  773. except ValueError:
  774. # Invalid backend_source value, log warning but don't filter
  775. logger.warning(f"Invalid backend_source value: {backend_source}")
  776. # 5. Generate framework_index_map (must be last as it depends on processed data)
  777. version_config_dicts = []
  778. if backend.built_in_version_configs:
  779. version_config_dicts.append(backend.built_in_version_configs)
  780. if backend.version_configs and backend.version_configs.root:
  781. version_config_dicts.append(backend.version_configs.root)
  782. backend.framework_index_map = _generate_framework_index_map(
  783. version_config_dicts
  784. )
  785. # Backend passed all filters, add to result list
  786. filter_backends.append(backend)
  787. # Apply pagination to merged results
  788. total = len(filter_backends)
  789. start_idx = (params.page - 1) * params.perPage
  790. end_idx = start_idx + params.perPage
  791. paginated_backends = filter_backends[start_idx:end_idx]
  792. pagination = Pagination(
  793. page=params.page,
  794. perPage=params.perPage,
  795. total=total,
  796. totalPage=max(math.ceil(total / params.perPage), 1),
  797. )
  798. # Create the response with the same structure as the original
  799. return InferenceBackendsPublic(
  800. items=paginated_backends,
  801. pagination=pagination,
  802. )
  803. @router.get("/all", response_model=List[InferenceBackend])
  804. async def get_all_inference_backends(
  805. session: SessionDep,
  806. ctx: TenantContextDep,
  807. ):
  808. backends = await merge_runner_versions_to_db(session, ctx=ctx)
  809. ret = []
  810. for backend in backends:
  811. if backend.backend_source == BackendSourceEnum.CUSTOM:
  812. ret.append(backend)
  813. continue
  814. for built_in_version, config in backend.built_in_version_configs.items():
  815. # if version in same, db version first
  816. if built_in_version not in backend.version_configs.root:
  817. backend.version_configs.root[built_in_version] = config
  818. ret.append(backend)
  819. return ret
  820. def _assert_backend_visible(ctx, backend):
  821. """Org member can see Platform (NULL) and own-Org rows. Admin sees
  822. everything in "All" mode; in act-as mode they're scoped just like
  823. a regular member of that Org (so a stale link to dev Org's row
  824. while admin is acting-as Default surfaces a 404, not a leak)."""
  825. if backend is None:
  826. raise NotFoundException(message="Inference backend not found")
  827. if ctx.is_platform_admin and ctx.current_principal_id is None:
  828. return
  829. org_id = backend.owner_principal_id
  830. if org_id is None:
  831. return # Platform row is visible to everyone
  832. if ctx.current_principal_id is not None and org_id == ctx.current_principal_id:
  833. return
  834. raise NotFoundException(message="Inference backend not found")
  835. @router.get("/{id}", response_model=InferenceBackend)
  836. async def get_inference_backend(session: SessionDep, ctx: TenantContextDep, id: int):
  837. """
  838. Get a specific inference backend by ID.
  839. """
  840. backend = await InferenceBackend.one_by_id(session, id)
  841. if not backend:
  842. raise BadRequestException(message=f"Inference backend {id} not found")
  843. _assert_backend_visible(ctx, backend)
  844. return backend
  845. @router.get("/backend_name/{backend_name}", response_model=InferenceBackend)
  846. async def get_inference_backend_by_name(
  847. session: SessionDep, ctx: TenantContextDep, backend_name: str
  848. ):
  849. """
  850. Get a specific inference backend by backend name. Resolves to the
  851. caller's Org row if one exists, else falls back to the Platform row.
  852. """
  853. if ctx.current_principal_id is not None and not ctx.is_platform_admin:
  854. org_row = await InferenceBackend.one_by_fields(
  855. session,
  856. {
  857. "backend_name": backend_name,
  858. "owner_principal_id": ctx.current_principal_id,
  859. },
  860. )
  861. if org_row is not None:
  862. return org_row
  863. backend = await InferenceBackend.one_by_fields(
  864. session,
  865. {"backend_name": backend_name, "owner_principal_id": None},
  866. )
  867. if not backend:
  868. raise BadRequestException(message=f"Inference backend {backend_name} not found")
  869. return backend
  870. @router.post("", response_model=InferenceBackend)
  871. async def create_inference_backend(
  872. session: SessionDep,
  873. ctx: TenantContextDep,
  874. backend_in: InferenceBackendCreate,
  875. ):
  876. """
  877. Create a new inference backend.
  878. Hybrid scope:
  879. - Platform admin: owner_principal_id NULL (Platform) or any Org id.
  880. - Org owner / manager: owner_principal_id locked to their current Org.
  881. Same backend_name as a Platform built-in IS allowed for an Org row
  882. (extension/override) — the case-insensitive duplicate check only
  883. bites when creating a Platform row that conflicts with a built-in.
  884. """
  885. target_org_id = getattr(backend_in, "owner_principal_id", None)
  886. validate_owner_principal(
  887. target_org_id,
  888. ctx,
  889. resource_label="inference backend",
  890. )
  891. # Platform-scoped rows can't shadow a built-in name (case-insensitive)
  892. # — the seeding controller owns those. Org-scoped rows MAY use the same
  893. # name to extend / override a built-in for that Org.
  894. if target_org_id is None and is_built_in_backend(backend_in.backend_name):
  895. raise BadRequestException(
  896. message=(
  897. f"Backend name {backend_in.backend_name} duplicates with built-in backends (case-insensitive). Please use another name."
  898. ),
  899. )
  900. backend_in.backend_source = BackendSourceEnum.CUSTOM
  901. backend_in.enabled = True
  902. # Composite unique on (backend_name, owner_principal_id) — uniqueness check
  903. # is scoped to the same tenant.
  904. existing = await InferenceBackend.one_by_fields(
  905. session,
  906. {
  907. "backend_name": backend_in.backend_name,
  908. "owner_principal_id": target_org_id,
  909. },
  910. )
  911. if existing:
  912. raise BadRequestException(
  913. message=f"Inference backend with name '{backend_in.backend_name}' already exists",
  914. )
  915. # Validate version names for custom backends before creating
  916. validate_custom_suffix(backend_in.backend_name, None)
  917. for version in backend_in.version_configs.root.keys():
  918. backend_in.version_configs.root[version].built_in_frameworks = None
  919. try:
  920. backend = InferenceBackend(
  921. backend_name=backend_in.backend_name,
  922. version_configs=backend_in.version_configs,
  923. default_version=backend_in.default_version,
  924. default_backend_param=backend_in.default_backend_param,
  925. default_run_command=backend_in.default_run_command,
  926. default_entrypoint=backend_in.default_entrypoint,
  927. health_check_path=backend_in.health_check_path,
  928. description=backend_in.description,
  929. default_env=backend_in.default_env,
  930. enabled=backend_in.enabled,
  931. backend_source=backend_in.backend_source,
  932. owner_principal_id=target_org_id,
  933. )
  934. backend = await InferenceBackend.create(session, backend)
  935. except Exception as e:
  936. raise InternalServerErrorException(
  937. message=f"Failed to create inference backend: {e}"
  938. )
  939. return backend
  940. async def _redirect_global_edit_to_org_row(
  941. session,
  942. ctx,
  943. backend: InferenceBackend,
  944. backend_in: InferenceBackendUpdate,
  945. ) -> Optional[InferenceBackend]:
  946. """If the caller is in an Org context and the target is a Global
  947. row, route the write to that Org's row. Applies to admin acting-as
  948. too — when admin has switched to Default Org, "enable community
  949. backend" should land in Default's scope, not modify Platform.
  950. Returns:
  951. - the existing Org row if found (caller continues the update on it), OR
  952. - the freshly created Org row (early return; caller should propagate).
  953. Returns ``None`` when no redirect is needed (target already
  954. belongs to the caller's Org, or caller is in "All" mode).
  955. """
  956. if backend.owner_principal_id is not None or ctx.current_principal_id is None:
  957. return None
  958. org_row = await InferenceBackend.one_by_fields(
  959. session,
  960. {
  961. "backend_name": backend.backend_name,
  962. "owner_principal_id": ctx.current_principal_id,
  963. },
  964. )
  965. if org_row is not None:
  966. return org_row
  967. # No Org row yet — seed one from the submitted payload. The Org row
  968. # inherits is_built_in / backend_source from the Platform row it
  969. # extends: an Org-scoped vLLM is still vLLM (a BUILT_IN backend),
  970. # not a freshly invented custom backend. That keeps suffix-validation
  971. # and other built-in-aware code paths firing identically.
  972. new_row = InferenceBackend(
  973. backend_name=backend_in.backend_name,
  974. version_configs=backend_in.version_configs,
  975. default_version=backend_in.default_version,
  976. default_backend_param=backend_in.default_backend_param,
  977. default_run_command=backend_in.default_run_command,
  978. default_entrypoint=backend_in.default_entrypoint,
  979. health_check_path=backend_in.health_check_path,
  980. description=backend_in.description,
  981. default_env=backend_in.default_env,
  982. enabled=True,
  983. is_built_in=backend.is_built_in,
  984. backend_source=backend.backend_source,
  985. owner_principal_id=ctx.current_principal_id,
  986. )
  987. return await InferenceBackend.create(session, new_row)
  988. @router.put("/{id}", response_model=InferenceBackend)
  989. async def update_inference_backend( # noqa: C901
  990. session: SessionDep,
  991. ctx: TenantContextDep,
  992. id: int,
  993. backend_in: InferenceBackendUpdate,
  994. ):
  995. """
  996. Update an existing inference backend.
  997. """
  998. backend = await InferenceBackend.one_by_id(session, id)
  999. if not backend:
  1000. raise NotFoundException(message=f"Inference backend {id} not found")
  1001. redirected = await _redirect_global_edit_to_org_row(
  1002. session, ctx, backend, backend_in
  1003. )
  1004. if redirected is not None:
  1005. # Continue the update flow against the Org row instead of the
  1006. # Global row the caller targeted. For a freshly created Org row
  1007. # the downstream update is effectively a no-op rewrite of the
  1008. # same payload — which is fine and keeps the response shape
  1009. # consistent for both branches.
  1010. backend = redirected
  1011. assert_org_owned_writable(ctx, backend, resource_label="inference backend")
  1012. # Check if updating to a name that already exists (excluding current backend)
  1013. if backend_in.backend_name != backend.backend_name:
  1014. raise BadRequestException(
  1015. message="The name of inference-backend can not be modified",
  1016. )
  1017. # Validate that built-in backends cannot have default_version set
  1018. if is_built_in_backend(backend.backend_name) and backend_in.default_version:
  1019. raise BadRequestException(
  1020. message=f"Built-in backend '{backend.backend_name}' cannot have default_version set. Default version is managed automatically.",
  1021. )
  1022. if backend_in.version_configs is not None:
  1023. await _validate_version_removal(session, backend, backend_in.version_configs)
  1024. # Validate version names for custom backends before updating
  1025. if backend.backend_source == BackendSourceEnum.CUSTOM or (
  1026. backend.backend_source is None and not backend.is_built_in
  1027. ):
  1028. validate_custom_suffix(backend_in.backend_name, None)
  1029. else:
  1030. validate_custom_suffix(None, backend_in.version_configs)
  1031. for version in backend_in.version_configs.root.keys():
  1032. backend_in.version_configs.root[version].built_in_frameworks = None
  1033. try:
  1034. # Use a dict for changes to prevent version_config serialization errors and None field overrides issues.
  1035. update_data = {
  1036. "backend_name": backend_in.backend_name,
  1037. "version_configs": backend_in.version_configs,
  1038. "default_version": backend_in.default_version,
  1039. "default_backend_param": backend_in.default_backend_param,
  1040. "default_run_command": backend_in.default_run_command,
  1041. "default_entrypoint": backend_in.default_entrypoint,
  1042. "health_check_path": backend_in.health_check_path,
  1043. "description": backend_in.description,
  1044. "default_env": backend_in.default_env,
  1045. "backend_source": backend_in.backend_source,
  1046. }
  1047. if backend_in.backend_source == BackendSourceEnum.COMMUNITY:
  1048. if backend_in.enabled is not None:
  1049. update_data["enabled"] = backend_in.enabled
  1050. built_in_version = {
  1051. k: v
  1052. for k, v in backend.version_configs.root.items()
  1053. if v.built_in_frameworks
  1054. }
  1055. # merge built-in versions with custom versions for update
  1056. built_in_version.update(update_data['version_configs'].root)
  1057. update_data['version_configs'].root = built_in_version
  1058. await backend.update(session, update_data)
  1059. except Exception as e:
  1060. raise InternalServerErrorException(
  1061. message=f"Failed to update inference backend: {e}"
  1062. )
  1063. return backend
  1064. @router.delete("/{id}")
  1065. async def delete_inference_backend(session: SessionDep, ctx: TenantContextDep, id: int):
  1066. """
  1067. Delete an inference backend.
  1068. """
  1069. backend = await InferenceBackend.one_by_id(session, id)
  1070. if not backend:
  1071. raise NotFoundException(message=f"Inference backend {id} not found")
  1072. assert_org_owned_writable(ctx, backend, resource_label="inference backend")
  1073. # Protect Platform-curated rows (built-in / community at the global
  1074. # scope). Org-scoped rows are always deletable by their owner — even
  1075. # when they're a vLLM extension carrying source=BUILT_IN — because
  1076. # they're the Org's own data, not platform-curated.
  1077. if (
  1078. backend.owner_principal_id is None
  1079. and backend.backend_source != BackendSourceEnum.CUSTOM
  1080. and backend.backend_source is not None
  1081. ):
  1082. raise BadRequestException(message="Cannot delete built-in or community backend")
  1083. # Check if the backend is being used by any models
  1084. is_in_use, model_names = await check_backend_in_use(session, backend.backend_name)
  1085. if is_in_use:
  1086. raise BadRequestException(
  1087. message=f"Cannot delete backend '{backend.backend_name}' because it is currently being used by the following models: {', '.join(model_names)}",
  1088. )
  1089. try:
  1090. await backend.delete(session)
  1091. except Exception as e:
  1092. raise InternalServerErrorException(
  1093. message=f"Failed to delete inference backend: {e}"
  1094. )
  1095. @router.post("/from-yaml", response_model=InferenceBackend)
  1096. async def create_inference_backend_from_yaml( # noqa: C901
  1097. session: SessionDep, ctx: TenantContextDep, payload: dict = Body(...)
  1098. ):
  1099. """
  1100. Create an inference backend from YAML configuration.
  1101. Expected YAML format:
  1102. ```yaml
  1103. backend_name: "my-custom-backend"
  1104. version_configs:
  1105. "v1.0.0":
  1106. image_name: "my-backend:v1.0.0"
  1107. run_command: "python server.py --port {{port}} --model {{model_path}}"
  1108. "v1.1.0":
  1109. image_name: "my-backend:v1.1.0"
  1110. run_command: "python server.py --port {{port}} --model {{model_path}}"
  1111. default_version: "v1.1.0"
  1112. default_backend_param: ["--max-tokens", "2048"]
  1113. default_run_command: "python server.py"
  1114. description: "My custom inference backend"
  1115. health_check_path: "/health"
  1116. allowed_proxy_uris: ["/v1/chat/completions", "/v1/completions"]
  1117. ```
  1118. """
  1119. try:
  1120. # Extract YAML content from JSON payload
  1121. yaml_content = payload.get("content")
  1122. if not yaml_content:
  1123. raise BadRequestException(message="Missing 'content' field in request body")
  1124. # Parse YAML content
  1125. req_yaml_data = yaml.safe_load(yaml_content)
  1126. # Validate required fields
  1127. if not req_yaml_data.get("backend_name"):
  1128. raise BadRequestException(message="backend_name is required in YAML")
  1129. target_org_id = req_yaml_data.get("owner_principal_id")
  1130. validate_owner_principal(
  1131. target_org_id,
  1132. ctx,
  1133. resource_label="inference backend",
  1134. )
  1135. # Platform rows can't shadow built-in names; Org rows may extend them.
  1136. if target_org_id is None and is_built_in_backend(req_yaml_data["backend_name"]):
  1137. raise BadRequestException(
  1138. message=(
  1139. f"Backend name {req_yaml_data['backend_name']} duplicates with built-in backends (case-insensitive). Please use another name."
  1140. ),
  1141. )
  1142. req_yaml_data["backend_source"] = BackendSourceEnum.CUSTOM
  1143. req_yaml_data["enabled"] = True
  1144. # Composite uniqueness — same backend_name allowed across tenants.
  1145. existing = await InferenceBackend.one_by_fields(
  1146. session,
  1147. {
  1148. "backend_name": req_yaml_data["backend_name"],
  1149. "owner_principal_id": target_org_id,
  1150. },
  1151. )
  1152. if existing:
  1153. raise BadRequestException(
  1154. message=f"Inference backend with name '{req_yaml_data['backend_name']}' already exists",
  1155. )
  1156. allowed_keys = [
  1157. "backend_name",
  1158. "version_configs",
  1159. "default_version",
  1160. "default_backend_param",
  1161. "default_run_command",
  1162. "health_check_path",
  1163. "description",
  1164. "default_env",
  1165. "enabled",
  1166. "backend_source",
  1167. ]
  1168. yaml_data = {k: req_yaml_data[k] for k in allowed_keys if k in req_yaml_data}
  1169. # Convert version_configs to VersionConfigDict if present
  1170. if 'version_configs' in yaml_data and yaml_data['version_configs']:
  1171. version_configs_dict = {}
  1172. for version, config in yaml_data['version_configs'].items():
  1173. if config.get('built_in_frameworks'):
  1174. config['built_in_frameworks'] = None
  1175. version_configs_dict[version] = VersionConfig(**config)
  1176. yaml_data['version_configs'] = VersionConfigDict(root=version_configs_dict)
  1177. # Validate version names for custom backends
  1178. validate_custom_suffix(yaml_data['backend_name'], None)
  1179. # Validate YAML data using Pydantic model to ensure field types are correct
  1180. try:
  1181. InferenceBackendCreate.model_validate(yaml_data)
  1182. except ValidationError as e:
  1183. raise BadRequestException(message=f"Invalid YAML data: {e}")
  1184. # Create the backend
  1185. backend = InferenceBackend(**yaml_data, owner_principal_id=target_org_id)
  1186. backend = await InferenceBackend.create(session, backend)
  1187. return backend
  1188. except yaml.YAMLError as e:
  1189. raise BadRequestException(message=f"Invalid YAML format: {e}")
  1190. except BadRequestException:
  1191. raise # Re-raise BadRequestException without wrapping
  1192. except Exception as e:
  1193. raise InternalServerErrorException(
  1194. message=f"Failed to create inference backend from YAML: {e.__str__()}"
  1195. )
  1196. @router.put("/{id}/from-yaml", response_model=InferenceBackend)
  1197. async def update_inference_backend_from_yaml( # noqa: C901
  1198. session: SessionDep,
  1199. ctx: TenantContextDep,
  1200. id: int,
  1201. payload: dict = Body(...),
  1202. ):
  1203. """
  1204. Update an existing inference backend from YAML configuration.
  1205. Expected JSON format:
  1206. ```json
  1207. {
  1208. "content": "backend_name: \"my-custom-backend\"\nversion_configs:\n \"v1.0.0\":\n image_name: \"my-backend:v1.0.0\"\n run_command: \"python server.py --port {{port}} --model {{model_path}}\"\n \"v1.1.0\":\n image_name: \"my-backend:v1.1.0\"\n run_command: \"python server.py --port {{port}} --model {{model_path}}\"\ndefault_version: \"v1.1.0\"\ndefault_backend_param: [\"--max-tokens\", \"2048\"]\ndefault_run_command: \"python server.py\"\ndescription: \"My custom inference backend\"\nhealth_check_path: \"/health\"\nallowed_proxy_uris: [\"/v1/chat/completions\", \"/v1/completions\"]"
  1209. }
  1210. """
  1211. backend = await InferenceBackend.one_by_id(session, id)
  1212. if not backend:
  1213. raise NotFoundException(message=f"Inference backend {id} not found")
  1214. assert_org_owned_writable(ctx, backend, resource_label="inference backend")
  1215. try:
  1216. # Extract YAML content from JSON payload
  1217. yaml_content = payload.get("content")
  1218. if not yaml_content:
  1219. raise BadRequestException(message="Missing 'content' field in request body")
  1220. # Parse YAML content
  1221. req_yaml_data = yaml.safe_load(yaml_content)
  1222. # Validate required fields
  1223. if not req_yaml_data.get("backend_name"):
  1224. raise BadRequestException(message="backend_name is required in YAML")
  1225. # Check if updating to a name that already exists (excluding current backend)
  1226. if req_yaml_data["backend_name"] != backend.backend_name:
  1227. raise BadRequestException(
  1228. message="The name of inference-backend can not be modified",
  1229. )
  1230. allowed_keys = [
  1231. "backend_name",
  1232. "version_configs",
  1233. "default_backend_param",
  1234. "default_run_command",
  1235. "default_entrypoint",
  1236. "health_check_path",
  1237. "description",
  1238. "default_env",
  1239. "enabled",
  1240. "backend_source",
  1241. ]
  1242. if not is_built_in_backend(backend.backend_name):
  1243. allowed_keys.append("default_version")
  1244. yaml_data = {k: req_yaml_data[k] for k in allowed_keys if k in req_yaml_data}
  1245. # Process version_configs if present
  1246. yaml_data['version_configs'] = _process_version_configs(
  1247. yaml_data.get('version_configs')
  1248. )
  1249. # Check if any versions are being removed and validate they're not in use
  1250. await _validate_version_removal(
  1251. session, backend, yaml_data.get('version_configs')
  1252. )
  1253. # Validate version names based on backend source
  1254. if backend.backend_source == BackendSourceEnum.CUSTOM or (
  1255. backend.backend_source is None and not backend.is_built_in
  1256. ):
  1257. validate_custom_suffix(yaml_data['backend_name'], None)
  1258. else:
  1259. validate_custom_suffix(None, yaml_data.get('version_configs'))
  1260. # Clear built_in_frameworks for all versions in yaml_data
  1261. _clear_built_in_frameworks(yaml_data.get('version_configs'))
  1262. # Merge built-in versions for COMMUNITY backends
  1263. if backend.backend_source == BackendSourceEnum.COMMUNITY:
  1264. yaml_data['version_configs'] = _merge_community_versions(
  1265. backend, yaml_data.get('version_configs')
  1266. )
  1267. # Validate YAML data using Pydantic model to ensure field types are correct
  1268. try:
  1269. InferenceBackendUpdate.model_validate(yaml_data)
  1270. except ValidationError as e:
  1271. raise BadRequestException(message=f"Invalid YAML data: {e}")
  1272. # Update the backend from YAML data (after normalization)
  1273. await backend.update(session, yaml_data)
  1274. return backend
  1275. except yaml.YAMLError as e:
  1276. raise BadRequestException(message=f"Invalid YAML format: {e}")
  1277. except BadRequestException:
  1278. raise # Re-raise BadRequestException without wrapping
  1279. except Exception as e:
  1280. raise InternalServerErrorException(
  1281. message=f"Failed to update inference backend from YAML: {e}"
  1282. )
  1283. def _process_version_configs(
  1284. version_configs_data: Optional[dict],
  1285. ) -> VersionConfigDict:
  1286. """
  1287. Convert raw version_configs dict to VersionConfigDict.
  1288. Returns None if version_configs_data is None or empty.
  1289. """
  1290. version_configs_dict = {}
  1291. for version, config in version_configs_data.items() if version_configs_data else []:
  1292. # Clear built_in_frameworks during initial processing
  1293. if config.get('built_in_frameworks'):
  1294. config['built_in_frameworks'] = None
  1295. version_configs_dict[version] = VersionConfig(**config)
  1296. return VersionConfigDict(root=version_configs_dict)
  1297. async def _validate_version_removal(
  1298. session,
  1299. backend: InferenceBackend,
  1300. new_version_configs: Optional[VersionConfigDict],
  1301. ):
  1302. """
  1303. Check if any versions are being removed and validate they're not in use.
  1304. """
  1305. # Get current versions (empty dict if none)
  1306. current_versions = {}
  1307. if backend.version_configs and backend.version_configs.root:
  1308. current_versions = {
  1309. v: config
  1310. for v, config in backend.version_configs.root.items()
  1311. if not config.built_in_frameworks
  1312. }
  1313. # Get new versions (empty dict if none)
  1314. new_versions = {}
  1315. if new_version_configs and new_version_configs.root:
  1316. new_versions = new_version_configs.root
  1317. # Find removed versions
  1318. removed_versions = set(current_versions.keys()) - set(new_versions.keys())
  1319. # Check if removed versions are in use
  1320. for version in removed_versions:
  1321. is_in_use, model_names = await check_backend_in_use(
  1322. session, backend.backend_name, version
  1323. )
  1324. if is_in_use:
  1325. raise BadRequestException(
  1326. message=f"Cannot remove version name '{version}' of backend '{backend.backend_name}' because it is currently being used by the following models: {', '.join(model_names)}",
  1327. )
  1328. def _clear_built_in_frameworks(version_configs: Optional[VersionConfigDict]):
  1329. """
  1330. Clear built_in_frameworks for all versions in version_configs.
  1331. """
  1332. if not version_configs or not version_configs.root:
  1333. return
  1334. for version_config in version_configs.root.values():
  1335. version_config.built_in_frameworks = None
  1336. def _merge_community_versions(
  1337. backend: InferenceBackend,
  1338. new_version_configs: Optional[VersionConfigDict],
  1339. ) -> VersionConfigDict:
  1340. """
  1341. Merge built-in versions with new versions for COMMUNITY backends.
  1342. Returns:
  1343. VersionConfigDict: Merged version configurations with built-in versions preserved
  1344. """
  1345. # Extract built-in versions from current backend
  1346. built_in_versions = {}
  1347. if backend.version_configs and backend.version_configs.root:
  1348. built_in_versions = {
  1349. k: v
  1350. for k, v in backend.version_configs.root.items()
  1351. if v.built_in_frameworks
  1352. }
  1353. if not new_version_configs or not new_version_configs.root:
  1354. return VersionConfigDict(root=built_in_versions)
  1355. # Merge: built-in versions + new versions (new versions take precedence)
  1356. built_in_versions.update(new_version_configs.root or {})
  1357. new_version_configs.root = built_in_versions
  1358. return new_version_configs
  1359. def validate_custom_suffix(
  1360. backend_name: Optional[str],
  1361. version_configs: Optional[VersionConfigDict],
  1362. ):
  1363. """
  1364. Validate custom suffix for backend names and version names.
  1365. Rules:
  1366. - Backend name: Must end with '-custom' if provided
  1367. - Version name: Must end with '-custom' ONLY if it's a user-defined version
  1368. (i.e., built_in_frameworks is None and custom_framework has value)
  1369. """
  1370. # Validate backend name
  1371. if backend_name and not backend_name.endswith("-custom"):
  1372. raise BadRequestException(
  1373. message=f"Custom backend name '{backend_name}' must end with '-custom'",
  1374. )
  1375. # Validate version names
  1376. if version_configs and version_configs.root:
  1377. for version, config in version_configs.root.items():
  1378. # Skip predefined versions (built_in_frameworks has value)
  1379. if config.built_in_frameworks:
  1380. continue
  1381. # User-defined versions must have -custom suffix
  1382. if not isinstance(version, str) or not version.endswith("-custom"):
  1383. raise BadRequestException(
  1384. message=f"Custom backend version '{version}' must end with '-custom'",
  1385. )