| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603 |
- import logging
- import math
- from copy import deepcopy
- from typing import List, Tuple, Optional, Dict
- import yaml
- from fastapi import APIRouter, Body
- from gpustack_runner.runner import ServiceVersionedRunner, ServiceRunner
- from gpustack_runtime.deployer.__utils__ import compare_versions
- from pydantic import ValidationError
- from starlette.responses import StreamingResponse
- from gpustack.api.exceptions import (
- InternalServerErrorException,
- NotFoundException,
- BadRequestException,
- )
- from gpustack.api.tenant import (
- assert_org_owned_writable,
- validate_owner_principal,
- )
- from gpustack.schemas import Worker
- from gpustack.schemas.common import Pagination
- from gpustack.schemas.inference_backend import (
- InferenceBackend,
- InferenceBackendCreate,
- InferenceBackendListItem,
- InferenceBackendResponse,
- InferenceBackendUpdate,
- InferenceBackendsPublic,
- VersionConfig,
- VersionConfigDict,
- get_built_in_backend,
- InferenceBackendPublic,
- VersionListItem,
- is_built_in_backend,
- )
- from gpustack.schemas.models import BackendEnum, Model, BackendSourceEnum
- from gpustack.server.db import async_session
- from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
- from gpustack_runner import list_service_runners
- from gpustack_runtime.detector.ascend import get_ascend_cann_variant
- from gpustack_runtime.detector import ManufacturerEnum
- logger = logging.getLogger(__name__)
- router = APIRouter()
- def filter_yaml_fields(yaml_data: Dict, filter_keys: List[str]) -> Dict: # noqa: C901
- """
- Recursively remove specified keys from a nested YAML dict.
- Args:
- yaml_data: Dictionary parsed from YAML content.
- filter_keys: List of keys to remove wherever they appear.
- Returns:
- The same dict instance after filtering.
- """
- if not isinstance(yaml_data, dict):
- return yaml_data
- def _filter_in_place(obj: Dict):
- # Delete keys that should be filtered
- for key in list(obj.keys()):
- if key in filter_keys:
- try:
- del obj[key]
- except Exception:
- # Silently ignore any deletion issues
- pass
- continue
- # Recurse into nested dicts
- val = obj.get(key)
- if isinstance(val, dict):
- _filter_in_place(val)
- elif isinstance(val, list):
- for item in val:
- if isinstance(item, dict):
- _filter_in_place(item)
- _filter_in_place(yaml_data)
- return yaml_data
- async def check_backend_in_use(
- session: SessionDep, backend_name: str, backend_version: Optional[str] = None
- ) -> Tuple[bool, List[str]]:
- """
- Check if a backend or specific backend version is being used by any models.
- Args:
- session: Database session
- backend_name: The name of the backend to check
- backend_version: Optional specific version to check. If None, checks all versions.
- Returns:
- A tuple containing:
- - Boolean indicating if the backend/version is in use
- - List of model names that are using the backend/version
- """
- try:
- # Query models that use the specified backend
- if backend_version:
- # Check for specific backend and version combination
- models = await Model.all_by_fields(
- session, {"backend": backend_name, "backend_version": backend_version}
- )
- else:
- # Check for any models using this backend (any version)
- models = await Model.all_by_field(session, "backend", backend_name)
- models = [model for model in models if model.replicas > 0]
- model_names = [model.name for model in models]
- is_in_use = len(models) > 0
- return is_in_use, model_names
- except Exception as e:
- logger.error(f"Error checking backend usage: {e}")
- return False, []
- def get_lower_version_runners(
- runners: list[ServiceRunner], backend_version: str
- ) -> list[ServiceRunner]:
- """
- Filter runners whose version is less than or equal to the given backend_version.
- Rebuilds the list[ServiceRunner] structure with only the matching elements.
- Args:
- runners: List of ServiceRunner objects to filter
- backend_version: The version to compare against (only runners with versions <= this will be kept)
- Returns:
- List of ServiceRunner objects with filtered versions/backends
- """
- filtered_runners = []
- for runner in runners:
- # Create a new runner with filtered structure
- new_runner = deepcopy(runner)
- # Filter versions in backends
- for version in new_runner.versions:
- for backend in version.backends:
- # Filter backend versions that are <= backend_version
- backend.versions = [
- bv
- for bv in backend.versions
- if compare_versions(bv.version, backend_version) <= 0
- ]
- # Remove backends with no matching versions
- for version in new_runner.versions:
- version.backends = [
- backend for backend in version.backends if backend.versions
- ]
- # Remove versions with no matching backends
- new_runner.versions = [
- version for version in new_runner.versions if version.backends
- ]
- # Only add runner if it has matching versions
- if new_runner.versions:
- filtered_runners.append(new_runner)
- return filtered_runners
- def get_runner_versions_and_configs(
- backend_name: str, backend_version: Optional[str], **kwargs
- ) -> Tuple[Dict[str, ServiceVersionedRunner], VersionConfigDict, Optional[str]]:
- """
- Get runner versions and version configs for a given backend.
- Args:
- backend_name: The name of the backend service
- kwargs: Others keyword arguments to pass to list_service_runners()
- Returns:
- A tuple containing:
- - List of version strings
- - VersionConfigDict with version configurations
- - Default version (first available version or None)
- """
- runners_list = list_service_runners(
- service=backend_name.lower(),
- **kwargs,
- )
- if backend_version:
- runners_list = get_lower_version_runners(runners_list, backend_version)
- runner_versions: Dict[str, ServiceVersionedRunner] = {}
- version_configs = VersionConfigDict()
- default_version = None
- if runners_list and len(runners_list) > 0:
- for version in runners_list[0].versions:
- if version.version:
- runner_versions[version.version] = version
- backend_list = [
- f"{backend_runner.backend}" for backend_runner in version.backends
- ]
- version_configs.root[version.version] = VersionConfig(
- built_in_frameworks=backend_list,
- )
- if default_version is None:
- default_version = version.version
- return runner_versions, version_configs, default_version
- def deduplicate_versions(versions: List[VersionListItem]) -> List[VersionListItem]:
- seen = set()
- result = []
- for item in versions:
- key = (item.version, item.is_deprecated)
- if key not in seen:
- seen.add(key)
- result.append(item)
- return result
- def get_runner_deprecate(runners: List[ServiceVersionedRunner]) -> bool:
- """
- Check if all runners are deprecated.
- Args:
- runners: List of ServiceVersionedRunner objects
- Returns:
- True if all runners are deprecated, False otherwise.
- Returns False if the list is empty.
- """
- if not runners:
- return False
- return all(
- runner.backends[0].versions[0].variants[0].deprecated for runner in runners
- )
- def merge_list_runners( # noqa: C901
- backend_name: str, workers: List[Worker]
- ) -> Tuple[Dict[str, List[ServiceVersionedRunner]], VersionConfigDict, Optional[str]]:
- """
- Merge runner versions and configs from multiple workers.
- Extracts gpu.type and gpu.runtime_version from each worker's GPU devices
- and uses them as query conditions for list_service_runners.
- Args:
- backend_name: The name of the backend service
- workers: List of workers to extract GPU information from
- Returns:
- A tuple containing:
- - Dict[str, List[ServiceVersionedRunner]]: Merged runner versions, grouped by version
- - VersionConfigDict: Merged version configurations
- - Optional[str]: Default version (from first query)
- """
- # Collect unique query conditions from all workers
- query_conditions = set()
- for worker in workers:
- if worker.status and worker.status.gpu_devices:
- for gpu in worker.status.gpu_devices:
- # Extract variant for Ascend GPUs
- variant = None
- if gpu.vendor == ManufacturerEnum.ASCEND and gpu.arch_family:
- variant = get_ascend_cann_variant(gpu.arch_family).lower()
- # Add (type, runtime_version, variant) tuple to set
- # Use None for runtime_version if not available
- query_conditions.add((gpu.type, gpu.runtime_version, variant))
- merged_runner_versions: Dict[str, List[ServiceVersionedRunner]] = {}
- merged_version_configs = VersionConfigDict()
- merged_default_version = None
- # Loop through each unique query condition
- for idx, (gpu_type, runtime_version, variant) in enumerate(query_conditions):
- # Build kwargs for get_runner_versions_and_configs
- kwargs = {"backend": gpu_type}
- if variant:
- kwargs["backend_variant"] = variant
- # Get runner versions and configs for this condition
- runner_versions, version_configs, default_version = (
- get_runner_versions_and_configs(backend_name, runtime_version, **kwargs)
- )
- # For the first condition, use its results as base
- if idx == 0:
- # Convert Dict[str, ServiceVersionedRunner] to Dict[str, List[ServiceVersionedRunner]]
- merged_runner_versions = {
- version: [runner] for version, runner in runner_versions.items()
- }
- merged_version_configs = version_configs
- merged_default_version = default_version
- else:
- # Merge runner versions (append to list if exists)
- for version, runner in runner_versions.items():
- if version in merged_runner_versions:
- merged_runner_versions[version].append(runner)
- else:
- merged_runner_versions[version] = [runner]
- # Merge version configs
- for version, config in version_configs.root.items():
- if version not in merged_version_configs.root:
- # Add new version
- merged_version_configs.root[version] = config
- else:
- # Merge built_in_frameworks (deduplicate)
- existing_frameworks = (
- merged_version_configs.root[version].built_in_frameworks or []
- )
- new_frameworks = config.built_in_frameworks or []
- merged_frameworks = list(set(existing_frameworks + new_frameworks))
- merged_version_configs.root[version].built_in_frameworks = (
- merged_frameworks
- )
- return merged_runner_versions, merged_version_configs, merged_default_version
- @router.get("/list", response_model=InferenceBackendResponse)
- async def list_backend_configs( # noqa: C901
- session: SessionDep,
- ctx: TenantContextDep,
- cluster_id: Optional[int] = None,
- ):
- """
- Get list of available backend configurations with version information.
- Returns both built-in backends and custom backends from database.
- Built-in backends are identified and enhanced with runner versions.
- Each backend item includes available versions.
- Hybrid: when an Org row and a Platform row share the same backend_name,
- the Org row's metadata + version_configs win, then Platform versions
- are merged in for any keys the Org didn't define.
- """
- items = []
- if cluster_id and cluster_id > 0:
- workers = await Worker.all_by_field(session, "cluster_id", cluster_id)
- else:
- workers = await Worker.all(session)
- # Process all backends from database (includes both built-in and custom backends)
- try:
- all_rows = await InferenceBackend.all(session)
- # Hybrid filter:
- # - Single-Org caller (member, or platform admin act-as): see
- # Platform rows (NULL) + their own Org's rows. The merge below
- # collapses these into one entry per backend_name with Org keys
- # winning on collisions.
- # - Bypass mode (admin "All", system users): there's no single Org
- # to merge with, so we fall back to Platform-only. Merging across
- # multiple Org rows for the same backend_name would be
- # ill-defined (last-Org-wins), and the response model
- # (InferenceBackendListItem) has no owner_principal_id field to
- # distinguish them anyway. Callers that need a specific Org's
- # overrides — including workers running tenant-scoped deploys —
- # should fetch by id or pass an org context.
- bypass_filter = (
- ctx is None
- or (ctx.is_platform_admin and ctx.current_principal_id is None)
- or getattr(getattr(ctx, "user", None), "is_system", False)
- )
- if bypass_filter:
- visible_rows = [b for b in all_rows if b.owner_principal_id is None]
- else:
- visible_rows = [
- b
- for b in all_rows
- if b.owner_principal_id is None
- or b.owner_principal_id == ctx.current_principal_id
- ]
- # Group by backend_name; collapse Platform + Org into one logical
- # backend with merged versions (Org wins on key collisions). With
- # the filter above, ``visible_rows`` contains at most one Org row
- # per backend_name, so the merge is well-defined.
- #
- # Stash merged values in side dicts keyed by db id rather than
- # mutating the ORM rows themselves — no ``expunge`` dance, no
- # risk of a stray flush persisting the read-time merge.
- merged_versions_by_id: Dict[int, VersionConfigDict] = {}
- grouped: Dict[str, InferenceBackend] = {}
- for b in visible_rows:
- name = b.backend_name
- existing = grouped.get(name)
- if existing is None:
- grouped[name] = b
- continue
- org_row = b if (b.owner_principal_id is not None) else existing
- other = existing if org_row is b else b
- merged_versions = {
- **(other.version_configs.root if other.version_configs else {}),
- **(org_row.version_configs.root if org_row.version_configs else {}),
- }
- merged_versions_by_id[org_row.id] = VersionConfigDict(root=merged_versions)
- grouped[name] = org_row
- inference_backends = list(grouped.values())
- for backend in inference_backends:
- effective_version_configs = merged_versions_by_id.get(
- backend.id, backend.version_configs
- )
- # Get versions from version_config
- versions: List[VersionListItem] = []
- if effective_version_configs and effective_version_configs.root:
- versions = [
- VersionListItem(
- version=version, env=backend.get_backend_env(version)
- )
- for version in effective_version_configs.root.keys()
- ]
- if backend.is_built_in:
- # For built-in backends, add runner versions and use special show name
- runner_versions, version_configs, default_version = merge_list_runners(
- backend.backend_name,
- workers,
- )
- # Merge runner versions with existing versions
- for version, config in version_configs.root.items():
- # Check if this version has any built-in frameworks
- if config.built_in_frameworks:
- # Versions are only marked deprecated when no worker is compatible with them.
- is_deprecated = get_runner_deprecate(
- runner_versions.get(version, [])
- )
- # Get environment for this specific version
- version_env = backend.get_backend_env(version)
- versions.append(
- VersionListItem(
- version=version,
- is_deprecated=is_deprecated,
- env=version_env,
- )
- )
- # Remove duplicates while preserving order
- versions = deduplicate_versions(versions)
- # Use the runner-derived default if the row didn't set one;
- # local var so we don't mutate the ORM object.
- effective_default_version = backend.default_version or default_version
- backend_item = InferenceBackendListItem(
- backend_name=backend.backend_name,
- default_version=effective_default_version,
- default_backend_param=backend.default_backend_param,
- versions=versions,
- is_built_in=backend.is_built_in,
- enabled=True,
- backend_source=BackendSourceEnum.BUILT_IN,
- default_env=backend.default_env,
- )
- else:
- if (
- backend.backend_source == BackendSourceEnum.COMMUNITY
- and not backend.enabled
- ):
- continue
- # For custom backends, use backend_name as show_name
- backend_item = InferenceBackendListItem(
- backend_name=backend.backend_name,
- default_version=backend.default_version,
- default_backend_param=backend.default_backend_param,
- versions=versions,
- is_built_in=False,
- enabled=backend.enabled,
- backend_source=backend.backend_source,
- default_env=backend.default_env,
- )
- items.append(backend_item)
- # Ensure Custom backend is always included even if not in database
- custom_backend_item = InferenceBackendListItem(
- backend_name=BackendEnum.CUSTOM,
- default_version=None,
- default_backend_param=None,
- versions=[],
- is_built_in=False,
- enabled=True,
- backend_source=BackendSourceEnum.BUILT_IN,
- default_env=None,
- )
- items.append(custom_backend_item)
- except Exception as e:
- # Log error but don't fail the entire request
- logger.error(f"Failed to load backends from database: {e}")
- return InferenceBackendResponse(items=items)
- def _hybrid_backend_conditions(ctx) -> List:
- """Hybrid visibility filter for inference_backends.
- Platform rows (owner_principal_id IS NULL) are visible to everyone.
- Org rows are visible to:
- - their own Org's members (current_principal_id matches)
- - platform admin in "All" mode (no current_principal_id) — full bypass
- - system users (worker / cluster service accounts) — full bypass,
- since they need every Org's overrides to actually run a deploy
- whose backend version was customised at the Org level
- Platform admin in act-as mode (current_principal_id is set) follows the
- same scope as a non-admin caller in that Org: Platform NULL +
- that Org's rows only. They DON'T see other Orgs' rows while
- pretending to be in this one.
- """
- if ctx is None:
- return []
- if getattr(ctx.user, "is_system", False):
- return []
- if ctx.is_platform_admin and ctx.current_principal_id is None:
- return []
- from sqlalchemy import or_
- or_clauses = [InferenceBackend.owner_principal_id.is_(None)]
- if ctx.current_principal_id is not None:
- or_clauses.append(
- InferenceBackend.owner_principal_id == ctx.current_principal_id
- )
- return [or_(*or_clauses)]
- async def _fetch_visible_backend_rows(session, ctx) -> List[InferenceBackend]:
- """Hybrid-aware DB read: Platform rows always; Org rows scoped to ctx."""
- extra_conditions = _hybrid_backend_conditions(ctx)
- if extra_conditions:
- return await InferenceBackend.all_by_fields(
- session, fields={}, extra_conditions=extra_conditions
- )
- return await InferenceBackend.all(session)
- def _enrich_built_in_with_runner_versions(
- db_backend: InferenceBackendPublic,
- backend_name: str,
- with_deprecated: bool,
- ) -> None:
- """Layer runner-discovered versions on top of the DB row in place."""
- _, runner_versions, default_version = get_runner_versions_and_configs(
- backend_name,
- backend_version=None,
- with_deprecated=with_deprecated,
- )
- for runner_version, version_config in runner_versions.root.items():
- db_backend.built_in_version_configs[runner_version] = version_config
- if default_version and not db_backend.default_version:
- db_backend.default_version = default_version
- def _migrate_community_built_in_versions(db_backend: InferenceBackendPublic) -> None:
- """Move version_configs entries that carry built_in_frameworks into the
- dedicated built_in_version_configs map (community backends only)."""
- if (
- db_backend.backend_source != BackendSourceEnum.COMMUNITY
- or not db_backend.version_configs
- or not db_backend.version_configs.root
- ):
- return
- versions_to_move = {
- version: config
- for version, config in db_backend.version_configs.root.items()
- if config.built_in_frameworks
- }
- if not versions_to_move:
- return
- if not db_backend.built_in_version_configs:
- db_backend.built_in_version_configs = {}
- db_backend.built_in_version_configs.update(versions_to_move)
- for version in versions_to_move:
- del db_backend.version_configs.root[version]
- def _collapse_by_backend_name(
- db_result_sorted: List[InferenceBackend],
- ) -> List[InferenceBackendPublic]:
- """Collapse Platform + Org rows that share a backend_name into one
- public-model entry. Used for the non-admin single-card view.
- - Org row wins on metadata + version_configs (Org keys override
- Platform keys, missing Org keys fall back to Platform).
- - **Exception: ``enabled``**. Use ``Platform.enabled OR Org.enabled``
- so a stale or accidental Org row with ``enabled=False`` cannot
- shadow a Platform-enabled backend. The tradeoff is that an Org
- can no longer "disable" a Platform-shared community backend in
- its own scope — disabling has to happen at the Platform level.
- That's a deliberate choice: keeping the Hybrid view simple and
- avoiding "I didn't disable it but it's gone" confusion is worth
- more than per-Org opt-out, which can be re-introduced later via
- an explicit ``override_enabled`` flag if needed.
- Returns ``InferenceBackendPublic`` copies rather than ORM rows so the
- read-time merge can never be flushed back to the database. The caller
- pays one ``model_dump`` per row, which is cheap relative to the DB
- read this is feeding.
- """
- by_name: Dict[str, InferenceBackendPublic] = {}
- for backend in db_result_sorted:
- existing = by_name.get(backend.backend_name)
- if existing is None:
- by_name[backend.backend_name] = InferenceBackendPublic(
- **backend.model_dump()
- )
- continue
- # `existing` is the public copy of whatever we saw first; `backend`
- # is the new ORM row. Decide which side is the Org row and merge.
- if backend.owner_principal_id is not None:
- org_versions = backend.version_configs
- other_versions = existing.version_configs
- org_enabled = bool(backend.enabled)
- other_enabled = bool(existing.enabled)
- target = InferenceBackendPublic(**backend.model_dump())
- else:
- org_versions = existing.version_configs
- other_versions = backend.version_configs
- org_enabled = bool(existing.enabled)
- other_enabled = bool(backend.enabled)
- target = existing
- merged_versions = {
- **(other_versions.root if other_versions else {}),
- **(org_versions.root if org_versions else {}),
- }
- target.version_configs = VersionConfigDict(root=merged_versions)
- target.enabled = org_enabled or other_enabled
- by_name[backend.backend_name] = target
- return list(by_name.values())
- async def merge_runner_versions_to_db(
- session: SessionDep,
- with_deprecated: bool = True,
- *,
- ctx=None,
- ) -> List[InferenceBackendPublic]:
- """Backends visible to the caller, with runner versions enriched in.
- Hybrid display rules:
- - **Platform admin**: one row per DB row (no collapse). Admin needs
- to manage Platform rows and Org rows separately, so they show as
- distinct cards (typically distinguished by an Owner tag in the UI).
- - **Non-admin**: collapsed single-card view per backend_name —
- Platform + Org rows fold into one entry, Org wins on metadata,
- versions union (Org overrides Platform). Org owners don't need
- to know about the underlying two-row Hybrid storage.
- """
- db_result = await _fetch_visible_backend_rows(session, ctx)
- # Sort by id ascending so the Org row (created later, larger id)
- # naturally wins during the non-admin collapse.
- db_result_sorted = sorted(db_result, key=lambda x: x.id if x.id else 0)
- # Show uncollapsed rows for admin-style views (managing every row
- # independently). Admin act-as mode behaves like the Org member —
- # they're acting *inside* that Org and want the collapsed
- # single-card UX too.
- is_admin_view = ctx is None or (
- ctx.is_platform_admin and ctx.current_principal_id is None
- )
- if is_admin_view:
- publics = [
- InferenceBackendPublic(**row.model_dump()) for row in db_result_sorted
- ]
- else:
- publics = _collapse_by_backend_name(db_result_sorted)
- built_in_names = {
- b.backend_name
- for b in get_built_in_backend()
- if b.backend_name != BackendEnum.CUSTOM.value
- }
- merged_backends: List[InferenceBackendPublic] = []
- for public in publics:
- if public.backend_name in built_in_names:
- _enrich_built_in_with_runner_versions(
- public, public.backend_name, with_deprecated
- )
- else:
- _migrate_community_built_in_versions(public)
- merged_backends.append(public)
- return merged_backends
- def _generate_framework_index_map( # noqa: C901
- version_config_dicts: List[Dict[str, VersionConfig]]
- ) -> Dict[str, List[str]]:
- """
- Generate framework index map from a list of version config dictionaries.
- Args:
- version_config_dicts: List of dictionaries mapping version names to VersionConfig objects
- Returns:
- Dictionary mapping framework names to sorted lists of supported versions
- """
- framework_map = {}
- for version_configs in version_config_dicts:
- if not version_configs:
- continue
- for version, config in version_configs.items():
- if config.built_in_frameworks:
- for framework in config.built_in_frameworks:
- if framework not in framework_map:
- framework_map[framework] = []
- if version not in framework_map[framework]:
- framework_map[framework].append(version)
- if config.custom_framework:
- if config.custom_framework not in framework_map:
- framework_map[config.custom_framework] = []
- framework_map[config.custom_framework].append(version)
- # Sort versions for each framework
- for framework in framework_map:
- framework_map[framework].sort()
- return framework_map
- def _filter_community_backends(
- backends: List[InferenceBackendPublic],
- is_only_community: Optional[bool] = None,
- ) -> List[InferenceBackendPublic]:
- """
- Filter backends to only include community backends without custom frameworks.
- This function filters the backend list to only include backends with
- backend_source=COMMUNITY, and removes any versions that have custom_framework set.
- Args:
- backends: List of inference backends to filter
- Returns:
- List of community backends with non-custom framework versions only
- """
- filter_backends = []
- for backend in backends:
- if is_only_community:
- # using in community_backends catalog
- if backend.backend_source != BackendSourceEnum.COMMUNITY:
- continue
- backend.version_configs.root = {}
- else:
- # using in common inference_backends view
- if (
- backend.backend_source == BackendSourceEnum.COMMUNITY
- and not backend.enabled
- ):
- continue
- filter_backends.append(backend)
- return filter_backends
- @router.get("", response_model=InferenceBackendsPublic)
- async def get_inference_backends( # noqa: C901
- session: SessionDep,
- ctx: TenantContextDep,
- params: ListParamsDep,
- search: str = None,
- include_deprecated: bool = False,
- community: Optional[bool] = None,
- backend_source: Optional[str] = None,
- ):
- """
- Get paginated list of inference backends with optional search and filters.
- Args:
- session: Database session
- params: List parameters (page, perPage, watch, sort_by)
- search: Search keyword for backend_name and description
- include_deprecated: Include deprecated versions
- community: Filter community backends (True=community only with non-custom versions, False/None=all backends)
- backend_source: Filter by backend source (built-in, custom, or community)
- Returns:
- InferenceBackendsPublic: Paginated list of inference backends
- """
- fields = {}
- if params.watch:
- # Filter the streamed events with the same Hybrid visibility check.
- def _visible(b: InferenceBackend) -> bool:
- if ctx is None or (
- ctx.is_platform_admin and ctx.current_principal_id is None
- ):
- return True
- # System users (worker / cluster) need every Org's overrides
- # because they actually run the deploys.
- if getattr(getattr(ctx, "user", None), "is_system", False):
- return True
- org_id = getattr(b, "owner_principal_id", None)
- if org_id is None:
- return True
- return (
- ctx.current_principal_id is not None
- and org_id == ctx.current_principal_id
- )
- return StreamingResponse(
- InferenceBackend.streaming(fields=fields, filter_func=_visible),
- media_type="text/event-stream",
- )
- async with async_session() as session:
- merged_backends = await merge_runner_versions_to_db(
- session, with_deprecated=include_deprecated, ctx=ctx
- )
- # Get worker GPU information for framework sorting
- workers = await Worker.all(session)
- framework_list = set()
- for worker in workers:
- if worker.status and worker.status.gpu_devices:
- for gpu in worker.status.gpu_devices:
- framework_list.add(gpu.type)
- # Single-pass filtering and transformation pipeline:
- # 1. Framework sorting (data transformation)
- # 2. Search filter (early rejection)
- # 3. Community filter (early rejection)
- # 4. Backend source filter (early rejection)
- # 5. Framework index map generation (final transformation)
- filter_backends = []
- for backend in merged_backends:
- # 1. Sort frameworks by support status (must be first as it modifies data structure)
- sorted_version_configs = {}
- for version, config in backend.built_in_version_configs.items():
- if config.built_in_frameworks:
- supported = [
- framework
- for framework in config.built_in_frameworks
- if framework in framework_list
- ]
- unsupported = [
- framework
- for framework in config.built_in_frameworks
- if framework not in framework_list
- ]
- config.built_in_frameworks = supported + unsupported
- sorted_version_configs[version] = config
- backend.built_in_version_configs = sorted_version_configs
- # 2. Apply search filter (early rejection to reduce subsequent processing)
- if search:
- lower_search = search.lower()
- if not (
- lower_search in backend.backend_name.lower()
- or (backend.description and lower_search in backend.description.lower())
- ):
- continue # Skip backends that don't match search criteria
- # 3. Apply community filter (early rejection)
- if community is True:
- # Using in community_backends catalog
- if backend.backend_source != BackendSourceEnum.COMMUNITY:
- continue
- # Clear custom versions for community backends
- if backend.version_configs:
- backend.version_configs.root = {}
- else:
- # Using in common inference_backends view
- if (
- backend.backend_source == BackendSourceEnum.COMMUNITY
- and not backend.enabled
- ):
- continue
- # 4. Apply backend_source filter (early rejection)
- if backend_source:
- try:
- source_enum = BackendSourceEnum(backend_source)
- if backend.backend_source != source_enum:
- continue
- except ValueError:
- # Invalid backend_source value, log warning but don't filter
- logger.warning(f"Invalid backend_source value: {backend_source}")
- # 5. Generate framework_index_map (must be last as it depends on processed data)
- version_config_dicts = []
- if backend.built_in_version_configs:
- version_config_dicts.append(backend.built_in_version_configs)
- if backend.version_configs and backend.version_configs.root:
- version_config_dicts.append(backend.version_configs.root)
- backend.framework_index_map = _generate_framework_index_map(
- version_config_dicts
- )
- # Backend passed all filters, add to result list
- filter_backends.append(backend)
- # Apply pagination to merged results
- total = len(filter_backends)
- start_idx = (params.page - 1) * params.perPage
- end_idx = start_idx + params.perPage
- paginated_backends = filter_backends[start_idx:end_idx]
- pagination = Pagination(
- page=params.page,
- perPage=params.perPage,
- total=total,
- totalPage=max(math.ceil(total / params.perPage), 1),
- )
- # Create the response with the same structure as the original
- return InferenceBackendsPublic(
- items=paginated_backends,
- pagination=pagination,
- )
- @router.get("/all", response_model=List[InferenceBackend])
- async def get_all_inference_backends(
- session: SessionDep,
- ctx: TenantContextDep,
- ):
- backends = await merge_runner_versions_to_db(session, ctx=ctx)
- ret = []
- for backend in backends:
- if backend.backend_source == BackendSourceEnum.CUSTOM:
- ret.append(backend)
- continue
- for built_in_version, config in backend.built_in_version_configs.items():
- # if version in same, db version first
- if built_in_version not in backend.version_configs.root:
- backend.version_configs.root[built_in_version] = config
- ret.append(backend)
- return ret
- def _assert_backend_visible(ctx, backend):
- """Org member can see Platform (NULL) and own-Org rows. Admin sees
- everything in "All" mode; in act-as mode they're scoped just like
- a regular member of that Org (so a stale link to dev Org's row
- while admin is acting-as Default surfaces a 404, not a leak)."""
- if backend is None:
- raise NotFoundException(message="Inference backend not found")
- if ctx.is_platform_admin and ctx.current_principal_id is None:
- return
- org_id = backend.owner_principal_id
- if org_id is None:
- return # Platform row is visible to everyone
- if ctx.current_principal_id is not None and org_id == ctx.current_principal_id:
- return
- raise NotFoundException(message="Inference backend not found")
- @router.get("/{id}", response_model=InferenceBackend)
- async def get_inference_backend(session: SessionDep, ctx: TenantContextDep, id: int):
- """
- Get a specific inference backend by ID.
- """
- backend = await InferenceBackend.one_by_id(session, id)
- if not backend:
- raise BadRequestException(message=f"Inference backend {id} not found")
- _assert_backend_visible(ctx, backend)
- return backend
- @router.get("/backend_name/{backend_name}", response_model=InferenceBackend)
- async def get_inference_backend_by_name(
- session: SessionDep, ctx: TenantContextDep, backend_name: str
- ):
- """
- Get a specific inference backend by backend name. Resolves to the
- caller's Org row if one exists, else falls back to the Platform row.
- """
- if ctx.current_principal_id is not None and not ctx.is_platform_admin:
- org_row = await InferenceBackend.one_by_fields(
- session,
- {
- "backend_name": backend_name,
- "owner_principal_id": ctx.current_principal_id,
- },
- )
- if org_row is not None:
- return org_row
- backend = await InferenceBackend.one_by_fields(
- session,
- {"backend_name": backend_name, "owner_principal_id": None},
- )
- if not backend:
- raise BadRequestException(message=f"Inference backend {backend_name} not found")
- return backend
- @router.post("", response_model=InferenceBackend)
- async def create_inference_backend(
- session: SessionDep,
- ctx: TenantContextDep,
- backend_in: InferenceBackendCreate,
- ):
- """
- Create a new inference backend.
- Hybrid scope:
- - Platform admin: owner_principal_id NULL (Platform) or any Org id.
- - Org owner / manager: owner_principal_id locked to their current Org.
- Same backend_name as a Platform built-in IS allowed for an Org row
- (extension/override) — the case-insensitive duplicate check only
- bites when creating a Platform row that conflicts with a built-in.
- """
- target_org_id = getattr(backend_in, "owner_principal_id", None)
- validate_owner_principal(
- target_org_id,
- ctx,
- resource_label="inference backend",
- )
- # Platform-scoped rows can't shadow a built-in name (case-insensitive)
- # — the seeding controller owns those. Org-scoped rows MAY use the same
- # name to extend / override a built-in for that Org.
- if target_org_id is None and is_built_in_backend(backend_in.backend_name):
- raise BadRequestException(
- message=(
- f"Backend name {backend_in.backend_name} duplicates with built-in backends (case-insensitive). Please use another name."
- ),
- )
- backend_in.backend_source = BackendSourceEnum.CUSTOM
- backend_in.enabled = True
- # Composite unique on (backend_name, owner_principal_id) — uniqueness check
- # is scoped to the same tenant.
- existing = await InferenceBackend.one_by_fields(
- session,
- {
- "backend_name": backend_in.backend_name,
- "owner_principal_id": target_org_id,
- },
- )
- if existing:
- raise BadRequestException(
- message=f"Inference backend with name '{backend_in.backend_name}' already exists",
- )
- # Validate version names for custom backends before creating
- validate_custom_suffix(backend_in.backend_name, None)
- for version in backend_in.version_configs.root.keys():
- backend_in.version_configs.root[version].built_in_frameworks = None
- try:
- backend = InferenceBackend(
- backend_name=backend_in.backend_name,
- version_configs=backend_in.version_configs,
- default_version=backend_in.default_version,
- default_backend_param=backend_in.default_backend_param,
- default_run_command=backend_in.default_run_command,
- default_entrypoint=backend_in.default_entrypoint,
- health_check_path=backend_in.health_check_path,
- description=backend_in.description,
- default_env=backend_in.default_env,
- enabled=backend_in.enabled,
- backend_source=backend_in.backend_source,
- owner_principal_id=target_org_id,
- )
- backend = await InferenceBackend.create(session, backend)
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to create inference backend: {e}"
- )
- return backend
- async def _redirect_global_edit_to_org_row(
- session,
- ctx,
- backend: InferenceBackend,
- backend_in: InferenceBackendUpdate,
- ) -> Optional[InferenceBackend]:
- """If the caller is in an Org context and the target is a Global
- row, route the write to that Org's row. Applies to admin acting-as
- too — when admin has switched to Default Org, "enable community
- backend" should land in Default's scope, not modify Platform.
- Returns:
- - the existing Org row if found (caller continues the update on it), OR
- - the freshly created Org row (early return; caller should propagate).
- Returns ``None`` when no redirect is needed (target already
- belongs to the caller's Org, or caller is in "All" mode).
- """
- if backend.owner_principal_id is not None or ctx.current_principal_id is None:
- return None
- org_row = await InferenceBackend.one_by_fields(
- session,
- {
- "backend_name": backend.backend_name,
- "owner_principal_id": ctx.current_principal_id,
- },
- )
- if org_row is not None:
- return org_row
- # No Org row yet — seed one from the submitted payload. The Org row
- # inherits is_built_in / backend_source from the Platform row it
- # extends: an Org-scoped vLLM is still vLLM (a BUILT_IN backend),
- # not a freshly invented custom backend. That keeps suffix-validation
- # and other built-in-aware code paths firing identically.
- new_row = InferenceBackend(
- backend_name=backend_in.backend_name,
- version_configs=backend_in.version_configs,
- default_version=backend_in.default_version,
- default_backend_param=backend_in.default_backend_param,
- default_run_command=backend_in.default_run_command,
- default_entrypoint=backend_in.default_entrypoint,
- health_check_path=backend_in.health_check_path,
- description=backend_in.description,
- default_env=backend_in.default_env,
- enabled=True,
- is_built_in=backend.is_built_in,
- backend_source=backend.backend_source,
- owner_principal_id=ctx.current_principal_id,
- )
- return await InferenceBackend.create(session, new_row)
- @router.put("/{id}", response_model=InferenceBackend)
- async def update_inference_backend( # noqa: C901
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- backend_in: InferenceBackendUpdate,
- ):
- """
- Update an existing inference backend.
- """
- backend = await InferenceBackend.one_by_id(session, id)
- if not backend:
- raise NotFoundException(message=f"Inference backend {id} not found")
- redirected = await _redirect_global_edit_to_org_row(
- session, ctx, backend, backend_in
- )
- if redirected is not None:
- # Continue the update flow against the Org row instead of the
- # Global row the caller targeted. For a freshly created Org row
- # the downstream update is effectively a no-op rewrite of the
- # same payload — which is fine and keeps the response shape
- # consistent for both branches.
- backend = redirected
- assert_org_owned_writable(ctx, backend, resource_label="inference backend")
- # Check if updating to a name that already exists (excluding current backend)
- if backend_in.backend_name != backend.backend_name:
- raise BadRequestException(
- message="The name of inference-backend can not be modified",
- )
- # Validate that built-in backends cannot have default_version set
- if is_built_in_backend(backend.backend_name) and backend_in.default_version:
- raise BadRequestException(
- message=f"Built-in backend '{backend.backend_name}' cannot have default_version set. Default version is managed automatically.",
- )
- if backend_in.version_configs is not None:
- await _validate_version_removal(session, backend, backend_in.version_configs)
- # Validate version names for custom backends before updating
- if backend.backend_source == BackendSourceEnum.CUSTOM or (
- backend.backend_source is None and not backend.is_built_in
- ):
- validate_custom_suffix(backend_in.backend_name, None)
- else:
- validate_custom_suffix(None, backend_in.version_configs)
- for version in backend_in.version_configs.root.keys():
- backend_in.version_configs.root[version].built_in_frameworks = None
- try:
- # Use a dict for changes to prevent version_config serialization errors and None field overrides issues.
- update_data = {
- "backend_name": backend_in.backend_name,
- "version_configs": backend_in.version_configs,
- "default_version": backend_in.default_version,
- "default_backend_param": backend_in.default_backend_param,
- "default_run_command": backend_in.default_run_command,
- "default_entrypoint": backend_in.default_entrypoint,
- "health_check_path": backend_in.health_check_path,
- "description": backend_in.description,
- "default_env": backend_in.default_env,
- "backend_source": backend_in.backend_source,
- }
- if backend_in.backend_source == BackendSourceEnum.COMMUNITY:
- if backend_in.enabled is not None:
- update_data["enabled"] = backend_in.enabled
- built_in_version = {
- k: v
- for k, v in backend.version_configs.root.items()
- if v.built_in_frameworks
- }
- # merge built-in versions with custom versions for update
- built_in_version.update(update_data['version_configs'].root)
- update_data['version_configs'].root = built_in_version
- await backend.update(session, update_data)
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to update inference backend: {e}"
- )
- return backend
- @router.delete("/{id}")
- async def delete_inference_backend(session: SessionDep, ctx: TenantContextDep, id: int):
- """
- Delete an inference backend.
- """
- backend = await InferenceBackend.one_by_id(session, id)
- if not backend:
- raise NotFoundException(message=f"Inference backend {id} not found")
- assert_org_owned_writable(ctx, backend, resource_label="inference backend")
- # Protect Platform-curated rows (built-in / community at the global
- # scope). Org-scoped rows are always deletable by their owner — even
- # when they're a vLLM extension carrying source=BUILT_IN — because
- # they're the Org's own data, not platform-curated.
- if (
- backend.owner_principal_id is None
- and backend.backend_source != BackendSourceEnum.CUSTOM
- and backend.backend_source is not None
- ):
- raise BadRequestException(message="Cannot delete built-in or community backend")
- # Check if the backend is being used by any models
- is_in_use, model_names = await check_backend_in_use(session, backend.backend_name)
- if is_in_use:
- raise BadRequestException(
- message=f"Cannot delete backend '{backend.backend_name}' because it is currently being used by the following models: {', '.join(model_names)}",
- )
- try:
- await backend.delete(session)
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to delete inference backend: {e}"
- )
- @router.post("/from-yaml", response_model=InferenceBackend)
- async def create_inference_backend_from_yaml( # noqa: C901
- session: SessionDep, ctx: TenantContextDep, payload: dict = Body(...)
- ):
- """
- Create an inference backend from YAML configuration.
- Expected YAML format:
- ```yaml
- backend_name: "my-custom-backend"
- version_configs:
- "v1.0.0":
- image_name: "my-backend:v1.0.0"
- run_command: "python server.py --port {{port}} --model {{model_path}}"
- "v1.1.0":
- image_name: "my-backend:v1.1.0"
- run_command: "python server.py --port {{port}} --model {{model_path}}"
- default_version: "v1.1.0"
- default_backend_param: ["--max-tokens", "2048"]
- default_run_command: "python server.py"
- description: "My custom inference backend"
- health_check_path: "/health"
- allowed_proxy_uris: ["/v1/chat/completions", "/v1/completions"]
- ```
- """
- try:
- # Extract YAML content from JSON payload
- yaml_content = payload.get("content")
- if not yaml_content:
- raise BadRequestException(message="Missing 'content' field in request body")
- # Parse YAML content
- req_yaml_data = yaml.safe_load(yaml_content)
- # Validate required fields
- if not req_yaml_data.get("backend_name"):
- raise BadRequestException(message="backend_name is required in YAML")
- target_org_id = req_yaml_data.get("owner_principal_id")
- validate_owner_principal(
- target_org_id,
- ctx,
- resource_label="inference backend",
- )
- # Platform rows can't shadow built-in names; Org rows may extend them.
- if target_org_id is None and is_built_in_backend(req_yaml_data["backend_name"]):
- raise BadRequestException(
- message=(
- f"Backend name {req_yaml_data['backend_name']} duplicates with built-in backends (case-insensitive). Please use another name."
- ),
- )
- req_yaml_data["backend_source"] = BackendSourceEnum.CUSTOM
- req_yaml_data["enabled"] = True
- # Composite uniqueness — same backend_name allowed across tenants.
- existing = await InferenceBackend.one_by_fields(
- session,
- {
- "backend_name": req_yaml_data["backend_name"],
- "owner_principal_id": target_org_id,
- },
- )
- if existing:
- raise BadRequestException(
- message=f"Inference backend with name '{req_yaml_data['backend_name']}' already exists",
- )
- allowed_keys = [
- "backend_name",
- "version_configs",
- "default_version",
- "default_backend_param",
- "default_run_command",
- "health_check_path",
- "description",
- "default_env",
- "enabled",
- "backend_source",
- ]
- yaml_data = {k: req_yaml_data[k] for k in allowed_keys if k in req_yaml_data}
- # Convert version_configs to VersionConfigDict if present
- if 'version_configs' in yaml_data and yaml_data['version_configs']:
- version_configs_dict = {}
- for version, config in yaml_data['version_configs'].items():
- if config.get('built_in_frameworks'):
- config['built_in_frameworks'] = None
- version_configs_dict[version] = VersionConfig(**config)
- yaml_data['version_configs'] = VersionConfigDict(root=version_configs_dict)
- # Validate version names for custom backends
- validate_custom_suffix(yaml_data['backend_name'], None)
- # Validate YAML data using Pydantic model to ensure field types are correct
- try:
- InferenceBackendCreate.model_validate(yaml_data)
- except ValidationError as e:
- raise BadRequestException(message=f"Invalid YAML data: {e}")
- # Create the backend
- backend = InferenceBackend(**yaml_data, owner_principal_id=target_org_id)
- backend = await InferenceBackend.create(session, backend)
- return backend
- except yaml.YAMLError as e:
- raise BadRequestException(message=f"Invalid YAML format: {e}")
- except BadRequestException:
- raise # Re-raise BadRequestException without wrapping
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to create inference backend from YAML: {e.__str__()}"
- )
- @router.put("/{id}/from-yaml", response_model=InferenceBackend)
- async def update_inference_backend_from_yaml( # noqa: C901
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- payload: dict = Body(...),
- ):
- """
- Update an existing inference backend from YAML configuration.
- Expected JSON format:
- ```json
- {
- "content": "backend_name: \"my-custom-backend\"\nversion_configs:\n \"v1.0.0\":\n image_name: \"my-backend:v1.0.0\"\n run_command: \"python server.py --port {{port}} --model {{model_path}}\"\n \"v1.1.0\":\n image_name: \"my-backend:v1.1.0\"\n run_command: \"python server.py --port {{port}} --model {{model_path}}\"\ndefault_version: \"v1.1.0\"\ndefault_backend_param: [\"--max-tokens\", \"2048\"]\ndefault_run_command: \"python server.py\"\ndescription: \"My custom inference backend\"\nhealth_check_path: \"/health\"\nallowed_proxy_uris: [\"/v1/chat/completions\", \"/v1/completions\"]"
- }
- """
- backend = await InferenceBackend.one_by_id(session, id)
- if not backend:
- raise NotFoundException(message=f"Inference backend {id} not found")
- assert_org_owned_writable(ctx, backend, resource_label="inference backend")
- try:
- # Extract YAML content from JSON payload
- yaml_content = payload.get("content")
- if not yaml_content:
- raise BadRequestException(message="Missing 'content' field in request body")
- # Parse YAML content
- req_yaml_data = yaml.safe_load(yaml_content)
- # Validate required fields
- if not req_yaml_data.get("backend_name"):
- raise BadRequestException(message="backend_name is required in YAML")
- # Check if updating to a name that already exists (excluding current backend)
- if req_yaml_data["backend_name"] != backend.backend_name:
- raise BadRequestException(
- message="The name of inference-backend can not be modified",
- )
- allowed_keys = [
- "backend_name",
- "version_configs",
- "default_backend_param",
- "default_run_command",
- "default_entrypoint",
- "health_check_path",
- "description",
- "default_env",
- "enabled",
- "backend_source",
- ]
- if not is_built_in_backend(backend.backend_name):
- allowed_keys.append("default_version")
- yaml_data = {k: req_yaml_data[k] for k in allowed_keys if k in req_yaml_data}
- # Process version_configs if present
- yaml_data['version_configs'] = _process_version_configs(
- yaml_data.get('version_configs')
- )
- # Check if any versions are being removed and validate they're not in use
- await _validate_version_removal(
- session, backend, yaml_data.get('version_configs')
- )
- # Validate version names based on backend source
- if backend.backend_source == BackendSourceEnum.CUSTOM or (
- backend.backend_source is None and not backend.is_built_in
- ):
- validate_custom_suffix(yaml_data['backend_name'], None)
- else:
- validate_custom_suffix(None, yaml_data.get('version_configs'))
- # Clear built_in_frameworks for all versions in yaml_data
- _clear_built_in_frameworks(yaml_data.get('version_configs'))
- # Merge built-in versions for COMMUNITY backends
- if backend.backend_source == BackendSourceEnum.COMMUNITY:
- yaml_data['version_configs'] = _merge_community_versions(
- backend, yaml_data.get('version_configs')
- )
- # Validate YAML data using Pydantic model to ensure field types are correct
- try:
- InferenceBackendUpdate.model_validate(yaml_data)
- except ValidationError as e:
- raise BadRequestException(message=f"Invalid YAML data: {e}")
- # Update the backend from YAML data (after normalization)
- await backend.update(session, yaml_data)
- return backend
- except yaml.YAMLError as e:
- raise BadRequestException(message=f"Invalid YAML format: {e}")
- except BadRequestException:
- raise # Re-raise BadRequestException without wrapping
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to update inference backend from YAML: {e}"
- )
- def _process_version_configs(
- version_configs_data: Optional[dict],
- ) -> VersionConfigDict:
- """
- Convert raw version_configs dict to VersionConfigDict.
- Returns None if version_configs_data is None or empty.
- """
- version_configs_dict = {}
- for version, config in version_configs_data.items() if version_configs_data else []:
- # Clear built_in_frameworks during initial processing
- if config.get('built_in_frameworks'):
- config['built_in_frameworks'] = None
- version_configs_dict[version] = VersionConfig(**config)
- return VersionConfigDict(root=version_configs_dict)
- async def _validate_version_removal(
- session,
- backend: InferenceBackend,
- new_version_configs: Optional[VersionConfigDict],
- ):
- """
- Check if any versions are being removed and validate they're not in use.
- """
- # Get current versions (empty dict if none)
- current_versions = {}
- if backend.version_configs and backend.version_configs.root:
- current_versions = {
- v: config
- for v, config in backend.version_configs.root.items()
- if not config.built_in_frameworks
- }
- # Get new versions (empty dict if none)
- new_versions = {}
- if new_version_configs and new_version_configs.root:
- new_versions = new_version_configs.root
- # Find removed versions
- removed_versions = set(current_versions.keys()) - set(new_versions.keys())
- # Check if removed versions are in use
- for version in removed_versions:
- is_in_use, model_names = await check_backend_in_use(
- session, backend.backend_name, version
- )
- if is_in_use:
- raise BadRequestException(
- message=f"Cannot remove version name '{version}' of backend '{backend.backend_name}' because it is currently being used by the following models: {', '.join(model_names)}",
- )
- def _clear_built_in_frameworks(version_configs: Optional[VersionConfigDict]):
- """
- Clear built_in_frameworks for all versions in version_configs.
- """
- if not version_configs or not version_configs.root:
- return
- for version_config in version_configs.root.values():
- version_config.built_in_frameworks = None
- def _merge_community_versions(
- backend: InferenceBackend,
- new_version_configs: Optional[VersionConfigDict],
- ) -> VersionConfigDict:
- """
- Merge built-in versions with new versions for COMMUNITY backends.
- Returns:
- VersionConfigDict: Merged version configurations with built-in versions preserved
- """
- # Extract built-in versions from current backend
- built_in_versions = {}
- if backend.version_configs and backend.version_configs.root:
- built_in_versions = {
- k: v
- for k, v in backend.version_configs.root.items()
- if v.built_in_frameworks
- }
- if not new_version_configs or not new_version_configs.root:
- return VersionConfigDict(root=built_in_versions)
- # Merge: built-in versions + new versions (new versions take precedence)
- built_in_versions.update(new_version_configs.root or {})
- new_version_configs.root = built_in_versions
- return new_version_configs
- def validate_custom_suffix(
- backend_name: Optional[str],
- version_configs: Optional[VersionConfigDict],
- ):
- """
- Validate custom suffix for backend names and version names.
- Rules:
- - Backend name: Must end with '-custom' if provided
- - Version name: Must end with '-custom' ONLY if it's a user-defined version
- (i.e., built_in_frameworks is None and custom_framework has value)
- """
- # Validate backend name
- if backend_name and not backend_name.endswith("-custom"):
- raise BadRequestException(
- message=f"Custom backend name '{backend_name}' must end with '-custom'",
- )
- # Validate version names
- if version_configs and version_configs.root:
- for version, config in version_configs.root.items():
- # Skip predefined versions (built_in_frameworks has value)
- if config.built_in_frameworks:
- continue
- # User-defined versions must have -custom suffix
- if not isinstance(version, str) or not version.endswith("-custom"):
- raise BadRequestException(
- message=f"Custom backend version '{version}' must end with '-custom'",
- )
|