| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718 |
- import secrets
- import datetime
- import base64
- import uuid
- import logging
- import asyncio
- from typing import Optional, List, Dict, Any, Set
- from sqlmodel.ext.asyncio.session import AsyncSession
- from sqlalchemy.exc import IntegrityError
- from sqlalchemy.orm import selectinload
- from urllib.parse import urlencode
- from fastapi import APIRouter, Depends, Response, Request
- from fastapi.responses import StreamingResponse, RedirectResponse
- from gpustack.api.exceptions import (
- AlreadyExistsException,
- InternalServerErrorException,
- NotFoundException,
- ForbiddenException,
- InvalidException,
- )
- from gpustack.config.config import get_global_config
- from gpustack.api.tenant import (
- bypass_tenant_filter,
- assert_cluster_resource_visible,
- assert_org_owned_writable,
- cluster_resource_visibility_conditions,
- )
- from gpustack.server.deps import (
- SessionDep,
- CurrentUserDep,
- TenantContextDep,
- )
- from gpustack.server.db import async_session
- from gpustack.server.worker_status_buffer import (
- heartbeat_flush_buffer,
- heartbeat_flush_buffer_lock,
- worker_status_flush_buffer,
- worker_status_flush_buffer_lock,
- )
- from gpustack.schemas.workers import (
- WorkerCreate,
- WorkerListParams,
- WorkerPublic,
- WorkerUpdate,
- WorkersPublic,
- Worker,
- WorkerRegistrationPublic,
- WorkerStatusStored,
- WorkerStateEnum,
- )
- from gpustack.schemas.clusters import Cluster, Credential, ClusterStateEnum
- from gpustack.schemas.users import User, UserRole
- from gpustack.schemas.api_keys import ApiKey
- from gpustack.schemas.config import (
- SensitivePredefinedConfig,
- PredefinedConfigNoDefaults,
- )
- from gpustack.security import get_secret_hash, API_KEY_PREFIX
- from gpustack.server.services import WorkerService, create_user_with_principal
- from gpustack.cloud_providers.common import key_bytes_to_openssh_pem
- from gpustack.utils.grafana import resolve_grafana_base_url
- router = APIRouter()
- system_name_prefix = "system/worker"
- logger = logging.getLogger(__name__)
- # Semaphore for creating workers to prevent db contention
- # FIXME: replace with an optimized implementation
- create_worker_semaphore = asyncio.Semaphore(10)
- def to_worker_public(input: Worker, me: bool) -> WorkerPublic:
- data = input.model_dump()
- if me:
- data['me'] = me
- return WorkerPublic.model_validate(data)
- def _make_worker_visibility_filter(ctx):
- """Return a row-level visibility predicate matching the SQL filter
- produced by ``cluster_resource_visibility_conditions``."""
- def _visible(w) -> bool:
- if bypass_tenant_filter(ctx):
- return True
- org_id = getattr(w, "owner_principal_id", None)
- if (
- ctx.current_principal_id is not None
- and org_id is not None
- and org_id == ctx.current_principal_id
- ):
- return True
- if getattr(w, "cluster_id", None) in ctx.accessible_cluster_ids:
- return True
- return False
- return _visible
- def _build_worker_list_filters(name, uuid, cluster_id, search):
- fuzzy_fields = {"name": search} if search else {}
- fields = {}
- if name:
- fields["name"] = name
- if uuid:
- fields["worker_uuid"] = uuid
- if cluster_id:
- fields["cluster_id"] = cluster_id
- return fields, fuzzy_fields
- def _normalize_worker_order_by(order_by):
- if not order_by:
- return order_by
- out = []
- for field, direction in order_by:
- # maps gpus (gpu count) to internal JSON array length representation
- if field == "gpus":
- out.append(("status.gpu_devices[]", direction))
- else:
- out.append((field, direction))
- return out
- @router.get("", response_model=WorkersPublic)
- async def get_workers(
- user: CurrentUserDep,
- ctx: TenantContextDep,
- params: WorkerListParams = Depends(),
- name: str = None,
- search: str = None,
- uuid: str = None,
- me: Optional[bool] = None,
- cluster_id: Optional[int] = None,
- ):
- fields, fuzzy_fields = _build_worker_list_filters(name, uuid, cluster_id, search)
- # Worker carries denormalized owner_principal_id (synced from cluster) so
- # tenant filtering can use the same OR-of-{own-Org, cluster_access} rule
- # as cluster_resource_visibility_conditions.
- extra_conditions = cluster_resource_visibility_conditions(ctx, Worker)
- visible = _make_worker_visibility_filter(ctx)
- if params.watch:
- return StreamingResponse(
- Worker.streaming(
- fields=fields, fuzzy_fields=fuzzy_fields, filter_func=visible
- ),
- media_type="text/event-stream",
- )
- if me and user.worker is not None:
- # me query overrides all other filters
- fields = {"id": user.worker.id}
- fuzzy_fields = {}
- async with async_session() as session:
- worker_list = await Worker.paginated_by_query(
- session=session,
- fields=fields,
- fuzzy_fields=fuzzy_fields,
- extra_conditions=extra_conditions,
- page=params.page,
- per_page=params.perPage,
- order_by=_normalize_worker_order_by(params.order_by),
- )
- if not user.worker:
- return worker_list
- public_list = [
- to_worker_public(worker, user.worker.id == worker.id)
- for worker in worker_list.items
- ]
- return WorkersPublic(items=public_list, pagination=worker_list.pagination)
- @router.get("/{id}", response_model=WorkerPublic)
- async def get_worker(
- user: CurrentUserDep,
- ctx: TenantContextDep,
- session: SessionDep,
- id: int,
- ):
- worker = await Worker.one_by_id(session, id)
- assert_cluster_resource_visible(ctx, worker, not_found_message="worker not found")
- if user.worker is not None and user.worker.id == worker.id:
- return to_worker_public(worker, True)
- return worker
- @router.get("/{id}/dashboard")
- async def get_worker_dashboard(
- session: SessionDep,
- ctx: TenantContextDep,
- id: int,
- request: Request,
- ):
- worker = await Worker.one_by_id(session, id)
- assert_cluster_resource_visible(ctx, worker, not_found_message="worker not found")
- cfg = get_global_config()
- if not cfg.get_grafana_url() or not cfg.grafana_worker_dashboard_uid:
- raise InternalServerErrorException(
- message="Grafana dashboard settings are not configured"
- )
- cluster = None
- if worker.cluster_id is not None:
- cluster = await Cluster.one_by_id(session, worker.cluster_id)
- query_params = {}
- if cluster is not None:
- query_params["var-cluster_name"] = cluster.name
- query_params["var-worker_name"] = worker.name
- grafana_base = resolve_grafana_base_url(cfg, request)
- slug = "gpustack-worker"
- dashboard_url = f"{grafana_base}/d/{cfg.grafana_worker_dashboard_uid}/{slug}"
- if query_params:
- dashboard_url = f"{dashboard_url}?{urlencode(query_params)}"
- return RedirectResponse(url=dashboard_url, status_code=302)
- def update_worker_data(
- worker_in: WorkerCreate,
- existing: Optional[Worker] = None,
- **kwargs,
- ) -> Worker:
- to_create_worker = None
- cluster: Optional[Cluster] = kwargs.get("cluster")
- if existing is not None:
- # Preserve maintenance field from existing worker if not explicitly set in worker_in
- incoming_data = worker_in.model_dump()
- if (
- incoming_data.get("maintenance") is None
- and existing.maintenance is not None
- ):
- incoming_data["maintenance"] = existing.maintenance
- to_create_worker = Worker.model_validate(
- {
- **existing.model_dump(),
- **incoming_data,
- "labels": {
- **existing.labels,
- **worker_in.labels,
- },
- "cluster_id": existing.cluster_id,
- # Re-sync from cluster in case org ownership changed.
- "owner_principal_id": (
- cluster.owner_principal_id
- if cluster is not None
- else existing.owner_principal_id
- ),
- "state": WorkerStateEnum.READY,
- }
- )
- else:
- # new worker should ignore the reported worker_uuid
- to_create_worker = Worker.model_validate(
- {
- **worker_in.model_dump(exclude={"name", "worker_uuid"}),
- "name": worker_in.name or worker_in.hostname,
- "worker_uuid": "",
- "state": WorkerStateEnum.READY,
- "owner_principal_id": (
- cluster.owner_principal_id if cluster is not None else None
- ),
- **kwargs,
- }
- )
- if cluster is not None:
- to_create_worker.cluster = cluster
- to_create_worker.compute_state()
- return to_create_worker
- def _matches_exact_fields(worker: Worker, fields: Optional[Dict[str, Any]]) -> bool:
- if not fields:
- return True
- return all(getattr(worker, k, None) == v for k, v in fields.items())
- def _matches_fuzzy_fields(worker: Worker, fuzzy_fields: Dict[str, str]) -> bool:
- if not fuzzy_fields:
- return True
- for k, v in fuzzy_fields.items():
- attr = getattr(worker, k, None)
- if not isinstance(attr, str) or v.lower() not in attr.lower():
- return False
- return True
- def filter_workers_by_fields(
- workers: List[Worker],
- fields: Optional[Dict[str, Any]],
- fuzzy_fields: Dict[str, str] = {},
- ) -> List[Worker]:
- if not fields and not fuzzy_fields:
- return workers
- to_return = []
- for worker in workers:
- match = _matches_exact_fields(worker, fields) and _matches_fuzzy_fields(
- worker, fuzzy_fields
- )
- if match:
- to_return.append(worker)
- return to_return
- def get_existing_worker(
- cluster_id: int, worker_in: WorkerCreate, workers: List[Worker]
- ) -> Optional[Worker]:
- static_fields = {
- "deleted_at": None,
- "cluster_id": cluster_id,
- }
- if worker_in.name == "":
- return None
- # find existing worker by external_id or worker_uuid
- for field in ["external_id", "worker_uuid"]:
- value = getattr(worker_in, field, None)
- if value is None:
- continue
- fields = {**static_fields, field: value}
- existing_worker = next(iter(filter_workers_by_fields(workers, fields)), None)
- if existing_worker is not None:
- return existing_worker
- # find existing worker by name
- if worker_in.labels and worker_in.labels.get("gpustack.existence-check"):
- fields = {"name": worker_in.name}
- existing_worker = next(iter(filter_workers_by_fields(workers, fields)), None)
- if existing_worker is not None:
- if existing_worker.cluster_id != cluster_id:
- raise AlreadyExistsException(
- message=f"worker with name {worker_in.name} already exists in another cluster"
- )
- return existing_worker
- return None
- def check_worker_name_conflict(
- name: str, workers: List[Worker], existing_id: Optional[int] = None
- ):
- if name == "":
- if existing_id is not None:
- raise InvalidException(message="worker name cannot be empty")
- return
- workers = [worker for worker in workers if worker.id != existing_id]
- name_conflict_fields = {"name": name}
- name_conflict_worker = next(
- iter(filter_workers_by_fields(workers, name_conflict_fields)), None
- )
- if name_conflict_worker is not None:
- raise AlreadyExistsException(message=f"worker with name {name} already exists")
- def find_available_worker_name(
- original_name: str, current_name: str, related_names: Set[str]
- ) -> str:
- if original_name not in related_names:
- return original_name
- index = 1
- if current_name.startswith(f"{original_name}-"):
- suffix = current_name[len(original_name) + 1 :]
- if suffix.isdigit():
- index = int(suffix) + 1
- new_name = f"{original_name}-{index}"
- while new_name in related_names:
- index += 1
- new_name = f"{original_name}-{index}"
- return new_name
- async def retry_create_worker(
- session: AsyncSession, to_create: Worker, workers: List[Worker]
- ) -> Worker:
- related_workers = filter_workers_by_fields(
- workers,
- fields={
- "deleted_at": None,
- },
- fuzzy_fields={"name": to_create.name},
- )
- related_names = set(worker.name for worker in related_workers)
- original_name = to_create.name
- current_name = to_create.name
- for i in range(5):
- try:
- current_name = find_available_worker_name(
- original_name, current_name, related_names
- )
- to_create.name = current_name
- to_create.labels["worker-name"] = current_name
- new_worker = await Worker.create(session, to_create, auto_commit=False)
- return new_worker
- except IntegrityError:
- logger.warning(
- f"Worker name collision detected for worker name {to_create.name}, retrying... (attempt {i + 1}/5)"
- )
- related_names.add(current_name)
- await asyncio.sleep(0.1) # small delay before retrying to reduce contention
- raise InternalServerErrorException(
- message="Failed to create worker with unique name after multiple attempts"
- )
- def retry_create_unique_worker_uuid(workers: List[Worker]) -> str:
- current_uuids = set(
- worker.worker_uuid for worker in workers if worker.worker_uuid != ""
- )
- for i in range(5):
- new_uuid = str(uuid.uuid4())
- if new_uuid not in current_uuids:
- return new_uuid
- logger.warning(
- f"UUID collision detected for worker_uuid {new_uuid}, retrying... (attempt {i + 1}/5)"
- )
- # might not be necessary to retry so many times, but just in case, we want to make sure
- # the system can recover from such a rare event without manual intervention
- raise InternalServerErrorException(
- message="Failed to generate unique worker UUID after multiple attempts"
- )
- def _resolve_create_worker_cluster_id(user, worker_in: WorkerCreate) -> int:
- cluster_id = (
- worker_in.cluster_id if worker_in.cluster_id is not None else user.cluster_id
- )
- if cluster_id is None:
- raise ForbiddenException(message="Missing cluster_id for worker registration")
- return cluster_id
- def _build_worker_config_dict(cluster: Cluster) -> Dict[str, Any]:
- sensitive_fields = set(SensitivePredefinedConfig.model_fields.keys())
- worker_config = (
- {}
- if cluster.worker_config is None
- else cluster.worker_config.model_dump(exclude=sensitive_fields)
- )
- cfg = get_global_config()
- if (
- cfg.system_default_container_registry is not None
- and len(cfg.system_default_container_registry) > 0
- ):
- worker_config.setdefault(
- "system_default_container_registry",
- cfg.system_default_container_registry,
- )
- return worker_config
- async def _resolve_existing_worker_user(
- session, existing_worker: Optional[Worker]
- ) -> Optional[User]:
- if existing_worker is None:
- return None
- return await User.one_by_field(
- session=session,
- field="worker_id",
- value=existing_worker.id,
- options=[selectinload(User.api_keys)],
- )
- def _existing_api_key(existing_user: Optional[User]) -> Optional[ApiKey]:
- if existing_user is None or not existing_user.api_keys:
- return None
- return existing_user.api_keys[0]
- async def _persist_worker_registration(
- session,
- *,
- existing_worker: Optional[Worker],
- new_worker: Worker,
- new_token: str,
- to_create_user: Optional[User],
- existing_user: Optional[User],
- to_create_apikey: Optional[ApiKey],
- all_workers: List[Worker],
- cluster: Cluster,
- ) -> Worker:
- if existing_worker is not None:
- if to_create_apikey is not None:
- new_worker.token = new_token
- await WorkerService(session).update(
- existing_worker, new_worker, auto_commit=False
- )
- worker = existing_worker
- else:
- worker = await retry_create_worker(session, new_worker, all_workers)
- created_user = None
- if to_create_user is not None:
- to_create_user.worker = worker
- created_user = await create_user_with_principal(session, to_create_user)
- if to_create_apikey is not None:
- to_create_apikey.user = existing_user or created_user
- to_create_apikey.user_id = (existing_user or created_user).id
- await ApiKey.create(session=session, source=to_create_apikey, auto_commit=False)
- if cluster.state != ClusterStateEnum.READY:
- cluster.state = ClusterStateEnum.READY
- await cluster.update(session=session, auto_commit=False)
- await session.commit()
- return worker
- @router.post("", response_model=WorkerRegistrationPublic)
- async def create_worker(user: CurrentUserDep, worker_in: WorkerCreate):
- # Worker registration runs through two paths: (1) v1_base_router with
- # a human session — admin-only, since spinning up workers is a
- # platform-level action; (2) cluster_client_router with the cluster
- # service-account token (user.is_system=True). Allow both, deny the
- # rest.
- if not (user.is_admin or getattr(user, "is_system", False)):
- raise ForbiddenException(message="Only platform admin can register workers")
- async with create_worker_semaphore:
- async with async_session() as session:
- cluster_id = _resolve_create_worker_cluster_id(user, worker_in)
- all_workers = await Worker.all_by_fields(session, {"deleted_at": None})
- existing_worker = get_existing_worker(cluster_id, worker_in, all_workers)
- check_worker_name_conflict(
- worker_in.name,
- all_workers,
- existing_id=existing_worker.id if existing_worker else None,
- )
- if existing_worker is None and worker_in.external_id is not None:
- # avoid creating a worker with a non-existent external_id
- raise NotFoundException(
- message=f"worker with external_id {worker_in.external_id} not found"
- )
- if existing_worker is not None:
- existing_worker = await Worker.one_by_id(
- session=session, id=existing_worker.id, for_update=True
- )
- cluster = await Cluster.one_by_id(session, cluster_id)
- if cluster is None or cluster.deleted_at is not None:
- raise NotFoundException(message="Cluster not found")
- worker_config = _build_worker_config_dict(cluster)
- hashed_suffix = secrets.token_hex(6)
- access_key = secrets.token_hex(8)
- secret_key = secrets.token_hex(16)
- new_token = f"{API_KEY_PREFIX}_{access_key}_{secret_key}"
- new_worker = update_worker_data(
- worker_in,
- existing=existing_worker,
- # following args are only used when creating a new worker
- provider=cluster.provider,
- cluster=cluster,
- token=new_token,
- )
- if new_worker.worker_uuid == "":
- new_worker.worker_uuid = retry_create_unique_worker_uuid(all_workers)
- existing_user = await _resolve_existing_worker_user(
- session, existing_worker
- )
- to_create_user = (
- User(
- username=f'{system_name_prefix}-{hashed_suffix}',
- is_system=True,
- role=UserRole.Worker,
- hashed_password="",
- cluster=cluster,
- )
- if existing_user is None
- else None
- )
- existing_api_key = _existing_api_key(existing_user)
- to_create_apikey = (
- ApiKey(
- name=f'{system_name_prefix}-{hashed_suffix}',
- access_key=access_key,
- hashed_secret_key=get_secret_hash(secret_key),
- )
- if existing_api_key is None
- else None
- )
- try:
- worker = await _persist_worker_registration(
- session,
- existing_worker=existing_worker,
- new_worker=new_worker,
- new_token=new_token,
- to_create_user=to_create_user,
- existing_user=existing_user,
- to_create_apikey=to_create_apikey,
- all_workers=all_workers,
- cluster=cluster,
- )
- worker_dump = worker.model_dump()
- worker_dump["token"] = worker.token
- worker_dump["worker_config"] = (
- PredefinedConfigNoDefaults.model_validate(worker_config)
- )
- return WorkerRegistrationPublic.model_validate(worker_dump)
- except Exception as e:
- await session.rollback()
- raise InternalServerErrorException(
- message=f"Failed to create worker: {e}"
- )
- @router.put("/{id}", response_model=WorkerPublic)
- async def update_worker(
- ctx: TenantContextDep,
- session: SessionDep,
- id: int,
- worker_in: WorkerUpdate,
- ):
- worker = await Worker.one_by_id(session, id)
- if worker is not None and worker.deleted_at is not None:
- worker = None
- assert_cluster_resource_visible(ctx, worker, not_found_message="worker not found")
- assert_org_owned_writable(ctx, worker, resource_label="worker")
- patch = worker_in.model_dump()
- if worker_in.maintenance is not None:
- worker.maintenance = worker_in.maintenance
- worker.compute_state()
- patch["state"] = worker.state
- try:
- await WorkerService(session).update(worker, patch)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to update worker: {e}")
- return worker
- @router.delete("/{id}")
- async def delete_worker(ctx: TenantContextDep, session: SessionDep, id: int):
- worker = await Worker.one_by_id(session, id)
- if worker is not None and worker.deleted_at is not None:
- worker = None
- assert_cluster_resource_visible(ctx, worker, not_found_message="worker not found")
- assert_org_owned_writable(ctx, worker, resource_label="worker")
- try:
- soft = worker.external_id is not None
- if soft:
- worker.state = WorkerStateEnum.DELETING
- await WorkerService(session).delete(worker, soft=soft)
- except Exception as e:
- raise InternalServerErrorException(message=f"Failed to delete worker: {e}")
- async def create_worker_status(user: CurrentUserDep, input: WorkerStatusStored):
- if user.worker is None:
- raise ForbiddenException(message="Failed to find related worker")
- heartbeat_time = datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
- input_dict = input.model_dump(exclude_unset=True)
- input_dict["heartbeat_time"] = heartbeat_time
- # Add worker status to buffer for batch update
- async with worker_status_flush_buffer_lock:
- worker_status_flush_buffer[user.worker.id] = input_dict
- return Response(status_code=204)
- async def heartbeat(user: CurrentUserDep):
- if user.worker is None:
- raise ForbiddenException(message="Failed to find related worker")
- # Add worker ID to buffer for batch update
- async with heartbeat_flush_buffer_lock:
- heartbeat_flush_buffer.add(user.worker.id)
- return Response(status_code=204)
- @router.get("/{id}/privatekey")
- async def get_worker_privatekey(
- ctx: TenantContextDep,
- session: SessionDep,
- id: int,
- ):
- worker = await Worker.one_by_id(session, id)
- if worker is not None and worker.deleted_at is not None:
- worker = None
- assert_cluster_resource_visible(ctx, worker, not_found_message="worker not found")
- # Private key is a write-class secret (anyone holding it can SSH into the
- # host) — gate with the writable check, same as the cluster registration
- # token endpoint in routes/clusters.py.
- assert_org_owned_writable(ctx, worker, resource_label="worker")
- if worker.ssh_key_id is None:
- raise NotFoundException(message="worker ssh key not found")
- ssh_key = await Credential.one_by_id(session, worker.ssh_key_id)
- if not ssh_key:
- raise NotFoundException(message="worker ssh key not found")
- private_key_bytes = base64.b64decode(ssh_key.encoded_private_key)
- private_key_pem = key_bytes_to_openssh_pem(
- private_key_bytes, ssh_key.ssh_key_options.algorithm
- )
- return Response(
- content=private_key_pem,
- media_type="application/octet-stream",
- headers={
- "Content-Disposition": f"attachment; filename=worker-{id}-private_key.pem"
- },
- )
|