draft_models.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import math
  2. from typing import List, Optional
  3. from fastapi import APIRouter, Depends, Query
  4. from gpustack.schemas.common import PaginatedList, Pagination
  5. from gpustack.server.catalog import (
  6. DraftModel,
  7. get_catalog_draft_models,
  8. )
  9. from gpustack.server.deps import ListParamsDep
  10. router = APIRouter()
  11. @router.get("", response_model=PaginatedList[DraftModel])
  12. async def get_draft_models(
  13. params: ListParamsDep,
  14. search: str = None,
  15. algorithm: Optional[str] = Query(None, description="Filter by algorithm."),
  16. draft_models: List[DraftModel] = Depends(get_catalog_draft_models),
  17. ):
  18. if search:
  19. search = search.strip().lower()
  20. draft_models = [model for model in draft_models if search in model.name.lower()]
  21. if algorithm:
  22. draft_models = [
  23. model
  24. for model in draft_models
  25. if model.algorithm is not None and model.algorithm == algorithm
  26. ]
  27. count = len(draft_models)
  28. if params.page < 1 or params.perPage < 1:
  29. # Return all items.
  30. pagination = Pagination(
  31. page=1,
  32. perPage=count,
  33. total=count,
  34. totalPage=1,
  35. )
  36. return PaginatedList[DraftModel](items=draft_models, pagination=pagination)
  37. # Paginate results.
  38. total_page = math.ceil(count / params.perPage)
  39. start_index = (params.page - 1) * params.perPage
  40. end_index = start_index + params.perPage
  41. paginated_items = draft_models[start_index:end_index]
  42. pagination = Pagination(
  43. page=params.page,
  44. perPage=params.perPage,
  45. total=count,
  46. totalPage=total_page,
  47. )
  48. return PaginatedList[DraftModel](items=paginated_items, pagination=pagination)