model_routes.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. import re
  2. from enum import Enum
  3. from typing import ClassVar, Optional, Dict, Any, List, Set
  4. from pydantic import BaseModel, field_validator, model_validator
  5. from sqlmodel import (
  6. Field,
  7. Relationship,
  8. Column,
  9. SQLModel,
  10. Integer,
  11. ForeignKey,
  12. JSON,
  13. )
  14. from typing import TYPE_CHECKING
  15. from gpustack.mixins import BaseModelMixin
  16. from gpustack.schemas.common import (
  17. ListParams,
  18. PaginatedList,
  19. PublicFields,
  20. ItemList,
  21. )
  22. from gpustack.schemas.organizations import PLATFORM_ORGANIZATION_ID
  23. if TYPE_CHECKING:
  24. from gpustack.schemas.models import Model
  25. from gpustack.schemas.model_provider import ModelProvider
  26. # Route names intentionally exclude `/` — the dispatch parser
  27. # (`UserService.get_model_ids_by_model_route_name`) splits the inbound
  28. # `model` string on the first `/` to separate Org slug from raw name.
  29. # Allowing `/` inside route names would create irresolvable ambiguity
  30. # (e.g. literal route "a/b" in platform Org vs. route "b" in Org with
  31. # slug "a"). Keep the two char sets disjoint.
  32. name_pattern = r'^[A-Za-z](?:[A-Za-z0-9_\-\.]*[A-Za-z0-9])?$'
  33. def effective_route_name(
  34. route_name: str,
  35. org_slug: Optional[str],
  36. is_platform_org: bool,
  37. ) -> str:
  38. """The model name clients see and gateways route on.
  39. The platform Org keeps unprefixed names (backward compat — existing
  40. clients calling `model: "qwen3-0.6b"` keep working). Other Orgs get
  41. a slug prefix (`org1/qwen3-0.6b`) so two Orgs can use the same route
  42. name without colliding in Higress's AI proxy match rules.
  43. Format follows the OpenAI / HuggingFace / OpenRouter convention
  44. (`namespace/model`); slug is already constrained to
  45. `^[a-z](?:[a-z0-9\\-]*[a-z0-9])?$` and route names exclude `/` (see
  46. ``name_pattern``) so the joined string parses unambiguously.
  47. """
  48. if is_platform_org or not org_slug:
  49. return route_name
  50. return f"{org_slug}/{route_name}"
  51. class AccessPolicyEnum(str, Enum):
  52. PUBLIC = "public"
  53. AUTHED = "authed"
  54. # ORG = scoped to members of the route's owning Organization. The
  55. # default for new routes in non-platform Orgs — semantically the
  56. # "team-private" scope, no principal table involvement.
  57. ORG = "org"
  58. # Per-user grants. The OSS UI surfaces only this policy for explicit
  59. # access lists since it doesn't expose Org / Group concepts; rows
  60. # are stored in ``model_route_principals`` with ``principal_id``
  61. # pointing at a USER-kind principal.
  62. ALLOWED_USERS = "allowed_users"
  63. # Per-principal grants (user / org / group) via
  64. # ``model_route_principals``. Surfaced by the enterprise UI.
  65. ALLOWED_PRINCIPALS = "allowed_principals"
  66. class TargetStateEnum(str, Enum):
  67. ACTIVE = "active"
  68. UNAVAILABLE = "unavailable"
  69. class FallbackStatusEnum(str, Enum):
  70. ERROR_400 = "4xx"
  71. ERROR_500 = "5xx"
  72. class ModelRouteTargetUpdate(SQLModel):
  73. provider_model_name: Optional[str] = Field(default=None, nullable=True)
  74. weight: int = Field(default=0, nullable=False, ge=0)
  75. model_id: Optional[int] = Field(
  76. default=None,
  77. sa_column=Column(
  78. Integer,
  79. ForeignKey(
  80. "models.id",
  81. ondelete="CASCADE",
  82. ),
  83. nullable=True,
  84. ),
  85. )
  86. provider_id: Optional[int] = Field(
  87. default=None,
  88. sa_column=Column(
  89. Integer,
  90. ForeignKey(
  91. "model_providers.id",
  92. ondelete="CASCADE",
  93. ),
  94. nullable=True,
  95. ),
  96. )
  97. @model_validator(mode="after")
  98. def check_provider_or_model(self):
  99. both_set = self.provider_id is not None and self.model_id is not None
  100. both_none = self.provider_id is None and self.model_id is None
  101. name_missing = self.provider_model_name is None and self.provider_id is not None
  102. invalid_name = (
  103. self.provider_model_name is not None and self.model_id is not None
  104. )
  105. if both_none:
  106. raise ValueError("Either provider_id or model_id must be provided.")
  107. if both_set:
  108. raise ValueError("Only one of provider_id or model_id can be provided.")
  109. if name_missing:
  110. raise ValueError(
  111. "provider_model_name must be provided when provider_id is set."
  112. )
  113. if invalid_name:
  114. raise ValueError("provider_model_name must be None when model_id is set.")
  115. return self
  116. class ModelRouteTargetCreate(ModelRouteTargetUpdate):
  117. fallback_status_codes: Optional[List[str]] = Field(
  118. default=None,
  119. sa_column=Column(
  120. JSON,
  121. nullable=True,
  122. ),
  123. )
  124. @field_validator("fallback_status_codes", mode="before")
  125. def validate_fallback_status_codes(cls, v):
  126. if v is None:
  127. return v
  128. deduped: Set[str] = set(v)
  129. for status in deduped:
  130. if status not in [
  131. FallbackStatusEnum.ERROR_400,
  132. FallbackStatusEnum.ERROR_500,
  133. ]:
  134. raise ValueError(f"Invalid fallback status code: {status}")
  135. return list(deduped)
  136. class ModelRouteTargetBase(ModelRouteTargetCreate):
  137. name: str = Field(nullable=False)
  138. route_name: str = Field(nullable=False)
  139. route_id: int = Field(
  140. sa_column=Column(
  141. Integer,
  142. ForeignKey(
  143. "model_routes.id",
  144. ondelete="CASCADE",
  145. ),
  146. nullable=False,
  147. )
  148. )
  149. state: TargetStateEnum = Field(default=TargetStateEnum.ACTIVE, nullable=False)
  150. @field_validator("route_name", mode="before")
  151. def validate_route_name(cls, v):
  152. if not isinstance(v, str):
  153. raise ValueError("route_name must be a string")
  154. if not re.match(name_pattern, v):
  155. raise ValueError(
  156. "route_name must start with a letter, only contain letters, numbers, hyphens, underscores, and not end with hyphen or underscore"
  157. )
  158. return v
  159. class ModelRouteTarget(ModelRouteTargetBase, BaseModelMixin, table=True):
  160. __tablename__: ClassVar[str] = "model_route_targets"
  161. id: Optional[int] = Field(default=None, primary_key=True)
  162. model_route: "ModelRoute" = Relationship(
  163. back_populates="route_targets",
  164. sa_relationship_kwargs={"lazy": "noload"},
  165. )
  166. provider: Optional["ModelProvider"] = Relationship(
  167. back_populates="model_route_targets",
  168. sa_relationship_kwargs={"lazy": "noload"},
  169. )
  170. model: Optional["Model"] = Relationship(
  171. back_populates="model_route_targets",
  172. sa_relationship_kwargs={"lazy": "noload"},
  173. )
  174. class ModelRouteTargetPublic(ModelRouteTargetBase, PublicFields):
  175. pass
  176. ModelRouteTargetsPublic = PaginatedList[ModelRouteTargetPublic]
  177. class ModelRouteTargetListParams(ListParams):
  178. route_id: Optional[int] = None
  179. route_name: Optional[str] = None
  180. model_id: Optional[int] = None
  181. provider_id: Optional[int] = None
  182. sortable_fields: ClassVar[List[str]] = [
  183. "id",
  184. "created_at",
  185. "updated_at",
  186. "name",
  187. "weight",
  188. "state",
  189. ]
  190. class ModelRouteTargetUpdateItem(ModelRouteTargetCreate):
  191. id: Optional[int] = None
  192. class ModelRouteUpdateBase(SQLModel):
  193. name: str = Field(nullable=False)
  194. description: Optional[str] = Field(default=None, nullable=True)
  195. categories: List[str] = Field(sa_type=JSON, default=[])
  196. meta: Optional[Dict[str, Any]] = Field(sa_type=JSON, default={})
  197. generic_proxy: Optional[bool] = Field(default=False)
  198. @field_validator("categories", mode="before")
  199. def validate_categories(cls, v):
  200. if v is None:
  201. return v
  202. for category in v:
  203. if category not in [
  204. "llm",
  205. "embedding",
  206. "image",
  207. "reranker",
  208. "speech_to_text",
  209. "text_to_speech",
  210. "unknown",
  211. ]:
  212. raise ValueError(f"Invalid category: {category}")
  213. return v
  214. @field_validator("name", mode="before")
  215. def validate_name(cls, v):
  216. if not isinstance(v, str):
  217. raise ValueError("name must be a string")
  218. if not re.match(name_pattern, v):
  219. raise ValueError(
  220. "name must start with a letter, only contain letters, numbers, hyphens, underscores, and not end with hyphen or underscore"
  221. )
  222. return v
  223. class ModelRouteUpdate(ModelRouteUpdateBase):
  224. targets: Optional[List[ModelRouteTargetUpdateItem]] = Field(
  225. default=None, nullable=True
  226. )
  227. class ModelRouteCreate(ModelRouteUpdate):
  228. pass
  229. class ModelRouteBase(ModelRouteUpdateBase):
  230. created_by_model: Optional[bool] = Field(default=False, nullable=False)
  231. targets: int = Field(default=0, nullable=False, ge=0)
  232. ready_targets: int = Field(default=0, nullable=False, ge=0)
  233. access_policy: AccessPolicyEnum = Field(default=AccessPolicyEnum.AUTHED)
  234. owner_principal_id: int = Field(
  235. default=PLATFORM_ORGANIZATION_ID,
  236. foreign_key="principals.id",
  237. nullable=False,
  238. )
  239. class ModelRoute(ModelRouteBase, BaseModelMixin, table=True):
  240. __tablename__: ClassVar[str] = "model_routes"
  241. id: Optional[int] = Field(default=None, primary_key=True)
  242. route_targets: List[ModelRouteTarget] = Relationship(
  243. back_populates="model_route",
  244. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  245. )
  246. models: List["Model"] = Relationship(
  247. back_populates="model_routes",
  248. link_model=ModelRouteTarget,
  249. sa_relationship_kwargs={
  250. "lazy": "noload",
  251. "overlaps": "route_targets,model_route,model",
  252. },
  253. )
  254. class ModelRoutePublic(ModelRouteBase, PublicFields):
  255. # The model name clients should send in their request body. Equals
  256. # `name` for the platform Org (backward compat); for other Orgs it
  257. # is `<org-slug>/<name>`. Frontends currently derive this themselves
  258. # via `effectiveRouteName(name, org)` since they have the owning Org
  259. # row in hand from a separate fetch — the field is reserved here so
  260. # a future server-side enrichment can populate it without breaking
  261. # consumers.
  262. effective_name: Optional[str] = None
  263. ModelRoutesPublic = PaginatedList[ModelRoutePublic]
  264. class ModelRouteListParams(ListParams):
  265. sortable_fields: ClassVar[List[str]] = [
  266. "id",
  267. "created_at",
  268. "updated_at",
  269. "name",
  270. "targets",
  271. "ready_targets",
  272. ]
  273. class SetFallbackTargetInput(BaseModel):
  274. fallback_status_codes: Optional[List[str]] = Field(
  275. default=None,
  276. sa_column=Column(
  277. JSON,
  278. nullable=True,
  279. ),
  280. )
  281. @field_validator("fallback_status_codes", mode="before")
  282. def validate_fallback_status_codes(cls, v):
  283. if v is None:
  284. return v
  285. deduped: Set[FallbackStatusEnum] = set(v)
  286. for status in deduped:
  287. if status not in [
  288. FallbackStatusEnum.ERROR_400,
  289. FallbackStatusEnum.ERROR_500,
  290. ]:
  291. raise ValueError(f"Invalid fallback status code: {status}")
  292. return list(deduped)
  293. class ModelUserAccess(BaseModel):
  294. id: int
  295. # More custom fields can be added here, e.g., quota, rate_limit, etc.
  296. class ModelAuthorizationUpdate(BaseModel):
  297. access_policy: Optional[AccessPolicyEnum] = None
  298. users: List[ModelUserAccess]
  299. class ModelUserAccessExtended(ModelUserAccess):
  300. username: Optional[str] = None
  301. full_name: Optional[str] = None
  302. avatar_url: Optional[str] = None
  303. # More user fields can be added here. e.g. quota, rate_limit, etc.
  304. ModelAuthorizationList = ItemList[ModelUserAccessExtended]
  305. class MyModel(ModelRouteBase, BaseModelMixin, table=True):
  306. __tablename__ = 'non_admin_user_models'
  307. __mapper_args__ = {'primary_key': ["pid"]}
  308. pid: str
  309. id: int
  310. user_id: int = Field(default=0)
  311. class MyModelPublic(ModelRoutePublic):
  312. pass