models.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. import logging
  2. import math
  3. from typing import List, Optional, Union
  4. from fastapi import APIRouter, Depends, Query, Request
  5. from fastapi.responses import RedirectResponse, StreamingResponse
  6. from urllib.parse import urlencode
  7. from gpustack_runtime.detector import ManufacturerEnum
  8. from sqlalchemy.orm import selectinload
  9. from sqlmodel import and_, or_
  10. from sqlmodel.ext.asyncio.session import AsyncSession
  11. from enum import Enum
  12. from gpustack.api.exceptions import (
  13. AlreadyExistsException,
  14. InternalServerErrorException,
  15. BadRequestException,
  16. )
  17. from gpustack.schemas.common import Pagination
  18. from gpustack.schemas.inference_backend import is_custom_backend
  19. from gpustack.schemas.models import (
  20. ModelInstance,
  21. ModelInstancesPublic,
  22. BackendEnum,
  23. ModelListParams,
  24. )
  25. from gpustack.schemas.clusters import Cluster
  26. from gpustack.schemas.workers import GPUDeviceStatus, Worker
  27. from gpustack.api.tenant import (
  28. bypass_tenant_filter,
  29. assert_resource_visible,
  30. tenant_list_conditions,
  31. )
  32. from gpustack.server.db import async_session
  33. from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
  34. from gpustack.schemas.models import (
  35. Model,
  36. ModelCreate,
  37. ModelSpecBase,
  38. ModelUpdate,
  39. ModelPublic,
  40. ModelsPublic,
  41. )
  42. from gpustack.schemas.model_routes import (
  43. ModelRoute,
  44. ModelRouteTarget,
  45. TargetStateEnum,
  46. )
  47. from gpustack.server.services import (
  48. ModelService,
  49. WorkerService,
  50. revoke_model_access_cache,
  51. )
  52. from gpustack.utils.command import find_parameter
  53. from gpustack.utils.convert import safe_int
  54. from gpustack.utils.gpu import parse_gpu_id
  55. from gpustack.routes.model_common import (
  56. build_category_conditions,
  57. categories_filter,
  58. )
  59. from gpustack.config.config import get_global_config
  60. from gpustack.utils.grafana import resolve_grafana_base_url
  61. router = APIRouter()
  62. logger = logging.getLogger(__name__)
  63. class ModelStateFilterEnum(str, Enum):
  64. READY = "ready"
  65. NOT_READY = "not_ready"
  66. STOPPED = "stopped"
  67. @router.get("", response_model=ModelsPublic)
  68. async def get_models(
  69. ctx: TenantContextDep,
  70. params: ModelListParams = Depends(),
  71. state: Optional[ModelStateFilterEnum] = Query(
  72. default=None,
  73. description="Filter by model state.",
  74. ),
  75. search: str = None,
  76. categories: Optional[List[str]] = Query(None, description="Filter by categories."),
  77. cluster_id: int = None,
  78. backend: Optional[str] = Query(None, description="Filter by backend."),
  79. ):
  80. fuzzy_fields = {}
  81. if search:
  82. fuzzy_fields = {"name": search}
  83. fields = {}
  84. if cluster_id:
  85. fields["cluster_id"] = cluster_id
  86. if backend:
  87. fields["backend"] = backend
  88. # Streaming uses field-equality only; scope by current org so non-admin
  89. # users never see cross-org rows via the live stream. Admin without an
  90. # explicit org context keeps the unfiltered cross-org stream. System
  91. # users (workers / cluster accounts) bypass — they need the cross-org
  92. # view to handle instances scheduled to them on clusters outside their
  93. # default Org.
  94. if ctx.current_principal_id is not None and not bypass_tenant_filter(ctx):
  95. fields["owner_principal_id"] = ctx.current_principal_id
  96. if params.watch:
  97. return StreamingResponse(
  98. Model.streaming(
  99. fields=fields,
  100. fuzzy_fields=fuzzy_fields,
  101. filter_func=lambda data: categories_filter(data, categories),
  102. ),
  103. media_type="text/event-stream",
  104. )
  105. async with async_session() as session:
  106. extra_conditions = list(tenant_list_conditions(ctx, Model))
  107. if categories:
  108. conditions = build_category_conditions(session, Model, categories)
  109. extra_conditions.append(or_(*conditions))
  110. if state is None:
  111. pass
  112. elif state == ModelStateFilterEnum.READY:
  113. extra_conditions.append(Model.ready_replicas > 0)
  114. elif state == ModelStateFilterEnum.NOT_READY:
  115. extra_conditions.append(and_(Model.ready_replicas == 0, Model.replicas > 0))
  116. elif state == ModelStateFilterEnum.STOPPED:
  117. extra_conditions.append(Model.replicas == 0)
  118. order_by = params.order_by
  119. if order_by:
  120. # When sorting by "source", add additional sorting fields for deterministic ordering
  121. new_order_by = []
  122. for field, direction in order_by:
  123. new_order_by.append((field, direction))
  124. if field == "source":
  125. new_order_by.append(("huggingface_repo_id", direction))
  126. new_order_by.append(("huggingface_filename", direction))
  127. new_order_by.append(("model_scope_model_id", direction))
  128. new_order_by.append(("model_scope_file_path", direction))
  129. new_order_by.append(("local_path", direction))
  130. order_by = new_order_by
  131. return await Model.paginated_by_query(
  132. session=session,
  133. fuzzy_fields=fuzzy_fields,
  134. extra_conditions=extra_conditions,
  135. page=params.page,
  136. per_page=params.perPage,
  137. fields=fields,
  138. order_by=order_by,
  139. )
  140. @router.get("/{id}", response_model=ModelPublic)
  141. async def get_model(
  142. session: SessionDep,
  143. ctx: TenantContextDep,
  144. id: int,
  145. ):
  146. return await _get_model(session=session, ctx=ctx, id=id)
  147. @router.get("/{id}/dashboard")
  148. async def get_model_dashboard(
  149. session: SessionDep,
  150. ctx: TenantContextDep,
  151. id: int,
  152. request: Request,
  153. ):
  154. model = await _get_model(session=session, ctx=ctx, id=id)
  155. cfg = get_global_config()
  156. if not cfg.get_grafana_url() or not cfg.grafana_model_dashboard_uid:
  157. raise InternalServerErrorException(
  158. message="Grafana dashboard settings are not configured"
  159. )
  160. cluster = None
  161. if model.cluster_id is not None:
  162. cluster = await Cluster.one_by_id(session, model.cluster_id)
  163. query_params = {}
  164. if cluster is not None:
  165. query_params["var-cluster_name"] = cluster.name
  166. query_params["var-model_name"] = model.name
  167. grafana_base = resolve_grafana_base_url(cfg, request)
  168. slug = "gpustack-model"
  169. dashboard_url = f"{grafana_base}/d/{cfg.grafana_model_dashboard_uid}/{slug}"
  170. if query_params:
  171. dashboard_url = f"{dashboard_url}?{urlencode(query_params)}"
  172. return RedirectResponse(url=dashboard_url, status_code=302)
  173. async def _get_model(
  174. session: SessionDep,
  175. ctx,
  176. id: int,
  177. ):
  178. model = await Model.one_by_id(session, id)
  179. assert_resource_visible(ctx, model, not_found_message="Model not found")
  180. return model
  181. @router.get("/{id}/instances", response_model=ModelInstancesPublic)
  182. async def get_model_instances(ctx: TenantContextDep, id: int, params: ListParamsDep):
  183. if params.watch:
  184. fields = {"model_id": id}
  185. return StreamingResponse(
  186. ModelInstance.streaming(fields=fields),
  187. media_type="text/event-stream",
  188. )
  189. async with async_session() as session:
  190. model = await Model.one_by_id(
  191. session, id, options=[selectinload(Model.instances)]
  192. )
  193. assert_resource_visible(ctx, model, not_found_message="Model not found")
  194. instances = model.instances
  195. count = len(instances)
  196. total_page = math.ceil(count / params.perPage)
  197. pagination = Pagination(
  198. page=params.page,
  199. perPage=params.perPage,
  200. total=count,
  201. totalPage=total_page,
  202. )
  203. return ModelInstancesPublic(items=instances, pagination=pagination)
  204. async def validate_model_in(
  205. session: SessionDep,
  206. model_in: Union[ModelCreate, ModelUpdate, ModelSpecBase],
  207. *,
  208. cluster_id: Optional[int] = None,
  209. ):
  210. if model_in.gpu_selector is not None and model_in.replicas > 0:
  211. await validate_gpu_ids(session, model_in, cluster_id=cluster_id)
  212. if is_custom_backend(model_in.backend):
  213. logger.info("Skip model validation for custom backend")
  214. return
  215. if model_in.backend_parameters:
  216. param_gpu_layers = find_parameter(
  217. model_in.backend_parameters, ["ngl", "gpu-layers", "n-gpu-layers"]
  218. )
  219. if param_gpu_layers:
  220. int_param_gpu_layers = safe_int(param_gpu_layers, None)
  221. if (
  222. not param_gpu_layers.isdigit()
  223. or int_param_gpu_layers < 0
  224. or int_param_gpu_layers > 999
  225. ):
  226. raise BadRequestException(
  227. message="Invalid backend parameter --gpu-layers. Please provide an integer in the range 0-999 (inclusive)."
  228. )
  229. if (
  230. int_param_gpu_layers == 0
  231. and model_in.gpu_selector is not None
  232. and len(model_in.gpu_selector.gpu_ids) > 0
  233. ):
  234. raise BadRequestException(
  235. message="Cannot set --gpu-layers to 0 and manually select GPUs at the same time. Setting --gpu-layers to 0 means running on CPU only."
  236. )
  237. unsupported_params = [
  238. (
  239. ["port"],
  240. (
  241. "Setting the port using --port is not supported. Ports are "
  242. "automatically allocated by GPUStack."
  243. ),
  244. ),
  245. (
  246. ["api-key"],
  247. (
  248. "Setting the API key using --api-key is not supported. API keys "
  249. "are managed by GPUStack."
  250. ),
  251. ),
  252. (
  253. ["served-model-name"],
  254. (
  255. "Setting the served model name using --served-model-name is not "
  256. "supported. The model name is automatically set from your "
  257. "deployment configuration."
  258. ),
  259. ),
  260. ]
  261. for param_names, error_message in unsupported_params:
  262. if find_parameter(model_in.backend_parameters, param_names):
  263. raise BadRequestException(message=error_message)
  264. async def validate_gpu_ids( # noqa: C901
  265. session: SessionDep,
  266. model_in: Union[ModelCreate, ModelUpdate, ModelSpecBase],
  267. *,
  268. cluster_id: Optional[int] = None,
  269. ):
  270. effective_cluster_id = (
  271. cluster_id if cluster_id is not None else getattr(model_in, "cluster_id", None)
  272. )
  273. if (
  274. model_in.gpu_selector
  275. and model_in.gpu_selector.gpu_ids
  276. and model_in.gpu_selector.gpus_per_replica
  277. ):
  278. if len(model_in.gpu_selector.gpu_ids) < model_in.gpu_selector.gpus_per_replica:
  279. raise BadRequestException(
  280. message="The number of selected GPUs must be greater than or equal to gpus_per_replica."
  281. )
  282. model_backend = model_in.backend
  283. if model_backend == BackendEnum.VOX_BOX and (
  284. len(model_in.gpu_selector.gpu_ids) > 1
  285. or (
  286. model_in.gpu_selector.gpus_per_replica is not None
  287. and model_in.gpu_selector.gpus_per_replica > 1
  288. )
  289. ):
  290. raise BadRequestException(
  291. message="The vox-box backend is restricted to execution on a single NVIDIA GPU."
  292. )
  293. worker_name_set = set()
  294. for gpu_id in model_in.gpu_selector.gpu_ids:
  295. is_valid, matched = parse_gpu_id(gpu_id)
  296. if not is_valid:
  297. raise BadRequestException(message=f"Invalid GPU ID: {gpu_id}")
  298. worker_name = matched.get("worker_name")
  299. gpu_index = safe_int(matched.get("gpu_index"), -1)
  300. worker_name_set.add(worker_name)
  301. if effective_cluster_id is None:
  302. raise BadRequestException(
  303. message=f"A cluster context is required for manual GPU selection, but was not provided. Cannot validate worker '{worker_name}'."
  304. )
  305. worker = await WorkerService(session).get_by_cluster_id_name(
  306. effective_cluster_id, worker_name
  307. )
  308. if not worker:
  309. raise BadRequestException(message=f"Worker {worker_name} not found")
  310. gpu = (
  311. next(
  312. (gpu for gpu in worker.status.gpu_devices if gpu.index == gpu_index),
  313. None,
  314. )
  315. if worker.status and worker.status.gpu_devices
  316. else None
  317. )
  318. if gpu:
  319. validate_gpu(gpu, model_backend=model_backend)
  320. if model_backend == BackendEnum.VLLM and len(worker_name_set) > 1:
  321. await validate_distributed_vllm_limit_per_worker(session, model_in, worker)
  322. if (
  323. is_custom_backend(model_backend)
  324. and len(worker_name_set) > 1
  325. and model_in.replicas == 1
  326. ):
  327. raise BadRequestException(
  328. message="Distributed inference across multiple workers is not supported for custom backends."
  329. )
  330. def validate_gpu(gpu_device: GPUDeviceStatus, model_backend: str = ""):
  331. if (
  332. model_backend == BackendEnum.VOX_BOX
  333. and gpu_device.vendor != ManufacturerEnum.NVIDIA.value
  334. ):
  335. raise BadRequestException(
  336. "The vox-box backend is supported only on NVIDIA GPUs."
  337. )
  338. if (
  339. model_backend == BackendEnum.ASCEND_MINDIE
  340. and gpu_device.vendor != ManufacturerEnum.ASCEND.value
  341. ):
  342. raise BadRequestException(
  343. f"Ascend MindIE backend requires Ascend NPUs. Selected {gpu_device.vendor} GPU is not supported."
  344. )
  345. async def validate_distributed_vllm_limit_per_worker(
  346. session: AsyncSession, model: Union[ModelCreate, ModelUpdate], worker: Worker
  347. ):
  348. """
  349. Validate that there is no more than one distributed vLLM instance per worker.
  350. """
  351. instances = await ModelInstance.all_by_field(session, "worker_id", worker.id)
  352. for instance in instances:
  353. if (
  354. instance.distributed_servers
  355. and instance.distributed_servers.subordinate_workers
  356. and instance.model_name != model.name
  357. ):
  358. raise BadRequestException(
  359. message=f"Each worker can run only one distributed vLLM instance. Worker '{worker.name}' already has '{instance.name}'."
  360. )
  361. @router.post("", response_model=ModelPublic)
  362. async def create_model(
  363. session: SessionDep, ctx: TenantContextDep, model_in: ModelCreate
  364. ):
  365. # Model & ModelRoute names are unique within their Org. Two Orgs
  366. # can each have a "llama3" without colliding.
  367. org_scope = ctx.current_principal_id
  368. existing = await Model.one_by_fields(
  369. session,
  370. {"name": model_in.name, "owner_principal_id": org_scope},
  371. )
  372. if existing:
  373. raise AlreadyExistsException(
  374. message=f"Model '{model_in.name}' already exists. "
  375. "Please choose a different name or check the existing model."
  376. )
  377. should_create_route = (
  378. model_in.enable_model_route is not None and model_in.enable_model_route
  379. )
  380. if should_create_route:
  381. existing_route = await ModelRoute.one_by_fields(
  382. session,
  383. {"name": model_in.name, "owner_principal_id": org_scope},
  384. )
  385. if existing_route:
  386. raise AlreadyExistsException(
  387. message=f"Model route '{model_in.name}' already exists. "
  388. "Please choose a different name or check the existing model route."
  389. )
  390. await validate_model_in(session, model_in)
  391. model_in_dict = model_in.model_dump(exclude={"enable_model_route"})
  392. # Stamp tenant scope. ModelBase has owner_principal_id defaulted to
  393. # PLATFORM_PRINCIPAL_ID, so `model_dump()` always emits the key —
  394. # `setdefault` would silently leave it at 1 even when the caller is
  395. # acting under a different Org. Override directly:
  396. # - Caller has a current Org context → that Org wins
  397. # - Caller is admin in "All" mode → fall back to the chosen
  398. # cluster's owner Org so the model lives where it actually runs
  399. # (otherwise it'd land in Platform/Default and the cluster's Org
  400. # couldn't see / manage it)
  401. target_org_id = ctx.current_principal_id
  402. if target_org_id is None and model_in.cluster_id is not None:
  403. cluster = await Cluster.one_by_id(session, model_in.cluster_id)
  404. if cluster is not None:
  405. target_org_id = cluster.owner_principal_id
  406. if target_org_id is not None:
  407. model_in_dict["owner_principal_id"] = target_org_id
  408. try:
  409. model: Model = await Model.create(
  410. session, source=model_in_dict, auto_commit=(not should_create_route)
  411. )
  412. if should_create_route:
  413. model_route = ModelRoute(
  414. name=model.name,
  415. description=model.description,
  416. categories=model.categories,
  417. generic_proxy=model.generic_proxy,
  418. created_by_model=True,
  419. access_policy=model.access_policy,
  420. owner_principal_id=model.owner_principal_id,
  421. )
  422. model_route: ModelRoute = await ModelRoute.create(
  423. session, source=model_route, auto_commit=False
  424. )
  425. model_route_target = ModelRouteTarget(
  426. name=f"{model.name}-deployment",
  427. route_name=model_route.name,
  428. generic_proxy=model.generic_proxy,
  429. model_route=model_route,
  430. model=model,
  431. weight=100,
  432. state=TargetStateEnum.UNAVAILABLE,
  433. )
  434. await ModelRouteTarget.create(
  435. session,
  436. source=model_route_target,
  437. auto_commit=False,
  438. )
  439. await session.commit()
  440. await revoke_model_access_cache(session=session)
  441. except Exception as e:
  442. raise InternalServerErrorException(message=f"Failed to create model: {e}")
  443. return model
  444. @router.put("/{id}", response_model=ModelPublic)
  445. async def update_model(
  446. session: SessionDep, ctx: TenantContextDep, id: int, model_in: ModelUpdate
  447. ):
  448. model = await Model.one_by_id(session, id)
  449. assert_resource_visible(ctx, model, not_found_message="Model not found")
  450. await validate_model_in(session, model_in)
  451. if model_in.backend != BackendEnum.CUSTOM.value and (
  452. model.run_command or model.image_name
  453. ):
  454. patch = model_in.model_dump(exclude_unset=True)
  455. patch["run_command"] = None
  456. patch["image_name"] = None
  457. model_in = patch
  458. try:
  459. await ModelService(session).update(model, model_in)
  460. except Exception as e:
  461. raise InternalServerErrorException(message=f"Failed to update model: {e}")
  462. return model
  463. @router.delete("/{id}")
  464. async def delete_model(session: SessionDep, ctx: TenantContextDep, id: int):
  465. model = await Model.one_by_id(
  466. session,
  467. id,
  468. options=[
  469. selectinload(Model.instances),
  470. selectinload(Model.model_route_targets),
  471. ],
  472. )
  473. assert_resource_visible(ctx, model, not_found_message="Model not found")
  474. try:
  475. await ModelService(session).delete(model)
  476. except Exception as e:
  477. raise InternalServerErrorException(message=f"Failed to delete model: {e}")