| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- from fastapi import APIRouter
- from fastapi.responses import StreamingResponse
- from sqlalchemy.orm import selectinload
- from gpustack.api.exceptions import (
- InternalServerErrorException,
- NotFoundException,
- ForbiddenException,
- )
- from gpustack.api.tenant import (
- assert_org_owned_writable,
- assert_resource_visible,
- tenant_list_conditions,
- )
- from gpustack.server.db import async_session
- from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
- from gpustack.schemas.clusters import (
- WorkerPoolPublic,
- WorkerPoolsPublic,
- WorkerPoolUpdate,
- WorkerPool,
- )
- WORKER_POOL_LOAD_OPTIONS = [selectinload(WorkerPool.pool_workers)]
- router = APIRouter()
- @router.get("", response_model=WorkerPoolsPublic)
- async def list(
- ctx: TenantContextDep,
- params: ListParamsDep,
- name: str = None,
- search: str = None,
- cluster_id: int = None,
- ):
- fuzzy_fields = {}
- if search:
- fuzzy_fields = {"name": search}
- fields = {"deleted_at": None}
- if cluster_id:
- fields["cluster_id"] = cluster_id
- if name:
- fields["name"] = name
- if params.watch:
- return StreamingResponse(
- WorkerPool.streaming(
- fields=fields,
- fuzzy_fields=fuzzy_fields,
- options=WORKER_POOL_LOAD_OPTIONS,
- ),
- media_type="text/event-stream",
- )
- async with async_session() as session:
- # Worker pools mirror their cluster's owner_principal_id; same filter
- # rules as cloud_credentials apply.
- extra_conditions = tenant_list_conditions(ctx, WorkerPool)
- return await WorkerPool.paginated_by_query(
- session=session,
- fields=fields,
- fuzzy_fields=fuzzy_fields,
- extra_conditions=extra_conditions,
- page=params.page,
- per_page=params.perPage,
- options=WORKER_POOL_LOAD_OPTIONS,
- )
- @router.get("/{id}", response_model=WorkerPoolPublic)
- async def get(session: SessionDep, ctx: TenantContextDep, id: int):
- existing = await WorkerPool.one_by_id(session, id, options=WORKER_POOL_LOAD_OPTIONS)
- if not existing or existing.deleted_at is not None:
- raise NotFoundException(message=f"worker pool {id} not found")
- assert_resource_visible(
- ctx,
- existing,
- not_found_message=f"worker pool {id} not found",
- )
- return existing
- @router.put("/{id}", response_model=WorkerPoolPublic)
- async def update(
- session: SessionDep, ctx: TenantContextDep, id: int, input: WorkerPoolUpdate
- ):
- existing = await WorkerPool.one_by_id(session, id)
- if not existing or existing.deleted_at is not None:
- raise NotFoundException(message=f"worker pool {id} not found")
- assert_org_owned_writable(ctx, existing, resource_label="worker pool")
- try:
- await WorkerPool.update(existing, session=session, source=input)
- except Exception as e:
- raise InternalServerErrorException(
- message=f"Failed to update worker pool {id}: {e}"
- )
- return await WorkerPool.one_by_id(session, id, options=WORKER_POOL_LOAD_OPTIONS)
- @router.delete("/{id}")
- async def delete(session: SessionDep, ctx: TenantContextDep, id: int):
- existing = await WorkerPool.one_by_id(session, id, options=WORKER_POOL_LOAD_OPTIONS)
- if not existing or existing.deleted_at is not None:
- raise NotFoundException(message=f"worker pool {id} not found")
- assert_org_owned_writable(ctx, existing, resource_label="worker pool")
- if len(existing.pool_workers) > 0:
- raise ForbiddenException(
- message=f"worker pool {id} has workers, cannot be deleted"
- )
- try:
- await existing.delete(session=session)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to delete worker pool: {e}")
|