gpu_devices.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from fastapi import APIRouter, Depends
  2. from fastapi.responses import StreamingResponse
  3. from gpustack.api.tenant import (
  4. bypass_tenant_filter,
  5. assert_cluster_resource_visible,
  6. cluster_resource_visibility_conditions,
  7. )
  8. from gpustack.server.db import async_session
  9. from gpustack.server.deps import SessionDep, TenantContextDep
  10. from gpustack.schemas.gpu_devices import (
  11. GPUDevice,
  12. GPUDeviceListParams,
  13. GPUDevicesPublic,
  14. GPUDevicePublic,
  15. )
  16. router = APIRouter()
  17. @router.get("", response_model=GPUDevicesPublic)
  18. async def get_gpus(
  19. ctx: TenantContextDep,
  20. params: GPUDeviceListParams = Depends(),
  21. search: str = None,
  22. cluster_id: int = None,
  23. ):
  24. fuzzy_fields = {}
  25. if search:
  26. fuzzy_fields = {"name": search}
  27. fields = {}
  28. if cluster_id:
  29. fields["cluster_id"] = cluster_id
  30. extra_conditions = cluster_resource_visibility_conditions(ctx, GPUDevice)
  31. def _gpu_visible(g) -> bool:
  32. if bypass_tenant_filter(ctx):
  33. return True
  34. org_id = getattr(g, "owner_principal_id", None)
  35. if (
  36. ctx.current_principal_id is not None
  37. and org_id is not None
  38. and org_id == ctx.current_principal_id
  39. ):
  40. return True
  41. if getattr(g, "cluster_id", None) in ctx.accessible_cluster_ids:
  42. return True
  43. return False
  44. if params.watch:
  45. return StreamingResponse(
  46. GPUDevice.streaming(
  47. fuzzy_fields=fuzzy_fields, fields=fields, filter_func=_gpu_visible
  48. ),
  49. media_type="text/event-stream",
  50. )
  51. async with async_session() as session:
  52. return await GPUDevice.paginated_by_query(
  53. session=session,
  54. fuzzy_fields=fuzzy_fields,
  55. page=params.page,
  56. per_page=params.perPage,
  57. fields=fields,
  58. extra_conditions=extra_conditions,
  59. order_by=params.order_by,
  60. )
  61. @router.get("/{id}", response_model=GPUDevicePublic)
  62. async def get_gpu(session: SessionDep, ctx: TenantContextDep, id: str):
  63. model = await GPUDevice.one_by_id(session, id)
  64. assert_cluster_resource_visible(
  65. ctx, model, not_found_message="GPU device not found"
  66. )
  67. return model