worker_pools.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from fastapi import APIRouter
  2. from fastapi.responses import StreamingResponse
  3. from sqlalchemy.orm import selectinload
  4. from gpustack.api.exceptions import (
  5. InternalServerErrorException,
  6. NotFoundException,
  7. ForbiddenException,
  8. )
  9. from gpustack.api.tenant import (
  10. assert_org_owned_writable,
  11. assert_resource_visible,
  12. tenant_list_conditions,
  13. )
  14. from gpustack.server.db import async_session
  15. from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
  16. from gpustack.schemas.clusters import (
  17. WorkerPoolPublic,
  18. WorkerPoolsPublic,
  19. WorkerPoolUpdate,
  20. WorkerPool,
  21. )
  22. WORKER_POOL_LOAD_OPTIONS = [selectinload(WorkerPool.pool_workers)]
  23. router = APIRouter()
  24. @router.get("", response_model=WorkerPoolsPublic)
  25. async def list(
  26. ctx: TenantContextDep,
  27. params: ListParamsDep,
  28. name: str = None,
  29. search: str = None,
  30. cluster_id: int = None,
  31. ):
  32. fuzzy_fields = {}
  33. if search:
  34. fuzzy_fields = {"name": search}
  35. fields = {"deleted_at": None}
  36. if cluster_id:
  37. fields["cluster_id"] = cluster_id
  38. if name:
  39. fields["name"] = name
  40. if params.watch:
  41. return StreamingResponse(
  42. WorkerPool.streaming(
  43. fields=fields,
  44. fuzzy_fields=fuzzy_fields,
  45. options=WORKER_POOL_LOAD_OPTIONS,
  46. ),
  47. media_type="text/event-stream",
  48. )
  49. async with async_session() as session:
  50. # Worker pools mirror their cluster's owner_principal_id; same filter
  51. # rules as cloud_credentials apply.
  52. extra_conditions = tenant_list_conditions(ctx, WorkerPool)
  53. return await WorkerPool.paginated_by_query(
  54. session=session,
  55. fields=fields,
  56. fuzzy_fields=fuzzy_fields,
  57. extra_conditions=extra_conditions,
  58. page=params.page,
  59. per_page=params.perPage,
  60. options=WORKER_POOL_LOAD_OPTIONS,
  61. )
  62. @router.get("/{id}", response_model=WorkerPoolPublic)
  63. async def get(session: SessionDep, ctx: TenantContextDep, id: int):
  64. existing = await WorkerPool.one_by_id(session, id, options=WORKER_POOL_LOAD_OPTIONS)
  65. if not existing or existing.deleted_at is not None:
  66. raise NotFoundException(message=f"worker pool {id} not found")
  67. assert_resource_visible(
  68. ctx,
  69. existing,
  70. not_found_message=f"worker pool {id} not found",
  71. )
  72. return existing
  73. @router.put("/{id}", response_model=WorkerPoolPublic)
  74. async def update(
  75. session: SessionDep, ctx: TenantContextDep, id: int, input: WorkerPoolUpdate
  76. ):
  77. existing = await WorkerPool.one_by_id(session, id)
  78. if not existing or existing.deleted_at is not None:
  79. raise NotFoundException(message=f"worker pool {id} not found")
  80. assert_org_owned_writable(ctx, existing, resource_label="worker pool")
  81. try:
  82. await WorkerPool.update(existing, session=session, source=input)
  83. except Exception as e:
  84. raise InternalServerErrorException(
  85. message=f"Failed to update worker pool {id}: {e}"
  86. )
  87. return await WorkerPool.one_by_id(session, id, options=WORKER_POOL_LOAD_OPTIONS)
  88. @router.delete("/{id}")
  89. async def delete(session: SessionDep, ctx: TenantContextDep, id: int):
  90. existing = await WorkerPool.one_by_id(session, id, options=WORKER_POOL_LOAD_OPTIONS)
  91. if not existing or existing.deleted_at is not None:
  92. raise NotFoundException(message=f"worker pool {id} not found")
  93. assert_org_owned_writable(ctx, existing, resource_label="worker pool")
  94. if len(existing.pool_workers) > 0:
  95. raise ForbiddenException(
  96. message=f"worker pool {id} has workers, cannot be deleted"
  97. )
  98. try:
  99. await existing.delete(session=session)
  100. except Exception as e:
  101. raise InternalServerErrorException(message=f"Failed to delete worker pool: {e}")