models.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. import logging
  2. import math
  3. from typing import List, Optional, Union
  4. from fastapi import APIRouter, Depends, Query, Request
  5. from fastapi.responses import RedirectResponse, StreamingResponse
  6. from urllib.parse import urlencode
  7. from gpustack_runtime.detector import ManufacturerEnum
  8. from sqlalchemy.orm import selectinload
  9. from sqlmodel import select, and_, or_
  10. from sqlmodel.ext.asyncio.session import AsyncSession
  11. from enum import Enum
  12. from gpustack.api.exceptions import (
  13. AlreadyExistsException,
  14. InternalServerErrorException,
  15. BadRequestException,
  16. )
  17. from gpustack.schemas.common import Pagination
  18. from gpustack.schemas.inference_backend import is_custom_backend
  19. from gpustack.schemas.models import (
  20. ModelInstance,
  21. ModelInstancesPublic,
  22. BackendEnum,
  23. ModelListParams,
  24. )
  25. from gpustack.schemas.clusters import Cluster
  26. from gpustack.schemas.workers import GPUDeviceStatus, Worker
  27. from gpustack.api.tenant import (
  28. bypass_tenant_filter,
  29. assert_resource_visible,
  30. tenant_list_conditions,
  31. )
  32. from gpustack.server.db import async_session
  33. from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
  34. from gpustack.schemas.models import (
  35. Model,
  36. ModelCreate,
  37. ModelSpecBase,
  38. ModelUpdate,
  39. ModelPublic,
  40. ModelsPublic,
  41. )
  42. from gpustack.schemas.model_routes import (
  43. ModelRoute,
  44. ModelRouteTarget,
  45. TargetStateEnum,
  46. )
  47. from gpustack.server.services import (
  48. ModelService,
  49. WorkerService,
  50. revoke_model_access_cache,
  51. )
  52. from gpustack.utils.command import find_parameter
  53. from gpustack.utils.convert import safe_int
  54. from gpustack.utils.gpu import parse_gpu_id
  55. from gpustack.routes.model_common import (
  56. build_category_conditions,
  57. categories_filter,
  58. )
  59. from gpustack.config.config import get_global_config
  60. from gpustack.utils.grafana import resolve_grafana_base_url
  61. router = APIRouter()
  62. logger = logging.getLogger(__name__)
  63. class ModelStateFilterEnum(str, Enum):
  64. READY = "ready"
  65. NOT_READY = "not_ready"
  66. STOPPED = "stopped"
  67. @router.get("", response_model=ModelsPublic)
  68. async def get_models(
  69. ctx: TenantContextDep,
  70. params: ModelListParams = Depends(),
  71. state: Optional[ModelStateFilterEnum] = Query(
  72. default=None,
  73. description="Filter by model state.",
  74. ),
  75. search: str = None,
  76. categories: Optional[List[str]] = Query(None, description="Filter by categories."),
  77. cluster_id: int = None,
  78. backend: Optional[str] = Query(None, description="Filter by backend."),
  79. ):
  80. fuzzy_fields = {}
  81. if search:
  82. fuzzy_fields = {"name": search}
  83. fields = {}
  84. if cluster_id:
  85. fields["cluster_id"] = cluster_id
  86. if backend:
  87. fields["backend"] = backend
  88. # Streaming uses field-equality only; scope by current org so non-admin
  89. # users never see cross-org rows via the live stream. Admin without an
  90. # explicit org context keeps the unfiltered cross-org stream. System
  91. # users (workers / cluster accounts) bypass — they need the cross-org
  92. # view to handle instances scheduled to them on clusters outside their
  93. # default Org.
  94. if ctx.current_principal_id is not None and not bypass_tenant_filter(ctx):
  95. fields["owner_principal_id"] = ctx.current_principal_id
  96. if params.watch:
  97. return StreamingResponse(
  98. Model.streaming(
  99. fields=fields,
  100. fuzzy_fields=fuzzy_fields,
  101. filter_func=lambda data: categories_filter(data, categories),
  102. ),
  103. media_type="text/event-stream",
  104. )
  105. async with async_session() as session:
  106. extra_conditions = list(tenant_list_conditions(ctx, Model))
  107. if categories:
  108. conditions = build_category_conditions(session, Model, categories)
  109. extra_conditions.append(or_(*conditions))
  110. if state is None:
  111. pass
  112. elif state == ModelStateFilterEnum.READY:
  113. extra_conditions.append(Model.ready_replicas > 0)
  114. elif state == ModelStateFilterEnum.NOT_READY:
  115. extra_conditions.append(and_(Model.ready_replicas == 0, Model.replicas > 0))
  116. elif state == ModelStateFilterEnum.STOPPED:
  117. extra_conditions.append(Model.replicas == 0)
  118. order_by = params.order_by
  119. if order_by:
  120. # When sorting by "source", add additional sorting fields for deterministic ordering
  121. new_order_by = []
  122. for field, direction in order_by:
  123. new_order_by.append((field, direction))
  124. if field == "source":
  125. new_order_by.append(("huggingface_repo_id", direction))
  126. new_order_by.append(("huggingface_filename", direction))
  127. new_order_by.append(("model_scope_model_id", direction))
  128. new_order_by.append(("model_scope_file_path", direction))
  129. new_order_by.append(("local_path", direction))
  130. order_by = new_order_by
  131. return await Model.paginated_by_query(
  132. session=session,
  133. fuzzy_fields=fuzzy_fields,
  134. extra_conditions=extra_conditions,
  135. page=params.page,
  136. per_page=params.perPage,
  137. fields=fields,
  138. order_by=order_by,
  139. )
  140. @router.get("/{id}", response_model=ModelPublic)
  141. async def get_model(
  142. session: SessionDep,
  143. ctx: TenantContextDep,
  144. id: int,
  145. ):
  146. return await _get_model(session=session, ctx=ctx, id=id)
  147. @router.get("/{id}/dashboard")
  148. async def get_model_dashboard(
  149. session: SessionDep,
  150. ctx: TenantContextDep,
  151. id: int,
  152. request: Request,
  153. ):
  154. model = await _get_model(session=session, ctx=ctx, id=id)
  155. cfg = get_global_config()
  156. if not cfg.get_grafana_url() or not cfg.grafana_model_dashboard_uid:
  157. raise InternalServerErrorException(
  158. message="Grafana dashboard settings are not configured"
  159. )
  160. cluster = None
  161. if model.cluster_id is not None:
  162. cluster = await Cluster.one_by_id(session, model.cluster_id)
  163. query_params = {}
  164. if cluster is not None:
  165. query_params["var-cluster_name"] = cluster.name
  166. query_params["var-model_name"] = model.name
  167. grafana_base = resolve_grafana_base_url(cfg, request)
  168. slug = "gpustack-model"
  169. dashboard_url = f"{grafana_base}/d/{cfg.grafana_model_dashboard_uid}/{slug}"
  170. if query_params:
  171. dashboard_url = f"{dashboard_url}?{urlencode(query_params)}"
  172. return RedirectResponse(url=dashboard_url, status_code=302)
  173. async def _get_model(
  174. session: SessionDep,
  175. ctx,
  176. id: int,
  177. ):
  178. model = await Model.one_by_id(session, id)
  179. assert_resource_visible(ctx, model, not_found_message="Model not found")
  180. return model
  181. @router.get("/{id}/instances", response_model=ModelInstancesPublic)
  182. async def get_model_instances(ctx: TenantContextDep, id: int, params: ListParamsDep):
  183. if params.watch:
  184. fields = {"model_id": id}
  185. return StreamingResponse(
  186. ModelInstance.streaming(fields=fields),
  187. media_type="text/event-stream",
  188. )
  189. async with async_session() as session:
  190. model = await Model.one_by_id(
  191. session, id, options=[selectinload(Model.instances)]
  192. )
  193. assert_resource_visible(ctx, model, not_found_message="Model not found")
  194. instances = model.instances
  195. count = len(instances)
  196. total_page = math.ceil(count / params.perPage)
  197. pagination = Pagination(
  198. page=params.page,
  199. perPage=params.perPage,
  200. total=count,
  201. totalPage=total_page,
  202. )
  203. return ModelInstancesPublic(items=instances, pagination=pagination)
  204. async def validate_model_in(
  205. session: SessionDep,
  206. model_in: Union[ModelCreate, ModelUpdate, ModelSpecBase],
  207. *,
  208. cluster_id: Optional[int] = None,
  209. ):
  210. if model_in.gpu_selector is not None and model_in.replicas > 0:
  211. await validate_gpu_ids(session, model_in, cluster_id=cluster_id)
  212. if is_custom_backend(model_in.backend):
  213. logger.info("Skip model validation for custom backend")
  214. return
  215. if model_in.backend_parameters:
  216. param_gpu_layers = find_parameter(
  217. model_in.backend_parameters, ["ngl", "gpu-layers", "n-gpu-layers"]
  218. )
  219. if param_gpu_layers:
  220. int_param_gpu_layers = safe_int(param_gpu_layers, None)
  221. if (
  222. not param_gpu_layers.isdigit()
  223. or int_param_gpu_layers < 0
  224. or int_param_gpu_layers > 999
  225. ):
  226. raise BadRequestException(
  227. message="Invalid backend parameter --gpu-layers. Please provide an integer in the range 0-999 (inclusive)."
  228. )
  229. if (
  230. int_param_gpu_layers == 0
  231. and model_in.gpu_selector is not None
  232. and len(model_in.gpu_selector.gpu_ids) > 0
  233. ):
  234. raise BadRequestException(
  235. message="Cannot set --gpu-layers to 0 and manually select GPUs at the same time. Setting --gpu-layers to 0 means running on CPU only."
  236. )
  237. unsupported_params = [
  238. (
  239. ["port"],
  240. (
  241. "Setting the port using --port is not supported. Ports are "
  242. "automatically allocated by GPUStack."
  243. ),
  244. ),
  245. (
  246. ["api-key"],
  247. (
  248. "Setting the API key using --api-key is not supported. API keys "
  249. "are managed by GPUStack."
  250. ),
  251. ),
  252. (
  253. ["served-model-name"],
  254. (
  255. "Setting the served model name using --served-model-name is not "
  256. "supported. The model name is automatically set from your "
  257. "deployment configuration."
  258. ),
  259. ),
  260. ]
  261. for param_names, error_message in unsupported_params:
  262. if find_parameter(model_in.backend_parameters, param_names):
  263. raise BadRequestException(message=error_message)
  264. async def validate_gpu_ids( # noqa: C901
  265. session: SessionDep,
  266. model_in: Union[ModelCreate, ModelUpdate, ModelSpecBase],
  267. *,
  268. cluster_id: Optional[int] = None,
  269. ):
  270. effective_cluster_id = (
  271. cluster_id if cluster_id is not None else getattr(model_in, "cluster_id", None)
  272. )
  273. if (
  274. model_in.gpu_selector
  275. and model_in.gpu_selector.gpu_ids
  276. and model_in.gpu_selector.gpus_per_replica
  277. ):
  278. if len(model_in.gpu_selector.gpu_ids) < model_in.gpu_selector.gpus_per_replica:
  279. raise BadRequestException(
  280. message="The number of selected GPUs must be greater than or equal to gpus_per_replica."
  281. )
  282. model_backend = model_in.backend
  283. if model_backend == BackendEnum.VOX_BOX and (
  284. len(model_in.gpu_selector.gpu_ids) > 1
  285. or (
  286. model_in.gpu_selector.gpus_per_replica is not None
  287. and model_in.gpu_selector.gpus_per_replica > 1
  288. )
  289. ):
  290. raise BadRequestException(
  291. message="The vox-box backend is restricted to execution on a single NVIDIA GPU."
  292. )
  293. worker_name_set = set()
  294. for gpu_id in model_in.gpu_selector.gpu_ids:
  295. is_valid, matched = parse_gpu_id(gpu_id)
  296. if not is_valid:
  297. raise BadRequestException(message=f"Invalid GPU ID: {gpu_id}")
  298. worker_name = matched.get("worker_name")
  299. gpu_index = safe_int(matched.get("gpu_index"), -1)
  300. worker_name_set.add(worker_name)
  301. if effective_cluster_id is None:
  302. raise BadRequestException(
  303. message=f"A cluster context is required for manual GPU selection, but was not provided. Cannot validate worker '{worker_name}'."
  304. )
  305. worker = await WorkerService(session).get_by_cluster_id_name(
  306. effective_cluster_id, worker_name
  307. )
  308. if not worker:
  309. raise BadRequestException(message=f"Worker {worker_name} not found")
  310. gpu = (
  311. next(
  312. (gpu for gpu in worker.status.gpu_devices if gpu.index == gpu_index),
  313. None,
  314. )
  315. if worker.status and worker.status.gpu_devices
  316. else None
  317. )
  318. if gpu:
  319. validate_gpu(gpu, model_backend=model_backend)
  320. if model_backend == BackendEnum.VLLM and len(worker_name_set) > 1:
  321. await validate_distributed_vllm_limit_per_worker(session, model_in, worker)
  322. if (
  323. is_custom_backend(model_backend)
  324. and len(worker_name_set) > 1
  325. and model_in.replicas == 1
  326. ):
  327. raise BadRequestException(
  328. message="Distributed inference across multiple workers is not supported for custom backends."
  329. )
  330. def validate_gpu(gpu_device: GPUDeviceStatus, model_backend: str = ""):
  331. if (
  332. model_backend == BackendEnum.VOX_BOX
  333. and gpu_device.vendor != ManufacturerEnum.NVIDIA.value
  334. ):
  335. raise BadRequestException(
  336. "The vox-box backend is supported only on NVIDIA GPUs."
  337. )
  338. if (
  339. model_backend == BackendEnum.ASCEND_MINDIE
  340. and gpu_device.vendor != ManufacturerEnum.ASCEND.value
  341. ):
  342. raise BadRequestException(
  343. f"Ascend MindIE backend requires Ascend NPUs. Selected {gpu_device.vendor} GPU is not supported."
  344. )
  345. async def validate_distributed_vllm_limit_per_worker(
  346. session: AsyncSession, model: Union[ModelCreate, ModelUpdate], worker: Worker
  347. ):
  348. """
  349. Validate that there is no more than one distributed vLLM instance per worker.
  350. """
  351. instances = await ModelInstance.all_by_field(session, "worker_id", worker.id)
  352. for instance in instances:
  353. if (
  354. instance.distributed_servers
  355. and instance.distributed_servers.subordinate_workers
  356. and instance.model_name != model.name
  357. ):
  358. raise BadRequestException(
  359. message=f"Each worker can run only one distributed vLLM instance. Worker '{worker.name}' already has '{instance.name}'."
  360. )
  361. @router.post("", response_model=ModelPublic)
  362. async def create_model(
  363. session: SessionDep, ctx: TenantContextDep, model_in: ModelCreate
  364. ):
  365. # Model & ModelRoute names are unique within their Org. Two Orgs
  366. # can each have a "llama3" without colliding.
  367. org_scope = ctx.current_principal_id
  368. # Check for ANY existing record with the same name (including soft-deleted)
  369. # to avoid unique constraint violations
  370. statement = select(Model).where(Model.name == model_in.name)
  371. result = await session.exec(statement)
  372. any_existing = result.first()
  373. if any_existing:
  374. if any_existing.deleted_at is not None:
  375. # Soft-deleted record found
  376. if not model_in.overwrite_deleted:
  377. # Prompt user to confirm overwrite
  378. raise AlreadyExistsException(
  379. message=f"Model '{model_in.name}' was previously deleted. "
  380. "Do you want to overwrite it? Set 'overwrite_deleted=true' to confirm."
  381. )
  382. # User confirmed overwrite - permanently delete the soft-deleted record
  383. await session.delete(any_existing)
  384. await session.flush()
  385. else:
  386. # Active record found - check if it's in the same org scope
  387. if any_existing.owner_principal_id == org_scope:
  388. raise AlreadyExistsException(
  389. message=f"Model '{model_in.name}' already exists. "
  390. "Please choose a different name or check the existing model."
  391. )
  392. else:
  393. # Different org - still a conflict due to unique constraint on name
  394. raise AlreadyExistsException(
  395. message=f"Model name '{model_in.name}' is already in use by another organization. "
  396. "Please choose a different name."
  397. )
  398. # Double-check for the specific org scope (defensive programming)
  399. existing = await Model.one_by_fields(
  400. session,
  401. {"name": model_in.name, "owner_principal_id": org_scope},
  402. )
  403. if existing:
  404. raise AlreadyExistsException(
  405. message=f"Model '{model_in.name}' already exists. "
  406. "Please choose a different name or check the existing model."
  407. )
  408. should_create_route = (
  409. model_in.enable_model_route is not None and model_in.enable_model_route
  410. )
  411. if should_create_route:
  412. existing_route = await ModelRoute.one_by_fields(
  413. session,
  414. {"name": model_in.name, "owner_principal_id": org_scope},
  415. )
  416. if existing_route:
  417. raise AlreadyExistsException(
  418. message=f"Model route '{model_in.name}' already exists. "
  419. "Please choose a different name or check the existing model route."
  420. )
  421. await validate_model_in(session, model_in)
  422. model_in_dict = model_in.model_dump(exclude={"enable_model_route", "overwrite_deleted"})
  423. # Stamp tenant scope. ModelBase has owner_principal_id defaulted to
  424. # PLATFORM_PRINCIPAL_ID, so `model_dump()` always emits the key —
  425. # `setdefault` would silently leave it at 1 even when the caller is
  426. # acting under a different Org. Override directly:
  427. # - Caller has a current Org context → that Org wins
  428. # - Caller is admin in "All" mode → fall back to the chosen
  429. # cluster's owner Org so the model lives where it actually runs
  430. # (otherwise it'd land in Platform/Default and the cluster's Org
  431. # couldn't see / manage it)
  432. target_org_id = ctx.current_principal_id
  433. if target_org_id is None and model_in.cluster_id is not None:
  434. cluster = await Cluster.one_by_id(session, model_in.cluster_id)
  435. if cluster is not None:
  436. target_org_id = cluster.owner_principal_id
  437. if target_org_id is not None:
  438. model_in_dict["owner_principal_id"] = target_org_id
  439. try:
  440. model: Model = await Model.create(
  441. session, source=model_in_dict, auto_commit=(not should_create_route)
  442. )
  443. if should_create_route:
  444. model_route = ModelRoute(
  445. name=model.name,
  446. description=model.description,
  447. categories=model.categories,
  448. generic_proxy=model.generic_proxy,
  449. created_by_model=True,
  450. access_policy=model.access_policy,
  451. owner_principal_id=model.owner_principal_id,
  452. )
  453. model_route: ModelRoute = await ModelRoute.create(
  454. session, source=model_route, auto_commit=False
  455. )
  456. model_route_target = ModelRouteTarget(
  457. name=f"{model.name}-deployment",
  458. route_name=model_route.name,
  459. generic_proxy=model.generic_proxy,
  460. model_route=model_route,
  461. model=model,
  462. weight=100,
  463. state=TargetStateEnum.UNAVAILABLE,
  464. )
  465. await ModelRouteTarget.create(
  466. session,
  467. source=model_route_target,
  468. auto_commit=False,
  469. )
  470. await session.commit()
  471. await revoke_model_access_cache(session=session)
  472. except Exception as e:
  473. raise InternalServerErrorException(message=f"Failed to create model: {e}")
  474. return model
  475. @router.put("/{id}", response_model=ModelPublic)
  476. async def update_model(
  477. session: SessionDep, ctx: TenantContextDep, id: int, model_in: ModelUpdate
  478. ):
  479. model = await Model.one_by_id(session, id)
  480. assert_resource_visible(ctx, model, not_found_message="Model not found")
  481. await validate_model_in(session, model_in)
  482. if model_in.backend != BackendEnum.CUSTOM.value and (
  483. model.run_command or model.image_name
  484. ):
  485. patch = model_in.model_dump(exclude_unset=True)
  486. patch["run_command"] = None
  487. patch["image_name"] = None
  488. model_in = patch
  489. try:
  490. await ModelService(session).update(model, model_in)
  491. except Exception as e:
  492. raise InternalServerErrorException(message=f"Failed to update model: {e}")
  493. return model
  494. @router.delete("/{id}")
  495. async def delete_model(session: SessionDep, ctx: TenantContextDep, id: int):
  496. model = await Model.one_by_id(
  497. session,
  498. id,
  499. options=[
  500. selectinload(Model.instances),
  501. selectinload(Model.model_route_targets),
  502. ],
  503. )
  504. assert_resource_visible(ctx, model, not_found_message="Model not found")
  505. try:
  506. await ModelService(session).delete(model)
  507. except Exception as e:
  508. raise InternalServerErrorException(message=f"Failed to delete model: {e}")