workers.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. from datetime import datetime, timezone
  2. from enum import Enum
  3. from typing import ClassVar, Dict, Optional, Any
  4. from pydantic import ConfigDict, BaseModel, field_validator
  5. from urllib.parse import urlparse
  6. from sqlmodel import (
  7. Field,
  8. SQLModel,
  9. JSON,
  10. Column,
  11. Text,
  12. Relationship,
  13. Integer,
  14. ForeignKey,
  15. )
  16. from sqlalchemy import String
  17. from gpustack import envs
  18. from gpustack.mixins import BaseModelMixin
  19. from gpustack.schemas.common import (
  20. ListParams,
  21. PaginatedList,
  22. UTCDateTime,
  23. pydantic_column_type,
  24. )
  25. from typing import List
  26. from sqlalchemy.orm import declarative_base
  27. from gpustack.utils.network import is_offline
  28. from .clusters import ClusterProvider, Cluster, WorkerPool
  29. from gpustack.schemas.config import (
  30. PredefinedConfigNoDefaults,
  31. ModelInstanceProxyModeEnum,
  32. )
  33. Base = declarative_base()
  34. class UtilizationInfo(BaseModel):
  35. total: int = Field(default=None)
  36. utilization_rate: Optional[float] = Field(default=None) # rang from 0 to 100
  37. class MemoryInfo(UtilizationInfo):
  38. is_unified_memory: bool = Field(default=False)
  39. used: Optional[int] = Field(default=None)
  40. allocated: Optional[int] = Field(default=None)
  41. class CPUInfo(UtilizationInfo):
  42. pass
  43. class GPUCoreInfo(UtilizationInfo):
  44. pass
  45. class GPUNetworkInfo(BaseModel):
  46. status: str = Field(default="down") # Network status (up/down)
  47. inet: str = Field(default="") # IPv4 address
  48. netmask: str = Field(default="") # Subnet mask
  49. mac: str = Field(default="") # MAC address
  50. gateway: str = Field(default="") # Default gateway
  51. iface: Optional[str] = Field(default=None) # Network interface name
  52. mtu: Optional[int] = Field(default=None) # Maximum Transmission Unit
  53. class SwapInfo(UtilizationInfo):
  54. used: Optional[int] = Field(default=None)
  55. pass
  56. class GPUDeviceInfo(BaseModel):
  57. vendor: Optional[str] = Field(default="")
  58. """
  59. Manufacturer of the GPU device, e.g. nvidia, amd, ascend, etc.
  60. """
  61. type: Optional[str] = Field(default="")
  62. """
  63. Device runtime backend type, e.g. cuda, rocm, cann, etc.
  64. """
  65. index: Optional[int] = Field(default=None)
  66. """
  67. GPU index, which is the logic ID of the GPU chip,
  68. which is a human-readable index and counted from 0 generally.
  69. It might be recognized as the GPU device ID in some cases, when there is no more than one GPU chip on the same card.
  70. """
  71. device_index: Optional[int] = Field(default=0)
  72. """
  73. GPU device index, which is the index of the onboard GPU device.
  74. In Linux, it can be retrieved under the /dev/ path.
  75. For example, /dev/nvidia0 (the first Nvidia card), /dev/davinci2(the third Ascend card), etc.
  76. """
  77. device_chip_index: Optional[int] = Field(default=0)
  78. """
  79. GPU device chip index, which is the index of the GPU chip on the card.
  80. It works with `device_index` to identify a GPU chip uniquely.
  81. For example, the first chip on the first card is 0, and the second chip on the first card is 1.
  82. """
  83. arch_family: Optional[str] = Field(default=None)
  84. """
  85. Architecture family of the GPU device.
  86. """
  87. name: str = Field(default="")
  88. """
  89. GPU name, e.g. NVIDIA A100-SXM4-40GB, NVIDIA RTX 3090, AMD MI100, Ascend 310P, etc.
  90. """
  91. uuid: Optional[str] = Field(default="")
  92. """
  93. UUID is a unique identifier assigned to each GPU device.
  94. """
  95. driver_version: Optional[str] = Field(default=None)
  96. """
  97. Driver version of the GPU device, e.g. for NVIDIA GPUs.
  98. """
  99. runtime_version: Optional[str] = Field(default=None)
  100. """
  101. Runtime version of the GPU device, e.g. CUDA version for NVIDIA GPUs.
  102. """
  103. compute_capability: Optional[str] = Field(default=None)
  104. """
  105. Compute compatibility version of the GPU device, e.g. for NVIDIA GPUs.
  106. """
  107. class GPUDeviceStatus(GPUDeviceInfo):
  108. core: Optional[GPUCoreInfo] = Field(sa_column=Column(JSON), default=None)
  109. """
  110. Core information of the GPU device.
  111. """
  112. memory: Optional[MemoryInfo] = Field(sa_column=Column(JSON), default=None)
  113. """
  114. Memory information of the GPU device.
  115. """
  116. temperature: Optional[float] = Field(default=None)
  117. """
  118. Temperature of the GPU device in Celsius.
  119. """
  120. network: Optional[GPUNetworkInfo] = Field(sa_column=Column(JSON), default=None)
  121. """
  122. Network information of the GPU device, mainly for Ascend devices.
  123. """
  124. GPUDevicesStatus = List[GPUDeviceStatus]
  125. class MountPoint(BaseModel):
  126. name: str = Field(default="")
  127. mount_point: str = Field(default="")
  128. mount_from: str = Field(default="")
  129. total: int = Field(default=None) # in bytes
  130. used: Optional[int] = Field(default=None)
  131. free: Optional[int] = Field(default=None)
  132. available: Optional[int] = Field(default=None)
  133. FileSystemInfo = List[MountPoint]
  134. class OperatingSystemInfo(BaseModel):
  135. name: str = Field(default="")
  136. version: str = Field(default="")
  137. class KernelInfo(BaseModel):
  138. name: str = Field(default="")
  139. release: str = Field(default="")
  140. version: str = Field(default="")
  141. architecture: str = Field(default="")
  142. class UptimeInfo(BaseModel):
  143. uptime: float = Field(default=None) # in seconds
  144. boot_time: str = Field(default="")
  145. class SystemReserved(BaseModel):
  146. ram: Optional[int] = Field(default=None)
  147. vram: Optional[int] = Field(default=None)
  148. class RPCServer(BaseModel):
  149. pid: Optional[int] = None
  150. port: Optional[int] = None
  151. gpu_index: Optional[int] = None
  152. class WorkerStateEnum(str, Enum):
  153. r"""
  154. Enum for Worker State
  155. State Transition Diagram:
  156. Phase 1: Provisioning Controller | Phase 2: Healthcheck Controller
  157. ------------------------------------------|------------------------------------
  158. PENDING > PROVISIONING > INITIALIZING > READY < -----|-----------> UNREACHABLE
  159. | | | ^ | (Worker Endpoint Unreachable)
  160. | | | | |
  161. |-------------|---------|------| └-----------> NOT_READY
  162. \_____________________________/| (Worker Stop Posting Status)
  163. ERROR | (Provisioning failed) ^
  164. | | | |
  165. v | v |
  166. DELETING <---┘ (provisioning end) |
  167. | |
  168. | |
  169. Phase 3: Upgrade and Maintain | |
  170. -------------------------------------------|-----------------------------|-----
  171. v |
  172. MAINTENANCE <---------------------┘
  173. (Back to Ready/Not Ready after maintenance completed)
  174. """
  175. NOT_READY = "not_ready"
  176. READY = "ready"
  177. UNREACHABLE = "unreachable"
  178. PENDING = "pending"
  179. PROVISIONING = "provisioning"
  180. INITIALIZING = "initializing"
  181. DELETING = "deleting"
  182. ERROR = "error"
  183. MAINTENANCE = "maintenance"
  184. @property
  185. def is_provisioning(self) -> bool:
  186. return self in [
  187. WorkerStateEnum.PENDING,
  188. WorkerStateEnum.PROVISIONING,
  189. WorkerStateEnum.INITIALIZING,
  190. WorkerStateEnum.DELETING,
  191. WorkerStateEnum.ERROR,
  192. ]
  193. class SystemInfo(BaseModel):
  194. cpu: Optional[CPUInfo] = Field(sa_column=Column(JSON), default=None)
  195. memory: Optional[MemoryInfo] = Field(sa_column=Column(JSON), default=None)
  196. swap: Optional[SwapInfo] = Field(sa_column=Column(JSON), default=None)
  197. filesystem: Optional[FileSystemInfo] = Field(sa_column=Column(JSON), default=None)
  198. os: Optional[OperatingSystemInfo] = Field(sa_column=Column(JSON), default=None)
  199. kernel: Optional[KernelInfo] = Field(sa_column=Column(JSON), default=None)
  200. uptime: Optional[UptimeInfo] = Field(sa_column=Column(JSON), default=None)
  201. class Maintenance(BaseModel):
  202. enabled: bool = False
  203. message: Optional[str] = None
  204. class WorkerStatus(SystemInfo):
  205. """
  206. rpc_servers: Deprecated
  207. """
  208. gpu_devices: Optional[GPUDevicesStatus] = Field(
  209. sa_column=Column(JSON), default=None
  210. )
  211. rpc_servers: Optional[Dict[int, RPCServer]] = Field(
  212. sa_column=Column(JSON), default=None
  213. )
  214. model_config = ConfigDict(from_attributes=True)
  215. @classmethod
  216. def get_default_status(cls) -> 'WorkerStatus':
  217. return WorkerStatus(
  218. cpu=CPUInfo(total=0),
  219. memory=MemoryInfo(total=0, is_unified_memory=False),
  220. swap=SwapInfo(total=0),
  221. filesystem=[],
  222. os=OperatingSystemInfo(name="", version=""),
  223. kernel=KernelInfo(name="", release="", version="", architecture=""),
  224. uptime=UptimeInfo(uptime=0, boot_time=""),
  225. gpu_devices=[],
  226. )
  227. class WorkerStatusStored(BaseModel):
  228. advertise_address: Optional[str] = None
  229. hostname: str
  230. ip: str
  231. ifname: str
  232. port: int
  233. metrics_port: Optional[int] = None
  234. system_reserved: Optional[SystemReserved] = Field(
  235. default=None, sa_column=Column(pydantic_column_type(SystemReserved))
  236. )
  237. state_message: Optional[str] = Field(
  238. default=None, sa_column=Column(Text, nullable=True)
  239. )
  240. status: Optional[WorkerStatus] = Field(
  241. sa_column=Column(pydantic_column_type(WorkerStatus))
  242. )
  243. worker_uuid: str = Field(sa_column=Column(String(255), nullable=False))
  244. machine_id: Optional[str] = Field(
  245. default=None
  246. ) # The machine ID of the worker, used for identifying the worker in the cluster
  247. proxy_mode: Optional[ModelInstanceProxyModeEnum] = Field(
  248. default=ModelInstanceProxyModeEnum.WORKER,
  249. )
  250. class WorkerStatusPublic(WorkerStatusStored):
  251. gateway_endpoint: Optional[str] = None
  252. class WorkerUpdate(SQLModel):
  253. """
  254. WorkerUpdate: updatable fields for Worker
  255. """
  256. name: str = Field(index=True, unique=True)
  257. labels: Dict[str, str] = Field(sa_column=Column(JSON), default={})
  258. maintenance: Optional[Maintenance] = Field(
  259. default=None,
  260. sa_column=Column(pydantic_column_type(Maintenance), default=None),
  261. )
  262. class WorkerCreate(WorkerStatusStored, WorkerUpdate):
  263. cluster_id: Optional[int] = Field(
  264. sa_column=Column(Integer, ForeignKey("clusters.id"), nullable=False),
  265. default=None,
  266. )
  267. # Denormalized from cluster.owner_principal_id for per-row tenant
  268. # filtering. NULL = belongs to a global cluster (admin-managed).
  269. owner_principal_id: Optional[int] = Field(
  270. default=None,
  271. sa_column=Column(Integer, ForeignKey("principals.id"), nullable=True),
  272. )
  273. external_id: Optional[str] = Field(
  274. default=None, sa_column=Column(String(255), nullable=True)
  275. )
  276. worker_version: Optional[str] = Field(
  277. default=None, sa_column=Column(String(100), nullable=True)
  278. )
  279. class WorkerBase(WorkerCreate):
  280. state: WorkerStateEnum = WorkerStateEnum.NOT_READY
  281. heartbeat_time: Optional[datetime] = Field(
  282. sa_column=Column(UTCDateTime), default=None
  283. )
  284. unreachable: bool = False
  285. def compute_state(self):
  286. if self.maintenance and self.maintenance.enabled:
  287. self.state = WorkerStateEnum.MAINTENANCE
  288. self.state_message = self.maintenance.message
  289. return
  290. if self.state.is_provisioning:
  291. return
  292. if self.state == WorkerStateEnum.NOT_READY and self.state_message is not None:
  293. return
  294. is_not_ready_flag, last_heartbeat_str = is_offline(
  295. self.heartbeat_time,
  296. envs.WORKER_HEARTBEAT_GRACE_PERIOD,
  297. datetime.now(timezone.utc),
  298. )
  299. if is_not_ready_flag:
  300. reschedule_minutes = envs.MODEL_INSTANCE_RESCHEDULE_GRACE_PERIOD / 60
  301. self.state = WorkerStateEnum.NOT_READY
  302. self.state_message = (
  303. f"Heartbeat lost (last heartbeat: {last_heartbeat_str}). "
  304. f"If the worker remains unresponsive for more than {reschedule_minutes:.1f} minutes, "
  305. "the instances on this worker will be rescheduled automatically. "
  306. "If this downtime is planned maintenance, please enable maintenance mode. "
  307. "Otherwise, please <a href='https://docs.gpustack.ai/latest/troubleshooting/#view-gpustack-logs'>check the worker logs</a>."
  308. )
  309. return
  310. if self.unreachable:
  311. address = self.advertise_address or self.ip
  312. healthz_url = f"http://{address}:{self.port}/healthz"
  313. msg = (
  314. "Server cannot access the "
  315. f"worker's health check endpoint at {healthz_url}. "
  316. "Please verify the port requirements in the "
  317. "<a href='https://docs.gpustack.ai/latest/installation/requirements/#port-requirements'>documentation</a>"
  318. )
  319. self.state = WorkerStateEnum.UNREACHABLE
  320. self.state_message = msg
  321. return
  322. self.state = WorkerStateEnum.READY
  323. self.state_message = None
  324. provider: ClusterProvider = Field(default=ClusterProvider.Docker)
  325. worker_pool_id: Optional[int] = Field(
  326. default=None,
  327. sa_column=Column(Integer, ForeignKey("worker_pools.id"), nullable=True),
  328. ) # The worker pool this worker belongs to
  329. # Not setting foreign key to manage lifecycle
  330. ssh_key_id: Optional[int] = Field(
  331. default=None, sa_column=Column(Integer, nullable=True)
  332. )
  333. provider_config: Optional[Dict[str, Any]] = Field(
  334. default=None, sa_column=Column(JSON, nullable=True)
  335. )
  336. # Server side proxy field
  337. proxy_address: Optional[str] = Field(
  338. default=None, sa_column=Column(String(255), nullable=True)
  339. )
  340. @field_validator("proxy_address", mode="before")
  341. def validate_proxy_address(cls, v):
  342. if v is None:
  343. return v
  344. if not isinstance(v, str):
  345. raise ValueError("proxy_address must be a string or None")
  346. # proxy address must be in url format, e.g. http://1.2.3.4:8000
  347. result = urlparse(v)
  348. if not all([result.scheme, result.netloc]):
  349. raise ValueError("proxy_address must be a valid URL")
  350. return v
  351. def get_proxy_address(self) -> Optional[str]:
  352. """
  353. Get the proxy address for the worker. If the worker has a proxy_address, return it.
  354. Otherwise, return None to indicate that no proxy should be used.
  355. """
  356. if self.proxy_mode != ModelInstanceProxyModeEnum.TUNNEL:
  357. return None
  358. return self.proxy_address
  359. class Worker(WorkerBase, BaseModelMixin, table=True):
  360. __tablename__ = 'workers'
  361. id: Optional[int] = Field(default=None, primary_key=True)
  362. cluster: Cluster = Relationship(
  363. back_populates="cluster_workers", sa_relationship_kwargs={"lazy": "noload"}
  364. )
  365. worker_pool: Optional[WorkerPool] = Relationship(
  366. back_populates="pool_workers", sa_relationship_kwargs={"lazy": "noload"}
  367. )
  368. # This field should be replaced by x509 credential if mTLS is supported.
  369. token: Optional[str] = Field(default=None, nullable=True)
  370. @property
  371. def provision_progress(self) -> Optional[str]:
  372. """
  373. The provisioning progress should have following steps:
  374. 1. create_ssh_key
  375. 2. create_instance with created ssh_key
  376. 3. wait_for_started
  377. 4. wait_for_public_ip
  378. 5. [optional] create_volumes_and_attach
  379. """
  380. if self.state == WorkerStateEnum.INITIALIZING:
  381. return "5/5"
  382. if (
  383. self.state != WorkerStateEnum.PROVISIONING
  384. and self.state != WorkerStateEnum.PENDING
  385. ):
  386. return None
  387. format = "{}/{}"
  388. total = 5
  389. current = sum(
  390. [
  391. self.state == WorkerStateEnum.PROVISIONING,
  392. self.ssh_key_id is not None,
  393. self.external_id is not None,
  394. self.ip is not None and self.ip != "",
  395. "volume_ids" in (self.provider_config or {}),
  396. ]
  397. )
  398. return format.format(current, total)
  399. def __hash__(self):
  400. return hash(self.id)
  401. def __eq__(self, other):
  402. if super().__eq__(other) and isinstance(other, Worker):
  403. return self.id == other.id
  404. return False
  405. class WorkerListParams(ListParams):
  406. sortable_fields: ClassVar[List[str]] = [
  407. "name",
  408. "state",
  409. "ip",
  410. "status.cpu.utilization_rate",
  411. "status.memory.utilization_rate",
  412. "gpus", # gpu count, the same naming pattern as in Clusters
  413. "created_at",
  414. "updated_at",
  415. ]
  416. class WorkerPublic(
  417. WorkerBase,
  418. ):
  419. id: int
  420. created_at: datetime
  421. updated_at: datetime
  422. me: Optional[bool] = None # Indicates if the worker is the current worker
  423. provision_progress: Optional[str] = None # Indicates the provisioning progress
  424. worker_uuid: Optional[str] = Field(default=None, exclude=True)
  425. machine_id: Optional[str] = Field(default=None, exclude=True)
  426. class WorkerRegistrationPublic(WorkerPublic):
  427. token: str
  428. worker_uuid: str
  429. worker_config: Optional["PredefinedConfigNoDefaults"] = None
  430. WorkersPublic = PaginatedList[WorkerPublic]