| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381 |
- import re
- from enum import Enum
- from typing import ClassVar, Optional, Dict, Any, List, Set
- from pydantic import BaseModel, field_validator, model_validator
- from sqlmodel import (
- Field,
- Relationship,
- Column,
- SQLModel,
- Integer,
- ForeignKey,
- JSON,
- )
- from typing import TYPE_CHECKING
- from gpustack.mixins import BaseModelMixin
- from gpustack.schemas.common import (
- ListParams,
- PaginatedList,
- PublicFields,
- ItemList,
- )
- from gpustack.schemas.organizations import PLATFORM_ORGANIZATION_ID
- if TYPE_CHECKING:
- from gpustack.schemas.models import Model
- from gpustack.schemas.model_provider import ModelProvider
- # Route names intentionally exclude `/` — the dispatch parser
- # (`UserService.get_model_ids_by_model_route_name`) splits the inbound
- # `model` string on the first `/` to separate Org slug from raw name.
- # Allowing `/` inside route names would create irresolvable ambiguity
- # (e.g. literal route "a/b" in platform Org vs. route "b" in Org with
- # slug "a"). Keep the two char sets disjoint.
- name_pattern = r'^[A-Za-z](?:[A-Za-z0-9_\-\.]*[A-Za-z0-9])?$'
- def effective_route_name(
- route_name: str,
- org_slug: Optional[str],
- is_platform_org: bool,
- ) -> str:
- """The model name clients see and gateways route on.
- The platform Org keeps unprefixed names (backward compat — existing
- clients calling `model: "qwen3-0.6b"` keep working). Other Orgs get
- a slug prefix (`org1/qwen3-0.6b`) so two Orgs can use the same route
- name without colliding in Higress's AI proxy match rules.
- Format follows the OpenAI / HuggingFace / OpenRouter convention
- (`namespace/model`); slug is already constrained to
- `^[a-z](?:[a-z0-9\\-]*[a-z0-9])?$` and route names exclude `/` (see
- ``name_pattern``) so the joined string parses unambiguously.
- """
- if is_platform_org or not org_slug:
- return route_name
- return f"{org_slug}/{route_name}"
- class AccessPolicyEnum(str, Enum):
- PUBLIC = "public"
- AUTHED = "authed"
- # ORG = scoped to members of the route's owning Organization. The
- # default for new routes in non-platform Orgs — semantically the
- # "team-private" scope, no principal table involvement.
- ORG = "org"
- # Per-user grants. The OSS UI surfaces only this policy for explicit
- # access lists since it doesn't expose Org / Group concepts; rows
- # are stored in ``model_route_principals`` with ``principal_id``
- # pointing at a USER-kind principal.
- ALLOWED_USERS = "allowed_users"
- # Per-principal grants (user / org / group) via
- # ``model_route_principals``. Surfaced by the enterprise UI.
- ALLOWED_PRINCIPALS = "allowed_principals"
- class TargetStateEnum(str, Enum):
- ACTIVE = "active"
- UNAVAILABLE = "unavailable"
- class FallbackStatusEnum(str, Enum):
- ERROR_400 = "4xx"
- ERROR_500 = "5xx"
- class ModelRouteTargetUpdate(SQLModel):
- provider_model_name: Optional[str] = Field(default=None, nullable=True)
- weight: int = Field(default=0, nullable=False, ge=0)
- model_id: Optional[int] = Field(
- default=None,
- sa_column=Column(
- Integer,
- ForeignKey(
- "models.id",
- ondelete="CASCADE",
- ),
- nullable=True,
- ),
- )
- provider_id: Optional[int] = Field(
- default=None,
- sa_column=Column(
- Integer,
- ForeignKey(
- "model_providers.id",
- ondelete="CASCADE",
- ),
- nullable=True,
- ),
- )
- @model_validator(mode="after")
- def check_provider_or_model(self):
- both_set = self.provider_id is not None and self.model_id is not None
- both_none = self.provider_id is None and self.model_id is None
- name_missing = self.provider_model_name is None and self.provider_id is not None
- invalid_name = (
- self.provider_model_name is not None and self.model_id is not None
- )
- if both_none:
- raise ValueError("Either provider_id or model_id must be provided.")
- if both_set:
- raise ValueError("Only one of provider_id or model_id can be provided.")
- if name_missing:
- raise ValueError(
- "provider_model_name must be provided when provider_id is set."
- )
- if invalid_name:
- raise ValueError("provider_model_name must be None when model_id is set.")
- return self
- class ModelRouteTargetCreate(ModelRouteTargetUpdate):
- fallback_status_codes: Optional[List[str]] = Field(
- default=None,
- sa_column=Column(
- JSON,
- nullable=True,
- ),
- )
- @field_validator("fallback_status_codes", mode="before")
- def validate_fallback_status_codes(cls, v):
- if v is None:
- return v
- deduped: Set[str] = set(v)
- for status in deduped:
- if status not in [
- FallbackStatusEnum.ERROR_400,
- FallbackStatusEnum.ERROR_500,
- ]:
- raise ValueError(f"Invalid fallback status code: {status}")
- return list(deduped)
- class ModelRouteTargetBase(ModelRouteTargetCreate):
- name: str = Field(nullable=False)
- route_name: str = Field(nullable=False)
- route_id: int = Field(
- sa_column=Column(
- Integer,
- ForeignKey(
- "model_routes.id",
- ondelete="CASCADE",
- ),
- nullable=False,
- )
- )
- state: TargetStateEnum = Field(default=TargetStateEnum.ACTIVE, nullable=False)
- @field_validator("route_name", mode="before")
- def validate_route_name(cls, v):
- if not isinstance(v, str):
- raise ValueError("route_name must be a string")
- if not re.match(name_pattern, v):
- raise ValueError(
- "route_name must start with a letter, only contain letters, numbers, hyphens, underscores, and not end with hyphen or underscore"
- )
- return v
- class ModelRouteTarget(ModelRouteTargetBase, BaseModelMixin, table=True):
- __tablename__: ClassVar[str] = "model_route_targets"
- id: Optional[int] = Field(default=None, primary_key=True)
- model_route: "ModelRoute" = Relationship(
- back_populates="route_targets",
- sa_relationship_kwargs={"lazy": "noload"},
- )
- provider: Optional["ModelProvider"] = Relationship(
- back_populates="model_route_targets",
- sa_relationship_kwargs={"lazy": "noload"},
- )
- model: Optional["Model"] = Relationship(
- back_populates="model_route_targets",
- sa_relationship_kwargs={"lazy": "noload"},
- )
- class ModelRouteTargetPublic(ModelRouteTargetBase, PublicFields):
- pass
- ModelRouteTargetsPublic = PaginatedList[ModelRouteTargetPublic]
- class ModelRouteTargetListParams(ListParams):
- route_id: Optional[int] = None
- route_name: Optional[str] = None
- model_id: Optional[int] = None
- provider_id: Optional[int] = None
- sortable_fields: ClassVar[List[str]] = [
- "id",
- "created_at",
- "updated_at",
- "name",
- "weight",
- "state",
- ]
- class ModelRouteTargetUpdateItem(ModelRouteTargetCreate):
- id: Optional[int] = None
- class ModelRouteUpdateBase(SQLModel):
- name: str = Field(nullable=False)
- description: Optional[str] = Field(default=None, nullable=True)
- categories: List[str] = Field(sa_type=JSON, default=[])
- meta: Optional[Dict[str, Any]] = Field(sa_type=JSON, default={})
- generic_proxy: Optional[bool] = Field(default=False)
- @field_validator("categories", mode="before")
- def validate_categories(cls, v):
- if v is None:
- return v
- for category in v:
- if category not in [
- "llm",
- "embedding",
- "image",
- "reranker",
- "speech_to_text",
- "text_to_speech",
- "unknown",
- ]:
- raise ValueError(f"Invalid category: {category}")
- return v
- @field_validator("name", mode="before")
- def validate_name(cls, v):
- if not isinstance(v, str):
- raise ValueError("name must be a string")
- if not re.match(name_pattern, v):
- raise ValueError(
- "name must start with a letter, only contain letters, numbers, hyphens, underscores, and not end with hyphen or underscore"
- )
- return v
- class ModelRouteUpdate(ModelRouteUpdateBase):
- targets: Optional[List[ModelRouteTargetUpdateItem]] = Field(
- default=None, nullable=True
- )
- class ModelRouteCreate(ModelRouteUpdate):
- pass
- class ModelRouteBase(ModelRouteUpdateBase):
- created_by_model: Optional[bool] = Field(default=False, nullable=False)
- targets: int = Field(default=0, nullable=False, ge=0)
- ready_targets: int = Field(default=0, nullable=False, ge=0)
- access_policy: AccessPolicyEnum = Field(default=AccessPolicyEnum.AUTHED)
- owner_principal_id: int = Field(
- default=PLATFORM_ORGANIZATION_ID,
- foreign_key="principals.id",
- nullable=False,
- )
- class ModelRoute(ModelRouteBase, BaseModelMixin, table=True):
- __tablename__: ClassVar[str] = "model_routes"
- id: Optional[int] = Field(default=None, primary_key=True)
- route_targets: List[ModelRouteTarget] = Relationship(
- back_populates="model_route",
- sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
- )
- models: List["Model"] = Relationship(
- back_populates="model_routes",
- link_model=ModelRouteTarget,
- sa_relationship_kwargs={
- "lazy": "noload",
- "overlaps": "route_targets,model_route,model",
- },
- )
- class ModelRoutePublic(ModelRouteBase, PublicFields):
- # The model name clients should send in their request body. Equals
- # `name` for the platform Org (backward compat); for other Orgs it
- # is `<org-slug>/<name>`. Frontends currently derive this themselves
- # via `effectiveRouteName(name, org)` since they have the owning Org
- # row in hand from a separate fetch — the field is reserved here so
- # a future server-side enrichment can populate it without breaking
- # consumers.
- effective_name: Optional[str] = None
- ModelRoutesPublic = PaginatedList[ModelRoutePublic]
- class ModelRouteListParams(ListParams):
- sortable_fields: ClassVar[List[str]] = [
- "id",
- "created_at",
- "updated_at",
- "name",
- "targets",
- "ready_targets",
- ]
- class SetFallbackTargetInput(BaseModel):
- fallback_status_codes: Optional[List[str]] = Field(
- default=None,
- sa_column=Column(
- JSON,
- nullable=True,
- ),
- )
- @field_validator("fallback_status_codes", mode="before")
- def validate_fallback_status_codes(cls, v):
- if v is None:
- return v
- deduped: Set[FallbackStatusEnum] = set(v)
- for status in deduped:
- if status not in [
- FallbackStatusEnum.ERROR_400,
- FallbackStatusEnum.ERROR_500,
- ]:
- raise ValueError(f"Invalid fallback status code: {status}")
- return list(deduped)
- class ModelUserAccess(BaseModel):
- id: int
- # More custom fields can be added here, e.g., quota, rate_limit, etc.
- class ModelAuthorizationUpdate(BaseModel):
- access_policy: Optional[AccessPolicyEnum] = None
- users: List[ModelUserAccess]
- class ModelUserAccessExtended(ModelUserAccess):
- username: Optional[str] = None
- full_name: Optional[str] = None
- avatar_url: Optional[str] = None
- # More user fields can be added here. e.g. quota, rate_limit, etc.
- ModelAuthorizationList = ItemList[ModelUserAccessExtended]
- class MyModel(ModelRouteBase, BaseModelMixin, table=True):
- __tablename__ = 'non_admin_user_models'
- __mapper_args__ = {'primary_key': ["pid"]}
- pid: str
- id: int
- user_id: int = Field(default=0)
- class MyModelPublic(ModelRoutePublic):
- pass
|