from datetime import datetime, timezone from enum import Enum from typing import ClassVar, Dict, Optional, Any from pydantic import ConfigDict, BaseModel, field_validator from urllib.parse import urlparse from sqlmodel import ( Field, SQLModel, JSON, Column, Text, Relationship, Integer, ForeignKey, ) from sqlalchemy import String from gpustack import envs from gpustack.mixins import BaseModelMixin from gpustack.schemas.common import ( ListParams, PaginatedList, UTCDateTime, pydantic_column_type, ) from typing import List from sqlalchemy.orm import declarative_base from gpustack.utils.network import is_offline from .clusters import ClusterProvider, Cluster, WorkerPool from gpustack.schemas.config import ( PredefinedConfigNoDefaults, ModelInstanceProxyModeEnum, ) Base = declarative_base() class UtilizationInfo(BaseModel): total: int = Field(default=None) utilization_rate: Optional[float] = Field(default=None) # rang from 0 to 100 class MemoryInfo(UtilizationInfo): is_unified_memory: bool = Field(default=False) used: Optional[int] = Field(default=None) allocated: Optional[int] = Field(default=None) class CPUInfo(UtilizationInfo): pass class GPUCoreInfo(UtilizationInfo): pass class GPUNetworkInfo(BaseModel): status: str = Field(default="down") # Network status (up/down) inet: str = Field(default="") # IPv4 address netmask: str = Field(default="") # Subnet mask mac: str = Field(default="") # MAC address gateway: str = Field(default="") # Default gateway iface: Optional[str] = Field(default=None) # Network interface name mtu: Optional[int] = Field(default=None) # Maximum Transmission Unit class SwapInfo(UtilizationInfo): used: Optional[int] = Field(default=None) pass class GPUDeviceInfo(BaseModel): vendor: Optional[str] = Field(default="") """ Manufacturer of the GPU device, e.g. nvidia, amd, ascend, etc. """ type: Optional[str] = Field(default="") """ Device runtime backend type, e.g. cuda, rocm, cann, etc. """ index: Optional[int] = Field(default=None) """ GPU index, which is the logic ID of the GPU chip, which is a human-readable index and counted from 0 generally. It might be recognized as the GPU device ID in some cases, when there is no more than one GPU chip on the same card. """ device_index: Optional[int] = Field(default=0) """ GPU device index, which is the index of the onboard GPU device. In Linux, it can be retrieved under the /dev/ path. For example, /dev/nvidia0 (the first Nvidia card), /dev/davinci2(the third Ascend card), etc. """ device_chip_index: Optional[int] = Field(default=0) """ GPU device chip index, which is the index of the GPU chip on the card. It works with `device_index` to identify a GPU chip uniquely. For example, the first chip on the first card is 0, and the second chip on the first card is 1. """ arch_family: Optional[str] = Field(default=None) """ Architecture family of the GPU device. """ name: str = Field(default="") """ GPU name, e.g. NVIDIA A100-SXM4-40GB, NVIDIA RTX 3090, AMD MI100, Ascend 310P, etc. """ uuid: Optional[str] = Field(default="") """ UUID is a unique identifier assigned to each GPU device. """ driver_version: Optional[str] = Field(default=None) """ Driver version of the GPU device, e.g. for NVIDIA GPUs. """ runtime_version: Optional[str] = Field(default=None) """ Runtime version of the GPU device, e.g. CUDA version for NVIDIA GPUs. """ compute_capability: Optional[str] = Field(default=None) """ Compute compatibility version of the GPU device, e.g. for NVIDIA GPUs. """ class GPUDeviceStatus(GPUDeviceInfo): core: Optional[GPUCoreInfo] = Field(sa_column=Column(JSON), default=None) """ Core information of the GPU device. """ memory: Optional[MemoryInfo] = Field(sa_column=Column(JSON), default=None) """ Memory information of the GPU device. """ temperature: Optional[float] = Field(default=None) """ Temperature of the GPU device in Celsius. """ network: Optional[GPUNetworkInfo] = Field(sa_column=Column(JSON), default=None) """ Network information of the GPU device, mainly for Ascend devices. """ GPUDevicesStatus = List[GPUDeviceStatus] class MountPoint(BaseModel): name: str = Field(default="") mount_point: str = Field(default="") mount_from: str = Field(default="") total: int = Field(default=None) # in bytes used: Optional[int] = Field(default=None) free: Optional[int] = Field(default=None) available: Optional[int] = Field(default=None) FileSystemInfo = List[MountPoint] class OperatingSystemInfo(BaseModel): name: str = Field(default="") version: str = Field(default="") class KernelInfo(BaseModel): name: str = Field(default="") release: str = Field(default="") version: str = Field(default="") architecture: str = Field(default="") class UptimeInfo(BaseModel): uptime: float = Field(default=None) # in seconds boot_time: str = Field(default="") class SystemReserved(BaseModel): ram: Optional[int] = Field(default=None) vram: Optional[int] = Field(default=None) class RPCServer(BaseModel): pid: Optional[int] = None port: Optional[int] = None gpu_index: Optional[int] = None class WorkerStateEnum(str, Enum): r""" Enum for Worker State State Transition Diagram: Phase 1: Provisioning Controller | Phase 2: Healthcheck Controller ------------------------------------------|------------------------------------ PENDING > PROVISIONING > INITIALIZING > READY < -----|-----------> UNREACHABLE | | | ^ | (Worker Endpoint Unreachable) | | | | | |-------------|---------|------| └-----------> NOT_READY \_____________________________/| (Worker Stop Posting Status) ERROR | (Provisioning failed) ^ | | | | v | v | DELETING <---┘ (provisioning end) | | | | | Phase 3: Upgrade and Maintain | | -------------------------------------------|-----------------------------|----- v | MAINTENANCE <---------------------┘ (Back to Ready/Not Ready after maintenance completed) """ NOT_READY = "not_ready" READY = "ready" UNREACHABLE = "unreachable" PENDING = "pending" PROVISIONING = "provisioning" INITIALIZING = "initializing" DELETING = "deleting" ERROR = "error" MAINTENANCE = "maintenance" @property def is_provisioning(self) -> bool: return self in [ WorkerStateEnum.PENDING, WorkerStateEnum.PROVISIONING, WorkerStateEnum.INITIALIZING, WorkerStateEnum.DELETING, WorkerStateEnum.ERROR, ] class SystemInfo(BaseModel): cpu: Optional[CPUInfo] = Field(sa_column=Column(JSON), default=None) memory: Optional[MemoryInfo] = Field(sa_column=Column(JSON), default=None) swap: Optional[SwapInfo] = Field(sa_column=Column(JSON), default=None) filesystem: Optional[FileSystemInfo] = Field(sa_column=Column(JSON), default=None) os: Optional[OperatingSystemInfo] = Field(sa_column=Column(JSON), default=None) kernel: Optional[KernelInfo] = Field(sa_column=Column(JSON), default=None) uptime: Optional[UptimeInfo] = Field(sa_column=Column(JSON), default=None) class Maintenance(BaseModel): enabled: bool = False message: Optional[str] = None class WorkerStatus(SystemInfo): """ rpc_servers: Deprecated """ gpu_devices: Optional[GPUDevicesStatus] = Field( sa_column=Column(JSON), default=None ) rpc_servers: Optional[Dict[int, RPCServer]] = Field( sa_column=Column(JSON), default=None ) model_config = ConfigDict(from_attributes=True) @classmethod def get_default_status(cls) -> 'WorkerStatus': return WorkerStatus( cpu=CPUInfo(total=0), memory=MemoryInfo(total=0, is_unified_memory=False), swap=SwapInfo(total=0), filesystem=[], os=OperatingSystemInfo(name="", version=""), kernel=KernelInfo(name="", release="", version="", architecture=""), uptime=UptimeInfo(uptime=0, boot_time=""), gpu_devices=[], ) class WorkerStatusStored(BaseModel): advertise_address: Optional[str] = None hostname: str ip: str ifname: str port: int metrics_port: Optional[int] = None system_reserved: Optional[SystemReserved] = Field( default=None, sa_column=Column(pydantic_column_type(SystemReserved)) ) state_message: Optional[str] = Field( default=None, sa_column=Column(Text, nullable=True) ) status: Optional[WorkerStatus] = Field( sa_column=Column(pydantic_column_type(WorkerStatus)) ) worker_uuid: str = Field(sa_column=Column(String(255), nullable=False)) machine_id: Optional[str] = Field( default=None ) # The machine ID of the worker, used for identifying the worker in the cluster proxy_mode: Optional[ModelInstanceProxyModeEnum] = Field( default=ModelInstanceProxyModeEnum.WORKER, ) class WorkerStatusPublic(WorkerStatusStored): gateway_endpoint: Optional[str] = None class WorkerUpdate(SQLModel): """ WorkerUpdate: updatable fields for Worker """ name: str = Field(index=True, unique=True) labels: Dict[str, str] = Field(sa_column=Column(JSON), default={}) maintenance: Optional[Maintenance] = Field( default=None, sa_column=Column(pydantic_column_type(Maintenance), default=None), ) class WorkerCreate(WorkerStatusStored, WorkerUpdate): cluster_id: Optional[int] = Field( sa_column=Column(Integer, ForeignKey("clusters.id"), nullable=False), default=None, ) # Denormalized from cluster.owner_principal_id for per-row tenant # filtering. NULL = belongs to a global cluster (admin-managed). owner_principal_id: Optional[int] = Field( default=None, sa_column=Column(Integer, ForeignKey("principals.id"), nullable=True), ) external_id: Optional[str] = Field( default=None, sa_column=Column(String(255), nullable=True) ) worker_version: Optional[str] = Field( default=None, sa_column=Column(String(100), nullable=True) ) class WorkerBase(WorkerCreate): state: WorkerStateEnum = WorkerStateEnum.NOT_READY heartbeat_time: Optional[datetime] = Field( sa_column=Column(UTCDateTime), default=None ) unreachable: bool = False def compute_state(self): if self.maintenance and self.maintenance.enabled: self.state = WorkerStateEnum.MAINTENANCE self.state_message = self.maintenance.message return if self.state.is_provisioning: return if self.state == WorkerStateEnum.NOT_READY and self.state_message is not None: return is_not_ready_flag, last_heartbeat_str = is_offline( self.heartbeat_time, envs.WORKER_HEARTBEAT_GRACE_PERIOD, datetime.now(timezone.utc), ) if is_not_ready_flag: reschedule_minutes = envs.MODEL_INSTANCE_RESCHEDULE_GRACE_PERIOD / 60 self.state = WorkerStateEnum.NOT_READY self.state_message = ( f"Heartbeat lost (last heartbeat: {last_heartbeat_str}). " f"If the worker remains unresponsive for more than {reschedule_minutes:.1f} minutes, " "the instances on this worker will be rescheduled automatically. " "If this downtime is planned maintenance, please enable maintenance mode. " "Otherwise, please check the worker logs." ) return if self.unreachable: address = self.advertise_address or self.ip healthz_url = f"http://{address}:{self.port}/healthz" msg = ( "Server cannot access the " f"worker's health check endpoint at {healthz_url}. " "Please verify the port requirements in the " "documentation" ) self.state = WorkerStateEnum.UNREACHABLE self.state_message = msg return self.state = WorkerStateEnum.READY self.state_message = None provider: ClusterProvider = Field(default=ClusterProvider.Docker) worker_pool_id: Optional[int] = Field( default=None, sa_column=Column(Integer, ForeignKey("worker_pools.id"), nullable=True), ) # The worker pool this worker belongs to # Not setting foreign key to manage lifecycle ssh_key_id: Optional[int] = Field( default=None, sa_column=Column(Integer, nullable=True) ) provider_config: Optional[Dict[str, Any]] = Field( default=None, sa_column=Column(JSON, nullable=True) ) # Server side proxy field proxy_address: Optional[str] = Field( default=None, sa_column=Column(String(255), nullable=True) ) @field_validator("proxy_address", mode="before") def validate_proxy_address(cls, v): if v is None: return v if not isinstance(v, str): raise ValueError("proxy_address must be a string or None") # proxy address must be in url format, e.g. http://1.2.3.4:8000 result = urlparse(v) if not all([result.scheme, result.netloc]): raise ValueError("proxy_address must be a valid URL") return v def get_proxy_address(self) -> Optional[str]: """ Get the proxy address for the worker. If the worker has a proxy_address, return it. Otherwise, return None to indicate that no proxy should be used. """ if self.proxy_mode != ModelInstanceProxyModeEnum.TUNNEL: return None return self.proxy_address class Worker(WorkerBase, BaseModelMixin, table=True): __tablename__ = 'workers' id: Optional[int] = Field(default=None, primary_key=True) cluster: Cluster = Relationship( back_populates="cluster_workers", sa_relationship_kwargs={"lazy": "noload"} ) worker_pool: Optional[WorkerPool] = Relationship( back_populates="pool_workers", sa_relationship_kwargs={"lazy": "noload"} ) # This field should be replaced by x509 credential if mTLS is supported. token: Optional[str] = Field(default=None, nullable=True) @property def provision_progress(self) -> Optional[str]: """ The provisioning progress should have following steps: 1. create_ssh_key 2. create_instance with created ssh_key 3. wait_for_started 4. wait_for_public_ip 5. [optional] create_volumes_and_attach """ if self.state == WorkerStateEnum.INITIALIZING: return "5/5" if ( self.state != WorkerStateEnum.PROVISIONING and self.state != WorkerStateEnum.PENDING ): return None format = "{}/{}" total = 5 current = sum( [ self.state == WorkerStateEnum.PROVISIONING, self.ssh_key_id is not None, self.external_id is not None, self.ip is not None and self.ip != "", "volume_ids" in (self.provider_config or {}), ] ) return format.format(current, total) def __hash__(self): return hash(self.id) def __eq__(self, other): if super().__eq__(other) and isinstance(other, Worker): return self.id == other.id return False class WorkerListParams(ListParams): sortable_fields: ClassVar[List[str]] = [ "name", "state", "ip", "status.cpu.utilization_rate", "status.memory.utilization_rate", "gpus", # gpu count, the same naming pattern as in Clusters "created_at", "updated_at", ] class WorkerPublic( WorkerBase, ): id: int created_at: datetime updated_at: datetime me: Optional[bool] = None # Indicates if the worker is the current worker provision_progress: Optional[str] = None # Indicates the provisioning progress worker_uuid: Optional[str] = Field(default=None, exclude=True) machine_id: Optional[str] = Field(default=None, exclude=True) class WorkerRegistrationPublic(WorkerPublic): token: str worker_uuid: str worker_config: Optional["PredefinedConfigNoDefaults"] = None WorkersPublic = PaginatedList[WorkerPublic]