worker_manager.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import os
  2. import logging
  3. from typing import Optional, Tuple
  4. import httpx
  5. from gpustack import __version__, __git_commit__
  6. from gpustack.client import ClientSet
  7. from gpustack.client.worker_manager_clients import (
  8. WorkerStatusClient,
  9. WorkerRegistrationClient,
  10. )
  11. from gpustack.config.config import Config
  12. from gpustack.schemas.workers import (
  13. WorkerCreate,
  14. WorkerUpdate,
  15. WorkerRegistrationPublic,
  16. )
  17. from gpustack.schemas.config import PredefinedConfigNoDefaults
  18. from gpustack.security import API_KEY_PREFIX
  19. from gpustack.utils import platform
  20. from gpustack.worker.collector import WorkerStatusCollector
  21. from gpustack.config.registration import (
  22. registration_client,
  23. read_worker_token,
  24. write_worker_token,
  25. )
  26. from gpustack.utils.uuid import (
  27. set_worker_name,
  28. get_worker_name,
  29. set_legacy_uuid,
  30. get_legacy_uuid,
  31. )
  32. from gpustack.utils.version import is_worker_version_compatible
  33. logger = logging.getLogger(__name__)
  34. class WorkerManager:
  35. _is_embedded: bool
  36. _collector: WorkerStatusCollector
  37. _clientset: Optional[ClientSet] = None
  38. _registration_client: WorkerRegistrationClient
  39. _status_client: WorkerStatusClient
  40. def __init__(
  41. self,
  42. cfg: Config,
  43. is_embedded: bool,
  44. collector: WorkerStatusCollector,
  45. ):
  46. self._is_embedded = is_embedded
  47. self._cfg = cfg
  48. self._collector = collector
  49. worker_token = read_worker_token(self._cfg.data_dir)
  50. if worker_token:
  51. self._prepare_clients(worker_token)
  52. def _prepare_clients(self, token: str):
  53. self._clientset = ClientSet(
  54. base_url=self._cfg.get_server_url(),
  55. api_key=token,
  56. )
  57. self._status_client = WorkerStatusClient(self._clientset.http_client)
  58. def sync_worker_status(self):
  59. """
  60. Should be called periodically to sync the worker node status with the server.
  61. It registers the worker node with the server if necessary.
  62. """
  63. if self._status_client is None:
  64. return
  65. try:
  66. workerStatus = self._collector.timed_collect(self._clientset)
  67. except Exception as e:
  68. logger.error(f"Failed to collect status for worker: {e}")
  69. return
  70. try:
  71. self._status_client.create(workerStatus)
  72. except Exception as e:
  73. logger.error(f"Failed to update worker status: {e}")
  74. async def register_with_server(
  75. self,
  76. ) -> Tuple[ClientSet, Optional[PredefinedConfigNoDefaults]]:
  77. # always re-register the worker and retrive the token and config
  78. try:
  79. worker_registerred = await self._register_worker()
  80. token = worker_registerred.token
  81. write_worker_token(self._cfg.data_dir, token)
  82. self._prepare_clients(token)
  83. return self._clientset, worker_registerred.worker_config
  84. except Exception as e:
  85. logger.error(f"Failed to register worker: {e}")
  86. raise
  87. async def _register_worker(self) -> WorkerRegistrationPublic:
  88. name = self._cfg.worker_name or get_worker_name(self._cfg.data_dir)
  89. logger.info(
  90. f"Registering worker with name: {name or '<auto-generated-name>'}",
  91. )
  92. if self._is_embedded:
  93. # always reloads the token
  94. self._cfg.reload_token()
  95. self._registration_client = registration_client(
  96. data_dir=self._cfg.data_dir,
  97. server_url=self._cfg.get_server_url(),
  98. registration_token=self._cfg.token,
  99. wait_token_file=self._is_embedded,
  100. )
  101. external_id = None
  102. external_id_path = os.path.join(self._cfg.data_dir, 'external_id')
  103. if os.path.exists(external_id_path):
  104. with open(os.path.join(self._cfg.data_dir, 'external_id'), 'r') as f:
  105. external_id = f.read()
  106. workerStatus = self._collector.timed_collect(initial=True)
  107. # Set empty name if not specified to avoid validation error
  108. workerUpdate = WorkerUpdate(
  109. name=name or "",
  110. labels=self._ensure_builtin_labels(),
  111. )
  112. to_register = WorkerCreate.model_validate(
  113. {
  114. **workerStatus.model_dump(),
  115. **workerUpdate.model_dump(),
  116. "external_id": external_id,
  117. "worker_version": __version__,
  118. }
  119. )
  120. created = await self._registration_client.create_async(to_register)
  121. logger.info(f"Worker {created.name} registered with worker_id {created.id}.")
  122. set_worker_name(self._cfg.data_dir, created.name)
  123. set_legacy_uuid(self._cfg.data_dir, created.worker_uuid)
  124. return created
  125. def _register_shutdown_hooks(self):
  126. pass
  127. def _ensure_builtin_labels(self) -> dict:
  128. labels = {
  129. "os": platform.system(),
  130. "arch": platform.arch(),
  131. }
  132. # worker name label will be set during registration
  133. name = self._cfg.worker_name or get_worker_name(self._cfg.data_dir)
  134. if name:
  135. labels["worker-name"] = name
  136. # Legacy workers with version 0.7.x send worker_uuid as part of registration.
  137. # Legacy workers with version <0.7.x don't have worker_uuid, so we use this label as part of the registration allowance.
  138. is_legacy_token = self._cfg.token and not self._cfg.token.startswith(
  139. API_KEY_PREFIX
  140. )
  141. is_legacy_worker = get_legacy_uuid(self._cfg.data_dir) is None
  142. is_existing_worker = get_worker_name(self._cfg.data_dir) is not None
  143. if (is_legacy_token or is_legacy_worker) and is_existing_worker:
  144. labels["gpustack.existence-check"] = "true"
  145. return labels
  146. async def _fetch_server_version(self) -> Optional[dict]:
  147. """
  148. Fetch the server version from the /version endpoint.
  149. Returns None if the endpoint is not available (e.g., old server).
  150. """
  151. server_url = self._cfg.get_server_url()
  152. version_url = f"{server_url}/version"
  153. try:
  154. async with httpx.AsyncClient(timeout=5.0) as client:
  155. response = await client.get(version_url)
  156. if response.status_code == 404:
  157. logger.warning(
  158. "Server does not support version check. "
  159. "Please upgrade your server to enable version verification."
  160. )
  161. return None
  162. response.raise_for_status()
  163. return response.json()
  164. except httpx.TimeoutException:
  165. logger.error(f"Timeout while fetching server version from {version_url}")
  166. raise
  167. except Exception as e:
  168. logger.error(f"Failed to fetch server version: {e}")
  169. raise
  170. async def check_server_version(self):
  171. """
  172. Check if the worker version is compatible with the server version.
  173. Logs a warning if versions are incompatible but allows startup to continue.
  174. """
  175. try:
  176. server_version_info = await self._fetch_server_version()
  177. except Exception as e:
  178. logger.error(f"Server version check failed: {e}")
  179. return
  180. if server_version_info is None:
  181. # Old server without version endpoint - warn and continue
  182. logger.warning(
  183. "Server version unknown. Proceeding without version check. "
  184. "Consider upgrading your server."
  185. )
  186. return
  187. server_version = server_version_info.get("version", "unknown")
  188. server_git_commit = server_version_info.get("git_commit", "unknown")
  189. is_compatible = is_worker_version_compatible(__version__, server_version)
  190. if not is_compatible:
  191. warning_msg = (
  192. f"Version mismatch detected:\n"
  193. f" Worker version: {__version__} (commit: {__git_commit__})\n"
  194. f" Server version: {server_version} (commit: {server_git_commit})\n\n"
  195. f"Please upgrade your worker to match the server version."
  196. )
  197. logger.warning(warning_msg)
  198. else:
  199. logger.info(
  200. f"Version check passed: worker {__version__} matches server {server_version}"
  201. )