clusters.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. import secrets
  2. from urllib.parse import urlparse
  3. from enum import Enum
  4. from typing import ClassVar, Optional, Dict, Any, List
  5. from pydantic import (
  6. BaseModel,
  7. computed_field,
  8. field_validator,
  9. ConfigDict,
  10. PrivateAttr,
  11. Field as PydanticField,
  12. )
  13. from sqlmodel import (
  14. Field,
  15. Relationship,
  16. Column,
  17. SQLModel,
  18. Text,
  19. Integer,
  20. ForeignKey,
  21. JSON,
  22. String,
  23. )
  24. import sqlalchemy as sa
  25. from typing import TYPE_CHECKING
  26. from gpustack.schemas.config import (
  27. SensitivePredefinedConfig,
  28. PredefinedConfigNoDefaults,
  29. )
  30. from gpustack.mixins import BaseModelMixin
  31. from gpustack.schemas.common import (
  32. PublicFields,
  33. ListParams,
  34. PaginatedList,
  35. pydantic_column_type,
  36. )
  37. if TYPE_CHECKING:
  38. from gpustack.schemas.models import Model, ModelInstance
  39. from gpustack.schemas.workers import Worker
  40. from gpustack.schemas.users import User
  41. class WorkerPoolUpdate(SQLModel):
  42. name: str
  43. batch_size: Optional[int] = Field(default=None, ge=1)
  44. replicas: int = Field(default=1, ge=0)
  45. labels: Optional[Dict[str, str]] = Field(sa_column=Column(JSON), default={})
  46. class Volume(BaseModel):
  47. format: Optional[str] = None
  48. size_gb: Optional[int] = None
  49. name: Optional[str] = None
  50. @field_validator("name")
  51. def validate_name(cls, v):
  52. if not v:
  53. return v
  54. # the worker id will be appended to the name to ensure uniqueness
  55. # so the max length is 60 characters to leave room for the worker id
  56. if len(v) > 60:
  57. raise ValueError("Volume name too long, max 60 characters")
  58. # allow alphanumeric characters, dashes, and periods
  59. allowed_chars = set(
  60. "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-."
  61. )
  62. if not all(c in allowed_chars for c in v):
  63. raise ValueError("Volume name contains invalid characters")
  64. return v
  65. class HostPathVolumeSource(BaseModel):
  66. model_config = ConfigDict(populate_by_name=True, extra="ignore")
  67. path: str = PydanticField(
  68. ...,
  69. description="Path of the directory on the host. If the path is a symlink, it will follow the link to the real path.",
  70. )
  71. type: Optional[str] = PydanticField(None, description="Type for HostPath Volume.")
  72. class PersistentVolumeClaimVolumeSource(BaseModel):
  73. model_config = ConfigDict(populate_by_name=True, extra="ignore")
  74. claim_name: str = PydanticField(
  75. ...,
  76. alias="claimName",
  77. description="ClaimName is the name of a PersistentVolumeClaim in the same namespace as the pod using this volume.",
  78. )
  79. read_only: bool = PydanticField(
  80. False,
  81. alias="readOnly",
  82. description="Will force the ReadOnly setting in VolumeMounts.",
  83. )
  84. class ConfigMapVolumeSource(BaseModel):
  85. """
  86. This source will not be used for now. You won't be able to create this kind of volume through UI.
  87. """
  88. model_config = ConfigDict(populate_by_name=True, extra="ignore")
  89. name: str = PydanticField(..., description="Name of the referent.")
  90. optional: Optional[bool] = PydanticField(
  91. None, description="Specify whether the ConfigMap or its keys must be defined."
  92. )
  93. class VolumeSource(BaseModel):
  94. model_config = ConfigDict(populate_by_name=True, extra="ignore")
  95. host_path: Optional[HostPathVolumeSource] = PydanticField(None, alias="hostPath")
  96. persistent_volume_claim: Optional[PersistentVolumeClaimVolumeSource] = (
  97. PydanticField(None, alias="persistentVolumeClaim")
  98. )
  99. config_map: Optional[ConfigMapVolumeSource] = PydanticField(None, alias="configMap")
  100. class K8sVolumeMount(BaseModel):
  101. model_config = ConfigDict(populate_by_name=True, extra="ignore")
  102. name: str
  103. mount_path: str = PydanticField(..., alias="mountPath")
  104. read_only: bool = PydanticField(False, alias="readOnly")
  105. volume_source: Optional[VolumeSource] = PydanticField(
  106. default=None,
  107. alias="volumeSource",
  108. description=(
  109. "Kubernetes VolumeSource definition. Examples:\n"
  110. '- hostPath: `{"hostPath": {"path": "/data", "type": "Directory"}}`\n'
  111. '- persistentVolumeClaim: `{"persistentVolumeClaim": {"claimName": "my-pvc"}}`\n'
  112. '- configMap: `{"configMap": {"name": "my-configmap"}}`'
  113. ),
  114. )
  115. @field_validator("name")
  116. def validate_name(cls, v):
  117. if not v:
  118. return v
  119. if len(v) > 63:
  120. raise ValueError("Volume name must be less than 64 characters")
  121. import re
  122. if not re.fullmatch(r"[a-z0-9]([-a-z0-9]*[a-z0-9])?", v):
  123. raise ValueError(
  124. "Volume name must be a valid DNS-1123 label (e.g. 'my-name', or '123-abc'); "
  125. "it must consist of lower case alphanumeric characters or '-', "
  126. "and must start and end with an alphanumeric character."
  127. )
  128. return v
  129. class CloudOptions(BaseModel):
  130. volumes: Optional[List[Volume]] = None
  131. class WorkerPoolCreate(WorkerPoolUpdate):
  132. instance_type: str
  133. os_image: str
  134. image_name: str
  135. cloud_options: Optional[CloudOptions] = Field(
  136. default={}, sa_column=Column(pydantic_column_type(CloudOptions))
  137. )
  138. zone: Optional[str] = None
  139. # instance_spec is for UI to store the instance_type's extended specifications for display.
  140. instance_spec: Optional[Dict[str, Any]] = Field(
  141. default=None, sa_column=Column(JSON)
  142. )
  143. class WorkerPoolBase(WorkerPoolCreate):
  144. cluster_id: int = Field(
  145. sa_column=Column(Integer, ForeignKey("clusters.id", ondelete="CASCADE"))
  146. )
  147. # Mirrors the cluster's owner_principal_id (NOT NULL since clusters are
  148. # always Org-owned). The route layer copies the parent cluster's
  149. # value so the row can be filtered without a join.
  150. owner_principal_id: Optional[int] = Field(
  151. default=None, foreign_key="principals.id", nullable=False
  152. )
  153. class WorkerPool(WorkerPoolBase, BaseModelMixin, table=True):
  154. __tablename__ = "worker_pools"
  155. __table_args__ = (
  156. sa.Index("idx_worker_pools_deleted_at_created_at", "deleted_at", "created_at"),
  157. )
  158. id: Optional[int] = Field(default=None, primary_key=True)
  159. cluster: Optional["Cluster"] = Relationship(
  160. back_populates="cluster_worker_pools",
  161. sa_relationship_kwargs={"lazy": "noload"},
  162. )
  163. pool_workers: list["Worker"] = Relationship(
  164. sa_relationship_kwargs={"lazy": "noload"},
  165. back_populates="worker_pool",
  166. )
  167. _workers: int = PrivateAttr(default=-1)
  168. _ready_workers: int = PrivateAttr(default=-1)
  169. @computed_field()
  170. @property
  171. def workers(self) -> int:
  172. try:
  173. if self._workers >= 0:
  174. return self._workers
  175. except TypeError:
  176. pass
  177. return len(self.pool_workers or [])
  178. @computed_field()
  179. @property
  180. def ready_workers(self) -> int:
  181. try:
  182. if self._ready_workers >= 0:
  183. return self._ready_workers
  184. except TypeError:
  185. pass
  186. return len([w for w in self.pool_workers or [] if w.state.value == 'ready'])
  187. def __hash__(self):
  188. return hash(self.id)
  189. def __eq__(self, other):
  190. if super().__eq__(other) and isinstance(other, WorkerPool):
  191. return self.id == other.id
  192. return False
  193. def __init__(
  194. self,
  195. workers: int = -1,
  196. ready_workers: int = -1,
  197. **kwargs,
  198. ):
  199. super().__init__(**kwargs)
  200. self._workers = workers
  201. self._ready_workers = ready_workers
  202. class WorkerPoolPublic(WorkerPoolBase, PublicFields):
  203. workers: int = Field(default=0)
  204. ready_workers: int = Field(default=0)
  205. WorkerPoolsPublic = PaginatedList[WorkerPoolPublic]
  206. class ClusterProvider(Enum):
  207. Docker = "Docker"
  208. Kubernetes = "Kubernetes"
  209. DigitalOcean = "DigitalOcean"
  210. class CloudCredentialBase(SQLModel):
  211. """
  212. Supports providers other than Kubernetes and Docker.
  213. """
  214. name: str
  215. description: Optional[str] = None
  216. provider: ClusterProvider = Field(default=ClusterProvider.DigitalOcean)
  217. key: Optional[str] = None
  218. options: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
  219. # Every cloud credential belongs to one Org (mirrors cluster scope).
  220. # The route fills this with ctx.current_principal_id or PLATFORM_ORG when
  221. # the caller omits it.
  222. owner_principal_id: Optional[int] = Field(
  223. default=None, foreign_key="principals.id", nullable=False
  224. )
  225. class CloudCredentialUpdate(CloudCredentialBase):
  226. secret: Optional[str] = None
  227. class CloudCredentialCreate(CloudCredentialUpdate):
  228. pass
  229. class CloudCredential(CloudCredentialCreate, BaseModelMixin, table=True):
  230. __tablename__ = "cloud_credentials"
  231. __table_args__ = (
  232. sa.Index(
  233. "idx_cloud_credentials_deleted_at_created_at", "deleted_at", "created_at"
  234. ),
  235. )
  236. id: Optional[int] = Field(default=None, primary_key=True)
  237. def __hash__(self):
  238. return hash(self.id)
  239. def __eq__(self, other):
  240. if super().__eq__(other) and isinstance(other, CloudCredential):
  241. return self.id == other.id
  242. return False
  243. class CloudCredentialListParams(ListParams):
  244. sortable_fields: ClassVar[List[str]] = [
  245. "name",
  246. "provider",
  247. "created_at",
  248. "updated_at",
  249. ]
  250. class CloudCredentialPublic(CloudCredentialBase, PublicFields):
  251. pass
  252. CloudCredentialsPublic = PaginatedList[CloudCredentialPublic]
  253. class ClusterStateEnum(str, Enum):
  254. PENDING = 'pending'
  255. PROVISIONING = 'provisioning'
  256. PROVISIONED = 'provisioned'
  257. READY = 'ready'
  258. class ClusterUpdate(SQLModel):
  259. name: str
  260. description: Optional[str] = None
  261. gateway_endpoint: Optional[str] = None
  262. server_url: Optional[str] = None
  263. worker_config: Optional[PredefinedConfigNoDefaults] = Field(
  264. default=None,
  265. sa_column=Column(
  266. pydantic_column_type(
  267. PredefinedConfigNoDefaults,
  268. exclude_none=True,
  269. exclude_unset=True,
  270. exclude_defaults=True,
  271. )
  272. ),
  273. )
  274. k8s_volume_mounts: Optional[List[K8sVolumeMount]] = Field(
  275. default=None,
  276. sa_column=Column(
  277. pydantic_column_type(
  278. List[K8sVolumeMount],
  279. exclude_none=True,
  280. exclude_unset=True,
  281. exclude_defaults=True,
  282. )
  283. ),
  284. )
  285. @field_validator("server_url")
  286. def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
  287. if v is not None and len(v) == 0:
  288. return None
  289. if v is not None:
  290. parsed = urlparse(v)
  291. if not parsed.scheme or not parsed.netloc:
  292. raise ValueError("Invalid server_url format")
  293. return v
  294. class ClusterCreateBase(ClusterUpdate):
  295. provider: ClusterProvider = Field(default=ClusterProvider.Docker)
  296. credential_id: Optional[int] = Field(
  297. default=None, foreign_key="cloud_credentials.id"
  298. )
  299. region: Optional[str] = None
  300. # Every cluster belongs to one Org. The route layer fills this with
  301. # ctx.current_principal_id (or PLATFORM_PRINCIPAL_ID for admin in "All"
  302. # mode) when callers omit it; sharing across Orgs is expressed via
  303. # cluster_access rather than NULL ownership.
  304. owner_principal_id: Optional[int] = Field(
  305. default=None, foreign_key="principals.id", nullable=False
  306. )
  307. class ClusterCreate(ClusterCreateBase):
  308. worker_pools: Optional[List[WorkerPoolCreate]] = Field(default=None)
  309. class ClusterBase(ClusterCreateBase):
  310. state: ClusterStateEnum = ClusterStateEnum.PROVISIONING
  311. state_message: Optional[str] = Field(
  312. default=None, sa_column=Column(Text, nullable=True)
  313. )
  314. reported_gateway_endpoint: Optional[str] = None
  315. is_default: bool = Field(default=False)
  316. class Cluster(ClusterBase, BaseModelMixin, table=True):
  317. __tablename__ = "clusters"
  318. __table_args__ = (
  319. sa.Index("idx_clusters_deleted_at_created_at", "deleted_at", "created_at"),
  320. # At most one default cluster per Org (partial unique on
  321. # is_default + soft-delete predicate). Each Org's deploy form
  322. # falls back to its own default; admin "All" falls back to the
  323. # platform Org's default.
  324. sa.Index(
  325. "uix_clusters_default_per_org",
  326. "owner_principal_id",
  327. unique=True,
  328. sqlite_where=sa.text("is_default = 1 AND deleted_at IS NULL"),
  329. postgresql_where=sa.text("is_default = true AND deleted_at IS NULL"),
  330. ),
  331. )
  332. id: Optional[int] = Field(default=None, primary_key=True)
  333. hashed_suffix: str = Field(nullable=False, default=secrets.token_hex(6))
  334. registration_token: Optional[str] = Field(
  335. nullable=True, default=secrets.token_hex(16)
  336. )
  337. cluster_worker_pools: List[WorkerPool] = Relationship(
  338. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  339. back_populates="cluster",
  340. )
  341. cluster_models: List["Model"] = Relationship(
  342. sa_relationship_kwargs={"lazy": "noload"}, back_populates="cluster"
  343. )
  344. cluster_model_instances: List["ModelInstance"] = Relationship(
  345. sa_relationship_kwargs={"lazy": "noload"}, back_populates="cluster"
  346. )
  347. cluster_users: list["User"] = Relationship(
  348. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  349. back_populates="cluster",
  350. )
  351. cluster_workers: List["Worker"] = Relationship(
  352. sa_relationship_kwargs={"cascade": "delete", "lazy": "noload"},
  353. back_populates="cluster",
  354. )
  355. _models: int = PrivateAttr(default=-1)
  356. _workers: int = PrivateAttr(default=-1)
  357. _ready_workers: int = PrivateAttr(default=-1)
  358. _gpus: int = PrivateAttr(default=-1)
  359. @computed_field()
  360. @property
  361. def workers(self) -> int:
  362. try:
  363. if self._workers >= 0:
  364. return self._workers
  365. except TypeError:
  366. pass
  367. return len(self.cluster_workers or [])
  368. @computed_field()
  369. @property
  370. def ready_workers(self) -> int:
  371. try:
  372. if self._ready_workers >= 0:
  373. return self._ready_workers
  374. except TypeError:
  375. pass
  376. return len([w for w in self.cluster_workers or [] if w.state.value == 'ready'])
  377. @computed_field(alias="gpus")
  378. @property
  379. def gpus(self) -> int:
  380. try:
  381. if self._gpus >= 0:
  382. return self._gpus
  383. except TypeError:
  384. pass
  385. count = 0
  386. for worker in self.cluster_workers or []:
  387. if worker.status is None or worker.status.gpu_devices is None:
  388. continue
  389. count += len(worker.status.gpu_devices)
  390. return count
  391. @computed_field(alias="models")
  392. @property
  393. def models(self) -> int:
  394. try:
  395. if self._models >= 0:
  396. return self._models
  397. except TypeError:
  398. pass
  399. return len(self.cluster_models or [])
  400. def __hash__(self):
  401. return hash(self.id)
  402. def __eq__(self, other):
  403. if super().__eq__(other) and isinstance(other, Cluster):
  404. return self.id == other.id
  405. return False
  406. def __init__(
  407. self,
  408. workers: int = -1,
  409. ready_workers: int = -1,
  410. gpus: int = -1,
  411. models: int = -1,
  412. **kwargs,
  413. ):
  414. super().__init__(**kwargs)
  415. self._workers = workers
  416. self._ready_workers = ready_workers
  417. self._gpus = gpus
  418. self._models = models
  419. class ClusterListParams(ListParams):
  420. sortable_fields: ClassVar[List[str]] = [
  421. "name",
  422. "provider",
  423. "state",
  424. "workers",
  425. "ready_workers",
  426. "gpus",
  427. "models",
  428. "created_at",
  429. "updated_at",
  430. ]
  431. class ClusterPublic(ClusterBase, PublicFields):
  432. workers: int = Field(default=0)
  433. ready_workers: int = Field(default=0)
  434. gpus: int = Field(default=0)
  435. models: int = Field(default=0)
  436. worker_config: Optional[PredefinedConfigNoDefaults] = Field(default=None)
  437. ClustersPublic = PaginatedList[ClusterPublic]
  438. class SensitiveRegistrationConfig(SensitivePredefinedConfig):
  439. model_config = ConfigDict(extra="ignore")
  440. token: str
  441. class ClusterRegistrationTokenPublic(BaseModel):
  442. """
  443. The arguments of docker run command to register a worker.
  444. The env attribute is basically a dict of environment variables parsed from SensitiveRegistrationConfig.
  445. """
  446. token: str
  447. server_url: str
  448. image: str
  449. env: Dict[str, str]
  450. args: List[str]
  451. class CredentialType(str, Enum):
  452. SSH = "ssh"
  453. CA = "ca"
  454. X509 = "x509"
  455. class SSHKeyOptions(BaseModel):
  456. algorithm: str = Field(default="RSA")
  457. length: int = Field(default=2048)
  458. class CredentialBase(SQLModel):
  459. external_id: Optional[str] = Field(
  460. default=None, sa_column=Column(String(255), nullable=True)
  461. )
  462. credential_type: CredentialType = Field(default=CredentialType.SSH)
  463. # pem format public key
  464. public_key: str = Field(sa_column=Column(Text, nullable=False))
  465. # base64 encoded private key
  466. encoded_private_key: str = Field(default="", sa_column=Column(Text, nullable=False))
  467. # e.g. RSA, ED25519
  468. ssh_key_options: Optional[SSHKeyOptions] = Field(
  469. default=None,
  470. sa_column=Column(pydantic_column_type(SSHKeyOptions), nullable=True),
  471. )
  472. class Credential(CredentialBase, BaseModelMixin, table=True):
  473. __tablename__ = "credentials"
  474. __table_args__ = (sa.Index("idx_credentials_external_id", "external_id"),)
  475. id: Optional[int] = Field(default=None, primary_key=True)