| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571 |
- import secrets
- from urllib.parse import urlparse
- from enum import Enum
- from typing import ClassVar, Optional, Dict, Any, List
- from pydantic import (
- BaseModel,
- computed_field,
- field_validator,
- ConfigDict,
- PrivateAttr,
- Field as PydanticField,
- )
- from sqlmodel import (
- Field,
- Relationship,
- Column,
- SQLModel,
- Text,
- Integer,
- ForeignKey,
- JSON,
- String,
- )
- import sqlalchemy as sa
- from typing import TYPE_CHECKING
- from gpustack.schemas.config import (
- SensitivePredefinedConfig,
- PredefinedConfigNoDefaults,
- )
- from gpustack.mixins import BaseModelMixin
- from gpustack.schemas.common import (
- PublicFields,
- ListParams,
- PaginatedList,
- pydantic_column_type,
- )
- if TYPE_CHECKING:
- from gpustack.schemas.models import Model, ModelInstance
- from gpustack.schemas.workers import Worker
- from gpustack.schemas.users import User
- class WorkerPoolUpdate(SQLModel):
- name: str
- batch_size: Optional[int] = Field(default=None, ge=1)
- replicas: int = Field(default=1, ge=0)
- labels: Optional[Dict[str, str]] = Field(sa_column=Column(JSON), default={})
- class Volume(BaseModel):
- format: Optional[str] = None
- size_gb: Optional[int] = None
- name: Optional[str] = None
- @field_validator("name")
- def validate_name(cls, v):
- if not v:
- return v
- # the worker id will be appended to the name to ensure uniqueness
- # so the max length is 60 characters to leave room for the worker id
- if len(v) > 60:
- raise ValueError("Volume name too long, max 60 characters")
- # allow alphanumeric characters, dashes, and periods
- allowed_chars = set(
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-."
- )
- if not all(c in allowed_chars for c in v):
- raise ValueError("Volume name contains invalid characters")
- return v
- class HostPathVolumeSource(BaseModel):
- model_config = ConfigDict(populate_by_name=True, extra="ignore")
- path: str = PydanticField(
- ...,
- description="Path of the directory on the host. If the path is a symlink, it will follow the link to the real path.",
- )
- type: Optional[str] = PydanticField(None, description="Type for HostPath Volume.")
- class PersistentVolumeClaimVolumeSource(BaseModel):
- model_config = ConfigDict(populate_by_name=True, extra="ignore")
- claim_name: str = PydanticField(
- ...,
- alias="claimName",
- description="ClaimName is the name of a PersistentVolumeClaim in the same namespace as the pod using this volume.",
- )
- read_only: bool = PydanticField(
- False,
- alias="readOnly",
- description="Will force the ReadOnly setting in VolumeMounts.",
- )
- class ConfigMapVolumeSource(BaseModel):
- """
- This source will not be used for now. You won't be able to create this kind of volume through UI.
- """
- model_config = ConfigDict(populate_by_name=True, extra="ignore")
- name: str = PydanticField(..., description="Name of the referent.")
- optional: Optional[bool] = PydanticField(
- None, description="Specify whether the ConfigMap or its keys must be defined."
- )
- class VolumeSource(BaseModel):
- model_config = ConfigDict(populate_by_name=True, extra="ignore")
- host_path: Optional[HostPathVolumeSource] = PydanticField(None, alias="hostPath")
- persistent_volume_claim: Optional[PersistentVolumeClaimVolumeSource] = (
- PydanticField(None, alias="persistentVolumeClaim")
- )
- config_map: Optional[ConfigMapVolumeSource] = PydanticField(None, alias="configMap")
- class K8sVolumeMount(BaseModel):
- model_config = ConfigDict(populate_by_name=True, extra="ignore")
- name: str
- mount_path: str = PydanticField(..., alias="mountPath")
- read_only: bool = PydanticField(False, alias="readOnly")
- volume_source: Optional[VolumeSource] = PydanticField(
- default=None,
- alias="volumeSource",
- description=(
- "Kubernetes VolumeSource definition. Examples:\n"
- '- hostPath: `{"hostPath": {"path": "/data", "type": "Directory"}}`\n'
- '- persistentVolumeClaim: `{"persistentVolumeClaim": {"claimName": "my-pvc"}}`\n'
- '- configMap: `{"configMap": {"name": "my-configmap"}}`'
- ),
- )
- @field_validator("name")
- def validate_name(cls, v):
- if not v:
- return v
- if len(v) > 63:
- raise ValueError("Volume name must be less than 64 characters")
- import re
- if not re.fullmatch(r"[a-z0-9]([-a-z0-9]*[a-z0-9])?", v):
- raise ValueError(
- "Volume name must be a valid DNS-1123 label (e.g. 'my-name', or '123-abc'); "
- "it must consist of lower case alphanumeric characters or '-', "
- "and must start and end with an alphanumeric character."
- )
- return v
- class CloudOptions(BaseModel):
- volumes: Optional[List[Volume]] = None
- class WorkerPoolCreate(WorkerPoolUpdate):
- instance_type: str
- os_image: str
- image_name: str
- cloud_options: Optional[CloudOptions] = Field(
- default={}, sa_column=Column(pydantic_column_type(CloudOptions))
- )
- zone: Optional[str] = None
- # instance_spec is for UI to store the instance_type's extended specifications for display.
- instance_spec: Optional[Dict[str, Any]] = Field(
- default=None, sa_column=Column(JSON)
- )
- class WorkerPoolBase(WorkerPoolCreate):
- cluster_id: int = Field(
- sa_column=Column(Integer, ForeignKey("clusters.id", ondelete="CASCADE"))
- )
- # Mirrors the cluster's owner_principal_id (NOT NULL since clusters are
- # always Org-owned). The route layer copies the parent cluster's
- # value so the row can be filtered without a join.
- owner_principal_id: Optional[int] = Field(
- default=None, foreign_key="principals.id", nullable=False
- )
- class WorkerPool(WorkerPoolBase, BaseModelMixin, table=True):
- __tablename__ = "worker_pools"
- __table_args__ = (
- sa.Index("idx_worker_pools_deleted_at_created_at", "deleted_at", "created_at"),
- )
- id: Optional[int] = Field(default=None, primary_key=True)
- cluster: Optional["Cluster"] = Relationship(
- back_populates="cluster_worker_pools",
- sa_relationship_kwargs={"lazy": "noload"},
- )
- pool_workers: list["Worker"] = Relationship(
- sa_relationship_kwargs={"lazy": "noload"},
- back_populates="worker_pool",
- )
- _workers: int = PrivateAttr(default=-1)
- _ready_workers: int = PrivateAttr(default=-1)
- @computed_field()
- @property
- def workers(self) -> int:
- try:
- if self._workers >= 0:
- return self._workers
- except TypeError:
- pass
- return len(self.pool_workers or [])
- @computed_field()
- @property
- def ready_workers(self) -> int:
- try:
- if self._ready_workers >= 0:
- return self._ready_workers
- except TypeError:
- pass
- return len([w for w in self.pool_workers or [] if w.state.value == 'ready'])
- def __hash__(self):
- return hash(self.id)
- def __eq__(self, other):
- if super().__eq__(other) and isinstance(other, WorkerPool):
- return self.id == other.id
- return False
- def __init__(
- self,
- workers: int = -1,
- ready_workers: int = -1,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self._workers = workers
- self._ready_workers = ready_workers
- class WorkerPoolPublic(WorkerPoolBase, PublicFields):
- workers: int = Field(default=0)
- ready_workers: int = Field(default=0)
- WorkerPoolsPublic = PaginatedList[WorkerPoolPublic]
- class ClusterProvider(Enum):
- Docker = "Docker"
- Kubernetes = "Kubernetes"
- DigitalOcean = "DigitalOcean"
- class CloudCredentialBase(SQLModel):
- """
- Supports providers other than Kubernetes and Docker.
- """
- name: str
- description: Optional[str] = None
- provider: ClusterProvider = Field(default=ClusterProvider.DigitalOcean)
- key: Optional[str] = None
- options: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
- # Every cloud credential belongs to one Org (mirrors cluster scope).
- # The route fills this with ctx.current_principal_id or PLATFORM_ORG when
- # the caller omits it.
- owner_principal_id: Optional[int] = Field(
- default=None, foreign_key="principals.id", nullable=False
- )
- class CloudCredentialUpdate(CloudCredentialBase):
- secret: Optional[str] = None
- class CloudCredentialCreate(CloudCredentialUpdate):
- pass
- class CloudCredential(CloudCredentialCreate, BaseModelMixin, table=True):
- __tablename__ = "cloud_credentials"
- __table_args__ = (
- sa.Index(
- "idx_cloud_credentials_deleted_at_created_at", "deleted_at", "created_at"
- ),
- )
- id: Optional[int] = Field(default=None, primary_key=True)
- def __hash__(self):
- return hash(self.id)
- def __eq__(self, other):
- if super().__eq__(other) and isinstance(other, CloudCredential):
- return self.id == other.id
- return False
- class CloudCredentialListParams(ListParams):
- sortable_fields: ClassVar[List[str]] = [
- "name",
- "provider",
- "created_at",
- "updated_at",
- ]
- class CloudCredentialPublic(CloudCredentialBase, PublicFields):
- pass
- CloudCredentialsPublic = PaginatedList[CloudCredentialPublic]
- class ClusterStateEnum(str, Enum):
- PENDING = 'pending'
- PROVISIONING = 'provisioning'
- PROVISIONED = 'provisioned'
- READY = 'ready'
- class ClusterUpdate(SQLModel):
- name: str
- description: Optional[str] = None
- gateway_endpoint: Optional[str] = None
- server_url: Optional[str] = None
- worker_config: Optional[PredefinedConfigNoDefaults] = Field(
- default=None,
- sa_column=Column(
- pydantic_column_type(
- PredefinedConfigNoDefaults,
- exclude_none=True,
- exclude_unset=True,
- exclude_defaults=True,
- )
- ),
- )
- k8s_volume_mounts: Optional[List[K8sVolumeMount]] = Field(
- default=None,
- sa_column=Column(
- pydantic_column_type(
- List[K8sVolumeMount],
- exclude_none=True,
- exclude_unset=True,
- exclude_defaults=True,
- )
- ),
- )
- @field_validator("server_url")
- def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
- if v is not None and len(v) == 0:
- return None
- if v is not None:
- parsed = urlparse(v)
- if not parsed.scheme or not parsed.netloc:
- raise ValueError("Invalid server_url format")
- return v
- class ClusterCreateBase(ClusterUpdate):
- provider: ClusterProvider = Field(default=ClusterProvider.Docker)
- credential_id: Optional[int] = Field(
- default=None, foreign_key="cloud_credentials.id"
- )
- region: Optional[str] = None
- # Every cluster belongs to one Org. The route layer fills this with
- # ctx.current_principal_id (or PLATFORM_PRINCIPAL_ID for admin in "All"
- # mode) when callers omit it; sharing across Orgs is expressed via
- # cluster_access rather than NULL ownership.
- owner_principal_id: Optional[int] = Field(
- default=None, foreign_key="principals.id", nullable=False
- )
- class ClusterCreate(ClusterCreateBase):
- worker_pools: Optional[List[WorkerPoolCreate]] = Field(default=None)
- class ClusterBase(ClusterCreateBase):
- state: ClusterStateEnum = ClusterStateEnum.PROVISIONING
- state_message: Optional[str] = Field(
- default=None, sa_column=Column(Text, nullable=True)
- )
- reported_gateway_endpoint: Optional[str] = None
- is_default: bool = Field(default=False)
- class Cluster(ClusterBase, BaseModelMixin, table=True):
- __tablename__ = "clusters"
- __table_args__ = (
- sa.Index("idx_clusters_deleted_at_created_at", "deleted_at", "created_at"),
- # At most one default cluster per Org (partial unique on
- # is_default + soft-delete predicate). Each Org's deploy form
- # falls back to its own default; admin "All" falls back to the
- # platform Org's default.
- sa.Index(
- "uix_clusters_default_per_org",
- "owner_principal_id",
- unique=True,
- sqlite_where=sa.text("is_default = 1 AND deleted_at IS NULL"),
- postgresql_where=sa.text("is_default = true AND deleted_at IS NULL"),
- ),
- )
- id: Optional[int] = Field(default=None, primary_key=True)
- hashed_suffix: str = Field(nullable=False, default=secrets.token_hex(6))
- registration_token: Optional[str] = Field(
- nullable=True, default=secrets.token_hex(16)
- )
- cluster_worker_pools: List[WorkerPool] = Relationship(
- sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
- back_populates="cluster",
- )
- cluster_models: List["Model"] = Relationship(
- sa_relationship_kwargs={"lazy": "noload"}, back_populates="cluster"
- )
- cluster_model_instances: List["ModelInstance"] = Relationship(
- sa_relationship_kwargs={"lazy": "noload"}, back_populates="cluster"
- )
- cluster_users: list["User"] = Relationship(
- sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
- back_populates="cluster",
- )
- cluster_workers: List["Worker"] = Relationship(
- sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
- back_populates="cluster",
- )
- _models: int = PrivateAttr(default=-1)
- _workers: int = PrivateAttr(default=-1)
- _ready_workers: int = PrivateAttr(default=-1)
- _gpus: int = PrivateAttr(default=-1)
- @computed_field()
- @property
- def workers(self) -> int:
- try:
- if self._workers >= 0:
- return self._workers
- except TypeError:
- pass
- return len(self.cluster_workers or [])
- @computed_field()
- @property
- def ready_workers(self) -> int:
- try:
- if self._ready_workers >= 0:
- return self._ready_workers
- except TypeError:
- pass
- return len([w for w in self.cluster_workers or [] if w.state.value == 'ready'])
- @computed_field(alias="gpus")
- @property
- def gpus(self) -> int:
- try:
- if self._gpus >= 0:
- return self._gpus
- except TypeError:
- pass
- count = 0
- for worker in self.cluster_workers or []:
- if worker.status is None or worker.status.gpu_devices is None:
- continue
- count += len(worker.status.gpu_devices)
- return count
- @computed_field(alias="models")
- @property
- def models(self) -> int:
- try:
- if self._models >= 0:
- return self._models
- except TypeError:
- pass
- return len(self.cluster_models or [])
- def __hash__(self):
- return hash(self.id)
- def __eq__(self, other):
- if super().__eq__(other) and isinstance(other, Cluster):
- return self.id == other.id
- return False
- def __init__(
- self,
- workers: int = -1,
- ready_workers: int = -1,
- gpus: int = -1,
- models: int = -1,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self._workers = workers
- self._ready_workers = ready_workers
- self._gpus = gpus
- self._models = models
- class ClusterListParams(ListParams):
- sortable_fields: ClassVar[List[str]] = [
- "name",
- "provider",
- "state",
- "workers",
- "ready_workers",
- "gpus",
- "models",
- "created_at",
- "updated_at",
- ]
- class ClusterPublic(ClusterBase, PublicFields):
- workers: int = Field(default=0)
- ready_workers: int = Field(default=0)
- gpus: int = Field(default=0)
- models: int = Field(default=0)
- worker_config: Optional[PredefinedConfigNoDefaults] = Field(default=None)
- ClustersPublic = PaginatedList[ClusterPublic]
- class SensitiveRegistrationConfig(SensitivePredefinedConfig):
- model_config = ConfigDict(extra="ignore")
- token: str
- class ClusterRegistrationTokenPublic(BaseModel):
- """
- The arguments of docker run command to register a worker.
- The env attribute is basically a dict of environment variables parsed from SensitiveRegistrationConfig.
- """
- token: str
- server_url: str
- image: str
- env: Dict[str, str]
- args: List[str]
- class CredentialType(str, Enum):
- SSH = "ssh"
- CA = "ca"
- X509 = "x509"
- class SSHKeyOptions(BaseModel):
- algorithm: str = Field(default="RSA")
- length: int = Field(default=2048)
- class CredentialBase(SQLModel):
- external_id: Optional[str] = Field(
- default=None, sa_column=Column(String(255), nullable=True)
- )
- credential_type: CredentialType = Field(default=CredentialType.SSH)
- # pem format public key
- public_key: str = Field(sa_column=Column(Text, nullable=False))
- # base64 encoded private key
- encoded_private_key: str = Field(default="", sa_column=Column(Text, nullable=False))
- # e.g. RSA, ED25519
- ssh_key_options: Optional[SSHKeyOptions] = Field(
- default=None,
- sa_column=Column(pydantic_column_type(SSHKeyOptions), nullable=True),
- )
- class Credential(CredentialBase, BaseModelMixin, table=True):
- __tablename__ = "credentials"
- __table_args__ = (sa.Index("idx_credentials_external_id", "external_id"),)
- id: Optional[int] = Field(default=None, primary_key=True)
|