inference_backend.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import re
  2. import shlex
  3. from datetime import datetime
  4. from typing import Dict, List, Optional
  5. from gpustack_runtime.deployer.__utils__ import compare_versions
  6. from pydantic import BaseModel, Field, RootModel
  7. from sqlalchemy import JSON, Column, ForeignKey, Integer, Text, UniqueConstraint
  8. from sqlmodel import SQLModel, Field as SQLField
  9. from gpustack.mixins import BaseModelMixin
  10. from .common import pydantic_column_type, PaginatedList
  11. from .models import BackendEnum, BackendSourceEnum
  12. class ContainerEnvConfig(BaseModel):
  13. """Container environment configuration."""
  14. user: Optional[int] = None
  15. group: Optional[int] = None
  16. shm_size_gib: float = 10.0
  17. class VersionConfig(BaseModel):
  18. """
  19. Configuration for a specific version of an inference backend.
  20. Attributes:
  21. image_name: Docker image name for this version
  22. run_command: Command to run the inference server (Optional, uses default if not specified)
  23. entrypoint: Container entrypoint command that overrides the default image entrypoint. (Optional)
  24. built_in_frameworks: Only built-in backend will return this field, sourced from gpustack-runner configuration. (Optional)
  25. custom_framework: User-provided value (upon backend creation) used for deployment and compatibility checks. (Optional)
  26. env: Environment variables for this version (Optional, merges with default_env)
  27. """
  28. image_name: Optional[str] = Field(None)
  29. run_command: Optional[str] = Field(None)
  30. entrypoint: Optional[str] = Field(None)
  31. built_in_frameworks: Optional[List[str]] = Field(None)
  32. custom_framework: Optional[str] = Field(None)
  33. env: Optional[Dict[str, str]] = Field(None)
  34. class VersionConfigDict(RootModel[Dict[str, VersionConfig]]):
  35. """
  36. Wrapper model for version configs dictionary to enable proper JSON serialization.
  37. """
  38. root: Dict[str, VersionConfig] = Field(default_factory=dict)
  39. # Database Models
  40. class InferenceBackendBase(SQLModel):
  41. """
  42. Base model for inference backends.
  43. Attributes:
  44. backend_name: Name of the backend (e.g., 'SGLang')
  45. version_configs: Dictionary mapping version strings to their configurations
  46. default_version: Default version to use if not specified
  47. default_backend_param: Default parameters to pass to the backend
  48. default_run_command: Default command to run the inference server
  49. default_entrypoint: Default entrypoint to replace for the inference server
  50. description: Backend description
  51. health_check_path: Path for health check endpoint
  52. """
  53. # Backend name is unique within an Org scope: one Platform-NULL row
  54. # plus optional one row per Org with the same backend_name (Hybrid
  55. # model). Composite unique is declared on the table class below.
  56. backend_name: str = SQLField(index=True)
  57. # Tenant scope. NULL = global (admin-managed). Non-NULL = an Org's
  58. # extension/override of a built-in or its own custom backend.
  59. owner_principal_id: Optional[int] = SQLField(
  60. default=None,
  61. sa_column=Column(
  62. Integer, ForeignKey("principals.id", ondelete="CASCADE"), nullable=True
  63. ),
  64. )
  65. version_configs: VersionConfigDict = SQLField(
  66. sa_column=Column(pydantic_column_type(VersionConfigDict)()),
  67. default_factory=lambda: VersionConfigDict(root={}),
  68. )
  69. default_version: Optional[str] = SQLField(default=None)
  70. default_backend_param: Optional[List[str]] = SQLField(
  71. sa_column=Column(JSON), default=[]
  72. )
  73. default_run_command: Optional[str] = SQLField(
  74. sa_column=Column(Text, nullable=True), default=""
  75. )
  76. default_entrypoint: Optional[str] = SQLField(
  77. sa_column=Column(Text, nullable=True), default=""
  78. )
  79. is_built_in: bool = SQLField(default=False)
  80. description: Optional[str] = SQLField(
  81. default=None, sa_column=Column(Text, nullable=True)
  82. )
  83. health_check_path: Optional[str] = SQLField(default=None)
  84. backend_source: Optional[BackendSourceEnum] = SQLField(default=None)
  85. enabled: Optional[bool] = SQLField(default=None)
  86. icon: Optional[str] = SQLField(default=None)
  87. default_env: Optional[Dict[str, str]] = SQLField(
  88. sa_column=Column(JSON), default=None
  89. )
  90. def resolve_target_version(self, version: Optional[str] = None) -> Optional[str]:
  91. """
  92. Resolve the target version to use based on the requested version, default version,
  93. and available version configs.
  94. Logic:
  95. - If requested/default version exists in version_configs, return it.
  96. - If using a non-built-in backend and version_configs exist, return the latest version
  97. (by compare_versions, falling back to lexicographical sort).
  98. - Otherwise, return None.
  99. """
  100. version_configs_dict = self.version_configs.root
  101. target_version = version or self.default_version
  102. # 1) Requested/default version exists
  103. if target_version in version_configs_dict:
  104. return target_version
  105. # 2) For non-built-in backends, auto-select the latest available version
  106. if version_configs_dict and not self.is_built_in:
  107. try:
  108. version_list = list(version_configs_dict.keys())
  109. latest_version = version_list[0]
  110. for ver in version_list[1:]:
  111. if compare_versions(ver, latest_version) > 0:
  112. latest_version = ver
  113. return latest_version
  114. except Exception:
  115. sorted_versions = sorted(version_configs_dict.keys())
  116. return sorted_versions[-1] if sorted_versions else None
  117. # 3) No suitable version found
  118. return None
  119. def get_version_config(self, version: Optional[str] = None) -> (VersionConfig, str):
  120. """
  121. Get configuration for a specific version.
  122. Args:
  123. version: Version string, uses default_version if None
  124. Returns:
  125. VersionConfig for the resolved version, and the resolved version string
  126. Raises:
  127. KeyError: If the version cannot be resolved from version_configs
  128. """
  129. target_version = self.resolve_target_version(version)
  130. if target_version is None:
  131. raise KeyError(
  132. f"Version '{version or self.default_version}' not found in backend '{self.backend_name}'"
  133. )
  134. return self.version_configs.root[target_version], target_version
  135. def get_run_command(self, version: Optional[str] = None) -> str:
  136. if not version:
  137. version = self.default_version
  138. version_config, _ = self.get_version_config(version)
  139. return version_config.run_command or self.default_run_command
  140. def get_backend_env(self, version: Optional[str] = None):
  141. """
  142. backend.version.env > backend.default_env
  143. """
  144. env_dict = {}
  145. if self.default_env:
  146. for k, v in self.default_env.items():
  147. env_dict[k] = v
  148. if version:
  149. try:
  150. version_config, _ = self.get_version_config(version)
  151. if version_config.env:
  152. for k, v in version_config.env.items():
  153. env_dict[k] = v
  154. except Exception:
  155. # built-in version may not include version config
  156. pass
  157. return env_dict
  158. def replace_command_param(
  159. self,
  160. version: Optional[str],
  161. model_path: Optional[str],
  162. port: Optional[int],
  163. worker_ip: Optional[str] = None,
  164. model_name: Optional[str] = None,
  165. command: Optional[str] = None,
  166. env: Optional[Dict[str, str]] = None,
  167. ) -> str:
  168. if not command:
  169. command = self.get_run_command(version)
  170. if not command:
  171. return ""
  172. command = command.replace("{{model_path}}", model_path or "")
  173. command = command.replace("{{port}}", str(port))
  174. command = command.replace("{{worker_ip}}", worker_ip or "")
  175. command = command.replace("{{model_name}}", model_name or "")
  176. # Resolve environment variables using {{VAR_NAME}} syntax
  177. # Use provided env (from model) if available, otherwise fall back to backend env
  178. if env:
  179. command = self._resolve_env_vars(command, env)
  180. return command
  181. def _resolve_env_vars(self, command: str, env_dict: Dict[str, str]) -> str:
  182. """
  183. Resolve {{VAR_NAME}} placeholders in the command string using the provided environment dict.
  184. Args:
  185. command: The command string with {{VAR_NAME}} placeholders
  186. env_dict: Dictionary of environment variable names to values
  187. Returns:
  188. Command with placeholders replaced by their values.
  189. If a variable is not found in env_dict, the placeholder is left unchanged.
  190. """
  191. # Match valid variable names: start with letter or underscore, followed by alphanumeric or underscore
  192. pattern = r"\{\{([A-Za-z_][A-Za-z0-9_]*)\}\}"
  193. def replace_var(match):
  194. var_name = match.group(1)
  195. return env_dict.get(var_name, match.group(0))
  196. return re.sub(pattern, replace_var, command)
  197. def get_container_entrypoint(
  198. self, version: Optional[str] = None
  199. ) -> Optional[List[str]]:
  200. """
  201. Get container entrypoint for the specified version.
  202. Args:
  203. version: Desired backend version; falls back to `default_version` when None.
  204. Returns:
  205. The container entrypoint string, or None if not configured.
  206. """
  207. if self.backend_name == BackendEnum.CUSTOM.value:
  208. return None
  209. try:
  210. # Resolve concrete version and fetch its configuration
  211. version_config, _ = self.get_version_config(version)
  212. except KeyError:
  213. # Version not found or cannot be resolved
  214. return None
  215. entrypoint = version_config.entrypoint or self.default_entrypoint
  216. if entrypoint:
  217. return shlex.split(entrypoint)
  218. else:
  219. return None
  220. def get_image_name(self, version: Optional[str] = None) -> (str, str):
  221. """
  222. Resolve a user-configured container image for the specified backend version.
  223. Args:
  224. version: Desired backend version; falls back to `default_version` when None.
  225. Returns:
  226. A tuple of (image_name, version). Empty strings indicate no user-configured image.
  227. """
  228. # CUSTOM backend does not resolve here; image/command come from the model configuration
  229. if self.backend_name == BackendEnum.CUSTOM.value:
  230. return "", ""
  231. try:
  232. # Resolve concrete version and fetch its configuration
  233. version_config, version = self.get_version_config(version)
  234. except KeyError:
  235. # Version not found or cannot be resolved
  236. return "", ""
  237. if not version_config or not version_config.image_name:
  238. return "", ""
  239. # Only return image for custom version configs (no built-in frameworks) with explicit image
  240. if (
  241. self.backend_source == BackendSourceEnum.BUILT_IN
  242. and version_config.built_in_frameworks
  243. ):
  244. return "", ""
  245. return version_config.image_name, version
  246. class InferenceBackend(InferenceBackendBase, BaseModelMixin, table=True):
  247. __tablename__ = 'inference_backends'
  248. __table_args__ = (
  249. UniqueConstraint(
  250. "backend_name",
  251. "owner_principal_id",
  252. name="uix_inference_backends_name_org",
  253. ),
  254. )
  255. id: Optional[int] = SQLField(default=None, primary_key=True)
  256. class VersionListItem(BaseModel):
  257. version: str = Field(...)
  258. is_deprecated: bool = Field(default=False)
  259. env: Optional[Dict[str, str]] = Field(None)
  260. class InferenceBackendListItem(BaseModel):
  261. """Backend configuration item."""
  262. backend_name: str = Field(...)
  263. is_built_in: Optional[bool] = Field(None)
  264. default_version: Optional[str] = Field(None)
  265. default_backend_param: Optional[List[str]] = Field(None)
  266. versions: Optional[List[VersionListItem]] = Field(
  267. None, description="Available versions for this backend"
  268. )
  269. enabled: Optional[bool] = Field(None)
  270. backend_source: Optional[BackendSourceEnum] = Field(None)
  271. default_env: Optional[Dict[str, str]] = Field(None)
  272. class InferenceBackendResponse(BaseModel):
  273. """Response for backend configs list."""
  274. items: List[InferenceBackendListItem] = Field(...)
  275. # CRUD API Models
  276. class InferenceBackendCreate(InferenceBackendBase):
  277. pass
  278. class InferenceBackendUpdate(InferenceBackendBase):
  279. pass
  280. class InferenceBackendPublic(InferenceBackendBase):
  281. id: Optional[int]
  282. created_at: Optional[datetime]
  283. updated_at: Optional[datetime]
  284. built_in_version_configs: Optional[Dict[str, VersionConfig]] = {}
  285. framework_index_map: Optional[Dict[str, List[str]]] = {}
  286. InferenceBackendsPublic = PaginatedList[InferenceBackendPublic]
  287. # built-in backend configurations
  288. def get_built_in_backend() -> List[InferenceBackend]:
  289. return [
  290. InferenceBackend(backend_name=BackendEnum.VLLM.value, is_built_in=True),
  291. InferenceBackend(backend_name=BackendEnum.SGLANG.value, is_built_in=True),
  292. InferenceBackend(
  293. backend_name=BackendEnum.ASCEND_MINDIE.value, is_built_in=True
  294. ),
  295. InferenceBackend(backend_name=BackendEnum.VOX_BOX.value, is_built_in=True),
  296. InferenceBackend(backend_name=BackendEnum.CUSTOM.value, is_built_in=True),
  297. ]
  298. def is_built_in_backend(backend_name: Optional[str]) -> bool:
  299. """
  300. Check if a backend is a built-in backend.
  301. Args:
  302. backend_name: The name of the backend to check
  303. Returns:
  304. True if the backend is built-in, False otherwise
  305. """
  306. if not backend_name:
  307. return False
  308. built_in_backends = get_built_in_backend()
  309. built_in_backend_names = {
  310. backend.backend_name.lower() for backend in built_in_backends
  311. }
  312. return backend_name.lower() in built_in_backend_names
  313. def is_custom_backend(backend_name: Optional[str]) -> bool:
  314. """
  315. Check if a backend is a custom backend, i.e., not built-in or explicitly marked as CUSTOM.
  316. Args:
  317. backend_name: The name of the backend to check
  318. Returns:
  319. True if the backend is custom, False otherwise
  320. """
  321. if not backend_name:
  322. return False
  323. return (
  324. not is_built_in_backend(backend_name)
  325. or backend_name == BackendEnum.CUSTOM.value
  326. )
  327. def is_built_in_backend_custom_version(
  328. backend_name: Optional[str],
  329. backend_version: Optional[str],
  330. image_name: Optional[str],
  331. ) -> bool:
  332. """
  333. True when a built-in backend uses user-defined runner configuration that is
  334. outside gpustack-runner catalogs: explicit model image, or an inference
  335. backend version key ending with '-custom' (see validate_custom_suffix).
  336. """
  337. if not is_built_in_backend(backend_name):
  338. return False
  339. if image_name:
  340. return True
  341. if backend_version and backend_version.lower().endswith("-custom"):
  342. return True
  343. return False