worker.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. import asyncio
  2. from collections import defaultdict
  3. from contextlib import asynccontextmanager
  4. import logging
  5. from typing import Optional, Tuple
  6. import json
  7. import os
  8. import uuid
  9. from pathlib import Path
  10. import aiohttp
  11. from fastapi import FastAPI
  12. import setproctitle
  13. import tenacity
  14. import uvicorn
  15. from urllib.parse import urlparse
  16. from starlette.middleware.base import BaseHTTPMiddleware
  17. from gpustack_runtime.deployer.k8s.deviceplugin import (
  18. serve_async as kdp_serve_async,
  19. get_resource_injection_policy,
  20. )
  21. from gpustack.api import exceptions
  22. from gpustack.config.config import (
  23. Config,
  24. WorkerConfig,
  25. )
  26. from gpustack.schemas.config import (
  27. GatewayModeEnum,
  28. ModelInstanceProxyModeEnum,
  29. PredefinedConfigNoDefaults,
  30. )
  31. from gpustack import envs
  32. from gpustack.routes import config as route_config, debug, probes
  33. from gpustack.routes.worker import logs, proxy, filesystem, cluster_proxy
  34. from gpustack.routes.token import worker_auth
  35. from gpustack.server import catalog
  36. from gpustack.utils.network import (
  37. get_first_non_loopback_ip,
  38. get_ifname_by_ip_hostname,
  39. )
  40. from gpustack.client import ClientSet
  41. from gpustack.logging import setup_logging
  42. from gpustack.utils.process import add_signal_handlers_in_loop
  43. from gpustack.utils.system_check import check_glibc_version
  44. from gpustack.utils.task import run_periodically_in_thread
  45. from gpustack.worker.benchmark_manager import BenchmarkManager
  46. from gpustack.worker.inference_backend_manager import InferenceBackendManager
  47. from gpustack.worker.model_file_manager import ModelFileManager
  48. from gpustack.worker.runtime_metrics_aggregator import RuntimeMetricsAggregator
  49. from gpustack.worker.serve_manager import ServeManager
  50. from gpustack.worker.exporter import MetricExporter
  51. from gpustack.worker.tools_manager import ToolsManager
  52. from gpustack.worker.worker_manager import WorkerManager
  53. from gpustack.worker.collector import WorkerStatusCollector
  54. from gpustack.config.registration import read_worker_token
  55. from gpustack.config import registration
  56. from gpustack.gateway import init_async_k8s_config
  57. from gpustack.client.generated_http_client import default_versioned_prefix
  58. from gpustack.worker.workload_cleaner import WorkloadCleaner
  59. from gpustack.utils.uuid import get_worker_name, get_legacy_uuid
  60. from gpustack.websocket_proxy.message_client import MessageClient
  61. from gpustack.api.auth import BearerTokenAuthenticator
  62. logger = logging.getLogger(__name__)
  63. class Worker:
  64. _default_config: PredefinedConfigNoDefaults
  65. _clientset: ClientSet
  66. _register_clientset: ClientSet
  67. _status_collector: WorkerStatusCollector
  68. _worker_manager: WorkerManager
  69. _serve_manager: ServeManager
  70. _benchmark_manager: BenchmarkManager
  71. _workload_cleaner: WorkloadCleaner
  72. _config: Config
  73. _worker_ip: Optional[str] = None
  74. _worker_ifname: Optional[str] = None
  75. _worker_id: Optional[int] = None
  76. _worker_name: Optional[str] = None
  77. _worker_uuid: Optional[str] = None
  78. _cluster_id: Optional[int] = None
  79. def worker_ip(self) -> str:
  80. return self._config.worker_ip or self._worker_ip
  81. def worker_ifname(self) -> str:
  82. return self._config.worker_ifname or self._worker_ifname
  83. def worker_name(self) -> Optional[str]:
  84. return (
  85. self._config.worker_name
  86. or self._worker_name
  87. or get_worker_name(self._config.data_dir)
  88. )
  89. def worker_uuid(self) -> str:
  90. return self._worker_uuid or get_legacy_uuid(self._config.data_dir) or ""
  91. def worker_id(self) -> int:
  92. return self._worker_id
  93. def clientset(self) -> ClientSet:
  94. return self._clientset
  95. def cluster_id(self) -> Optional[int]:
  96. return self._cluster_id
  97. def __init__(self, cfg: Config):
  98. self._config = cfg
  99. self._is_embedded = cfg.server_role() == Config.ServerRole.BOTH
  100. self._log_dir = cfg.log_dir
  101. self._address = "0.0.0.0"
  102. self._exporter_enabled = not cfg.disable_worker_metrics
  103. self._async_tasks = []
  104. self._worker_ip, self._worker_ifname = self._detect_worker_ip_and_ifname()
  105. self._runtime_metrics_cache = defaultdict()
  106. self._status_collector = WorkerStatusCollector(
  107. cfg=cfg,
  108. worker_ip_getter=self.worker_ip,
  109. worker_ifname_getter=self.worker_ifname,
  110. worker_id_getter=self.worker_id,
  111. worker_uuid_getter=self.worker_uuid,
  112. )
  113. self._worker_manager = WorkerManager(
  114. cfg=cfg,
  115. is_embedded=self._is_embedded,
  116. collector=self._status_collector,
  117. )
  118. self._exporter = MetricExporter(
  119. cfg=cfg,
  120. collector=self._status_collector,
  121. worker_ip_getter=self.worker_ip,
  122. worker_id_getter=self.worker_id,
  123. worker_name_getter=self.worker_name,
  124. clientset_getter=self.clientset,
  125. cache=self._runtime_metrics_cache,
  126. )
  127. self._serve_manager = ServeManager(
  128. worker_id_getter=self.worker_id,
  129. clientset_getter=self.clientset,
  130. cfg=self._config,
  131. )
  132. self._benchmark_manager = BenchmarkManager(
  133. worker_id_getter=self.worker_id,
  134. clientset_getter=self.clientset,
  135. cfg=self._config,
  136. )
  137. self._workload_cleaner = WorkloadCleaner(
  138. worker_id_getter=self.worker_id,
  139. clientset_getter=self.clientset,
  140. )
  141. migrate_worker_name(cfg)
  142. @tenacity.retry(
  143. stop=tenacity.stop_after_attempt(10),
  144. wait=tenacity.wait_fixed(3),
  145. reraise=True,
  146. before_sleep=lambda retry_state: logger.debug(
  147. f"Retrying to get worker ID (attempt {retry_state.attempt_number}) due to: {retry_state.outcome.exception()}"
  148. ),
  149. )
  150. async def _register(self):
  151. self._clientset, self._default_config = (
  152. await self._worker_manager.register_with_server()
  153. )
  154. # Worker ID is available after the worker registration.
  155. worker_list = self._clientset.workers.list(
  156. params={"me": 'true'},
  157. )
  158. name = self.worker_name() or "<not specified>"
  159. if len(worker_list.items) != 1:
  160. raise Exception(f"Worker {name} not registered.")
  161. self._worker_id = worker_list.items[0].id
  162. self._cluster_id = worker_list.items[0].cluster_id
  163. self._worker_name = worker_list.items[0].name
  164. self._worker_uuid = worker_list.items[0].worker_uuid
  165. def _create_async_task(self, coro):
  166. self._async_tasks.append(asyncio.create_task(coro))
  167. def start(self):
  168. setup_logging(self._config.debug)
  169. if self._is_embedded:
  170. setproctitle.setproctitle("gpustack_worker")
  171. check_glibc_version()
  172. init_async_k8s_config(cfg=self._config)
  173. tools_manager = ToolsManager(
  174. tools_download_base_url=self._config.tools_download_base_url,
  175. pipx_path=self._config.pipx_path,
  176. data_dir=self._config.data_dir,
  177. bin_dir=self._config.bin_dir,
  178. )
  179. tools_manager.prepare_tools()
  180. catalog.prepare_chat_templates(self._config.data_dir)
  181. try:
  182. asyncio.run(self.start_async())
  183. except (KeyboardInterrupt, asyncio.CancelledError):
  184. pass
  185. except Exception as e:
  186. logger.error(f"Error serving worker APIs: {e}")
  187. finally:
  188. logger.info("Worker has shut down.")
  189. def log_worker_config(self):
  190. fields = {
  191. k: v
  192. for k, v in self._config.model_dump(
  193. exclude_none=True,
  194. exclude_unset=True,
  195. exclude_defaults=True,
  196. exclude={'token'},
  197. ).items()
  198. if k in WorkerConfig.model_fields
  199. }
  200. hf_token = fields.get("huggingface_token", None)
  201. if hf_token is not None:
  202. fields["huggingface_token"] = "*" * len(hf_token)
  203. logger.info(
  204. "Worker starting with config: %s",
  205. json.dumps(fields, indent=2, ensure_ascii=False),
  206. )
  207. async def start_async(self):
  208. """
  209. Start the worker.
  210. """
  211. logger.info("Starting GPUStack worker.")
  212. add_signal_handlers_in_loop()
  213. # Check version compatibility with server before registration
  214. await self._worker_manager.check_server_version()
  215. await self._register()
  216. self._config.reload_worker_config(self._default_config)
  217. self.log_worker_config()
  218. if self._exporter_enabled:
  219. # Start the runtime metrics cacher.
  220. _runtime_metrics_aggregator = RuntimeMetricsAggregator(
  221. cache=self._runtime_metrics_cache,
  222. worker_id_getter=self.worker_id,
  223. clientset=self._clientset,
  224. )
  225. run_periodically_in_thread(_runtime_metrics_aggregator.aggregate, 3, 30)
  226. # Start the metric exporter with retry.
  227. run_periodically_in_thread(self._exporter.start, 15)
  228. # Monitor the ip change, if not fixed.
  229. if self._config.worker_ip is None or self._config.worker_ifname is None:
  230. # Check worker ip change every 15 seconds.
  231. run_periodically_in_thread(self._check_worker_ip_change, 15)
  232. # Send heartbeat to the server every WORKER_HEARTBEAT_INTERVAL seconds.
  233. run_periodically_in_thread(self._heartbeat, envs.WORKER_HEARTBEAT_INTERVAL)
  234. # Report the worker node status to the server every WORKER_STATUS_SYNC_INTERVAL seconds.
  235. run_periodically_in_thread(
  236. self._worker_manager.sync_worker_status, envs.WORKER_STATUS_SYNC_INTERVAL
  237. )
  238. # Start the worker server to expose APIs.
  239. self._create_async_task(self._serve_apis())
  240. inference_backend_manager = InferenceBackendManager(self._clientset)
  241. # Start InferenceBackend listener to cache backend data
  242. self._create_async_task(inference_backend_manager.start_listener())
  243. # Trigger cache refresh
  244. registration.determine_default_registry(
  245. self._config.system_default_container_registry
  246. )
  247. self._serve_manager._inference_backend_manager = inference_backend_manager
  248. run_periodically_in_thread(
  249. self._serve_manager.sync_model_instances_state,
  250. envs.MODEL_INSTANCE_HEALTH_CHECK_INTERVAL,
  251. )
  252. # Use a short fixed loop interval so that per-model intervals
  253. # shorter than the global default can still be honoured.
  254. run_periodically_in_thread(
  255. self._serve_manager.sync_model_instances_inference_health,
  256. 10,
  257. )
  258. run_periodically_in_thread(
  259. self._workload_cleaner.cleanup_orphan_workloads, 120, 15
  260. )
  261. run_periodically_in_thread(self._benchmark_manager.sync_benchmark_state, 3, 15)
  262. self._create_async_task(self._serve_manager.watch_models())
  263. self._create_async_task(self._serve_manager.watch_model_instances_event())
  264. self._create_async_task(self._serve_manager.watch_model_instances())
  265. self._create_async_task(self._benchmark_manager.watch_benchmarks_event())
  266. model_file_manager = ModelFileManager(
  267. worker_id=self._worker_id, clientset=self._clientset, cfg=self._config
  268. )
  269. self._create_async_task(model_file_manager.watch_model_files())
  270. # Start Kubernetes Device Plugin server if allowed.
  271. if get_resource_injection_policy() == "kdp":
  272. self._create_async_task(kdp_serve_async(stop_event=asyncio.Event()))
  273. if self._config.proxy_mode == ModelInstanceProxyModeEnum.TUNNEL:
  274. docker_sock = Path("/var/run/docker.sock")
  275. sockets = [str(docker_sock)] if docker_sock.exists() else []
  276. # Start websocket proxy message client to handle CONNECT_REQUEST from server
  277. self._message_client = MessageClient(
  278. server_endpoint=self._config.get_server_url(),
  279. client_id=uuid.UUID(self.worker_uuid()),
  280. cidrs=[f"{self.worker_ip()}/32"] if self.worker_ip() else [],
  281. unix_sockets=sockets,
  282. authenticator=BearerTokenAuthenticator(headers=self._clientset.headers),
  283. )
  284. self._create_async_task(self._message_client.run())
  285. else:
  286. self._message_client = None
  287. # wait for a while to let other tasks start
  288. await asyncio.sleep(0.5)
  289. logger.info("GPUStack worker startup completed.")
  290. await asyncio.gather(*self._async_tasks)
  291. async def _serve_apis(self):
  292. """
  293. Start the worker server to expose APIs.
  294. """
  295. @asynccontextmanager
  296. async def lifespan(app: FastAPI):
  297. connector = aiohttp.TCPConnector(
  298. limit=envs.TCP_CONNECTOR_LIMIT,
  299. force_close=True,
  300. )
  301. app.state.http_client = aiohttp.ClientSession(
  302. connector=connector, trust_env=True
  303. )
  304. app.state.http_client_no_proxy = aiohttp.ClientSession(connector=connector)
  305. yield
  306. await app.state.http_client.close()
  307. await app.state.http_client_no_proxy.close()
  308. kube_session = getattr(app.state, "kube_api_session", None)
  309. if kube_session is not None and not kube_session.closed:
  310. await kube_session.close()
  311. app = FastAPI(
  312. title="GPUStack Worker",
  313. response_model_exclude_unset=True,
  314. lifespan=lifespan,
  315. )
  316. app.state.config = self._config
  317. app.state.token = read_worker_token(self._config.data_dir)
  318. app.state.worker_ip_getter = self.worker_ip
  319. app.state.get_instance_port_by_model_instance_id = (
  320. self._serve_manager.get_instance_port_by_model_instance_id
  321. )
  322. app.state.record_successful_inference = (
  323. self._serve_manager.record_successful_inference
  324. )
  325. app.add_middleware(BaseHTTPMiddleware, dispatch=proxy.set_port_from_model_name)
  326. app.include_router(route_config.router, prefix=default_versioned_prefix)
  327. app.include_router(debug.router, prefix="/debug")
  328. app.include_router(probes.router)
  329. app.include_router(logs.router)
  330. app.include_router(proxy.router)
  331. app.include_router(filesystem.router)
  332. app.include_router(cluster_proxy.router)
  333. app.add_api_route(
  334. path="/token-auth",
  335. endpoint=worker_auth,
  336. methods=["GET"],
  337. )
  338. exceptions.register_handlers(app)
  339. config = uvicorn.Config(
  340. app,
  341. host=self._address,
  342. port=self._config.get_api_port(self._is_embedded),
  343. access_log=False,
  344. log_level="error",
  345. )
  346. setup_logging()
  347. worker_api_message = f"Serving worker APIs on {config.host}:{config.port}."
  348. if not self._is_embedded:
  349. logger.debug(worker_api_message)
  350. logger.info(f"Worker gateway mode: {self._config.gateway_mode.value}.")
  351. if self._config.gateway_mode == GatewayModeEnum.embedded:
  352. logger.info(f"Serving worker on {self._config.get_gateway_port()}.")
  353. else:
  354. logger.info(worker_api_message)
  355. server = uvicorn.Server(config)
  356. await server.serve()
  357. def _detect_worker_ip_and_ifname(self) -> Tuple[Optional[str], Optional[str]]:
  358. """
  359. Detect the worker IP and ifname.
  360. """
  361. static_worker_ip = self._config.worker_ip
  362. static_worker_ifname = self._config.worker_ifname
  363. detected_ifname = None
  364. detected_ip = None
  365. if static_worker_ip is not None and static_worker_ifname is not None:
  366. pass
  367. elif static_worker_ip is not None:
  368. # if ip is set, use it to detect ifname
  369. detected_ifname = get_ifname_by_ip_hostname(static_worker_ip)
  370. elif static_worker_ifname is not None:
  371. # if ifname is set, used it to detect ip
  372. detected_ip = get_first_non_loopback_ip(
  373. expected_ifname=static_worker_ifname
  374. )
  375. else:
  376. # detect both ip and ifname
  377. # detect_ifname may be None if the hostname resolves to a loopback address.
  378. # This typically happens when the worker and server run on the same host, or for embedded workers.
  379. detected_ifname = get_ifname_by_ip_hostname(
  380. urlparse(self._config.get_server_url()).hostname
  381. )
  382. try:
  383. # if the expected_ifname is none, it will scan all interfaces
  384. detected_ip = get_first_non_loopback_ip(expected_ifname=detected_ifname)
  385. except Exception:
  386. logger.warning(
  387. f"Failed to detect worker IP from interface {detected_ifname}. Using first non-loopback IP."
  388. )
  389. # avoid edge case where the detected_ifname has no valid IPv4 address
  390. detected_ip = get_first_non_loopback_ip()
  391. if detected_ifname is None:
  392. detected_ifname = get_ifname_by_ip_hostname(detected_ip)
  393. return detected_ip, detected_ifname
  394. def _check_worker_ip_change(self):
  395. """
  396. Detect if the worker IP has changed. If so, delete legacy model
  397. instances so they can be recreated with the new worker IP.
  398. """
  399. new_ip, new_ifname = self._detect_worker_ip_and_ifname()
  400. old_ip, old_ifname = self._worker_ip, self._worker_ifname
  401. if new_ip == old_ip and new_ifname == old_ifname:
  402. return
  403. logger.info(
  404. f"Worker IP changed from {old_ip}({old_ifname}) to {new_ip}({new_ifname})"
  405. )
  406. self._worker_ip = new_ip
  407. self._worker_ifname = new_ifname
  408. self._worker_manager.sync_worker_status()
  409. for instance in self._clientset.model_instances.list(
  410. params={"worker_id": str(self._worker_id)}
  411. ).items:
  412. self._clientset.model_instances.delete(instance.id)
  413. if self._message_client:
  414. loop = asyncio.get_event_loop()
  415. asyncio.run_coroutine_threadsafe(
  416. self._message_client.update_cidrs([f"{self._worker_ip}/32"]),
  417. loop,
  418. )
  419. def _heartbeat(self):
  420. """
  421. Send heartbeat to the server to indicate the worker is alive.
  422. """
  423. if self._worker_id is None:
  424. logger.debug("Worker ID is not set, skipping heartbeat.")
  425. return
  426. try:
  427. resp = self._clientset.http_client.get_httpx_client().post(
  428. "/worker-heartbeat", json={}
  429. )
  430. if resp.status_code != 204:
  431. logger.error(
  432. f"Failed to send heartbeat to server, status code: {resp.status_code}"
  433. )
  434. except Exception as e:
  435. logger.error(f"Failed to send heartbeat to server: {e}")
  436. def migrate_worker_name(cfg: Config):
  437. """
  438. Based on the situation that registration of worker changed in version v2.0, v2.1, we need to
  439. recreate the worker_name file doesn't exist. Following are the files involved in the migration:
  440. | File Name | < v0.7 | ~ v0.7 | ~ v2.0 | ~ v2.1 |
  441. | ------------ | ------ | ------ | ------ | ------ |
  442. | worker_name | Y | Y | Opt | YS |
  443. | worker_uuid | | Y | Opt | YS |
  444. | worker_token | | | YS | YS |
  445. Y means the file must exist and the content is generated locally.
  446. Opt means the file may exist depends on the startup configuration.
  447. YS means the file must exist and the content is generated from server.
  448. When upgrading from v2.0 to v2.1, if the worker_name file doesn't exist, we need to migrate
  449. the worker_name from configuration to the worker_name file. This is because in v2.0, the worker_name won't be written
  450. to file if the worker is started with `--worker-name` argument or `GPUSTACK_WORKER_NAME` environment variable.
  451. In the end, we can generate the worker_name file based on the existance of the worker_token file and the worker_name configuration.
  452. """
  453. worker_name_file = os.path.join(cfg.data_dir, "worker_name")
  454. worker_token_file = os.path.join(cfg.data_dir, "worker_token")
  455. if not os.path.exists(worker_name_file) and os.path.exists(worker_token_file):
  456. if cfg.worker_name:
  457. with open(worker_name_file, "w") as f:
  458. f.write(cfg.worker_name)
  459. else:
  460. raise RuntimeError("worker_name not found for v2.0 upgrade")