test_model_routes.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from contextlib import asynccontextmanager
  2. import pytest
  3. from sqlalchemy import true
  4. from gpustack.routes import model_routes
  5. from gpustack.schemas.common import Pagination
  6. from gpustack.schemas.model_routes import (
  7. ModelRouteListParams,
  8. ModelRoutesPublic,
  9. MyModel,
  10. )
  11. @pytest.mark.asyncio
  12. async def test_get_model_routes_filters_categories_on_target_class(monkeypatch):
  13. captured = {}
  14. @asynccontextmanager
  15. async def fake_async_session():
  16. yield object()
  17. def fake_build_category_conditions(session, target_class, categories):
  18. captured["target_class"] = target_class
  19. captured["categories"] = categories
  20. return [true()]
  21. async def fake_paginated_by_query(**kwargs):
  22. captured["fields"] = kwargs["fields"]
  23. captured["extra_conditions"] = kwargs["extra_conditions"]
  24. return ModelRoutesPublic(
  25. items=[],
  26. pagination=Pagination(page=1, perPage=24, total=0, totalPage=0),
  27. )
  28. monkeypatch.setattr(model_routes, "async_session", fake_async_session)
  29. monkeypatch.setattr(
  30. model_routes, "build_category_conditions", fake_build_category_conditions
  31. )
  32. monkeypatch.setattr(MyModel, "paginated_by_query", fake_paginated_by_query)
  33. await model_routes._get_model_routes(
  34. params=ModelRouteListParams(page=1, perPage=24),
  35. categories=["image"],
  36. target_class=MyModel,
  37. user_id=123,
  38. )
  39. assert captured["target_class"] is MyModel
  40. assert captured["categories"] == ["image"]
  41. assert captured["fields"]["user_id"] == 123
  42. assert captured["extra_conditions"]