serve_manager.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633
  1. import asyncio
  2. import contextlib
  3. from datetime import datetime, timezone
  4. import multiprocessing
  5. import re
  6. import threading
  7. import time
  8. import requests
  9. import setproctitle
  10. import os
  11. from typing import Dict, Optional, Set, List, Callable
  12. from pathlib import Path
  13. import logging
  14. from gpustack_runtime.deployer import (
  15. get_workload,
  16. WorkloadStatusStateEnum,
  17. delete_workload,
  18. logs_workload,
  19. )
  20. from gpustack_runtime.deployer.__utils__ import compare_versions
  21. from gpustack.api.exceptions import NotFoundException
  22. from gpustack.config.config import Config
  23. from gpustack.config import registration
  24. from gpustack.logging import (
  25. RedirectStdoutStderr,
  26. )
  27. from gpustack.schemas.inference_backend import (
  28. InferenceBackend,
  29. is_built_in_backend,
  30. is_custom_backend,
  31. )
  32. from gpustack.utils import network
  33. from gpustack.utils.convert import safe_int
  34. from gpustack.utils.attrs import set_attr
  35. from gpustack.utils.command import find_int_parameter
  36. from gpustack.utils.process import terminate_process_tree, add_signal_handlers
  37. from gpustack.worker.backends.ascend_mindie import AscendMindIEServer
  38. from gpustack.worker.backends.sglang import SGLangServer
  39. from gpustack.worker.backends.vllm import VLLMServer
  40. from gpustack.worker.backends.vox_box import VoxBoxServer
  41. from gpustack.worker.backends.custom import CustomServer
  42. from gpustack.routes.worker.logs import (
  43. extract_container_restart_count,
  44. extract_restart_count,
  45. )
  46. from gpustack.worker.model_meta import get_meta_from_running_instance
  47. from gpustack.client import ClientSet
  48. from gpustack.schemas.models import (
  49. BackendEnum,
  50. Model,
  51. ModelUpdate,
  52. ModelInstance,
  53. ModelInstanceUpdate,
  54. ModelInstanceStateEnum,
  55. get_backend,
  56. DistributedServerCoordinateModeEnum,
  57. ModelInstanceSubordinateWorker,
  58. CategoryEnum,
  59. )
  60. from gpustack.server.bus import Event, EventType
  61. from gpustack.worker.inference_backend_manager import InferenceBackendManager
  62. logger = logging.getLogger(__name__)
  63. # Inference health check error message
  64. _INFERENCE_HEALTH_CHECK_FAILED_MESSAGE = "Inference health check failed."
  65. # Global lock for port assignment to avoid pickle serialization issues
  66. _port_lock = threading.Lock()
  67. _SERVER_CLASS_MAPPING = {
  68. BackendEnum.VLLM: VLLMServer,
  69. BackendEnum.SGLANG: SGLangServer,
  70. BackendEnum.VOX_BOX: VoxBoxServer,
  71. BackendEnum.ASCEND_MINDIE: AscendMindIEServer,
  72. }
  73. class ServeManager:
  74. @property
  75. def _worker_id(self) -> int:
  76. return self._worker_id_getter()
  77. """
  78. The ID of current worker.
  79. """
  80. _config: Config
  81. """
  82. Global configuration.
  83. """
  84. _serve_log_dir: str
  85. """
  86. The directory to store logs of serving model instances(in subprocess).
  87. """
  88. @property
  89. def _clientset(self) -> ClientSet:
  90. return self._clientset_getter()
  91. """
  92. The clientset to access the API server.
  93. """
  94. _inference_backend_manager: InferenceBackendManager
  95. """
  96. The inference backend manager.
  97. """
  98. _provisioning_processes: Dict[int, multiprocessing.Process]
  99. """
  100. The mapping of model instance ID to provisioning (sub)process.
  101. When the (sub)process is alive, the model instance is provisioning.
  102. If the (sub)process exited, the model instance is either running or failed.
  103. """
  104. _log_persistence_threads: Dict[int, List[threading.Thread]]
  105. """
  106. The mapping of model instance ID to log persistence threads.
  107. Each model instance may have multiple threads (one per loggable container).
  108. """
  109. _log_persistence_stop_events: Dict[int, List[threading.Event]]
  110. """
  111. The mapping of model instance ID to stop events for log persistence threads.
  112. Used to signal threads to stop gracefully.
  113. """
  114. _error_model_instances: Dict[int, ModelInstance]
  115. """
  116. The mapping of model instance ID to error model instances.
  117. Used to restart error model instances.
  118. """
  119. _model_cache_by_instance: Dict[int, Model]
  120. """
  121. The cache of models by model instance ID.
  122. Used to avoid redundant API calls to get model information.
  123. """
  124. _model_instance_by_instance_id: Dict[int, ModelInstance]
  125. _clientset_getter: Callable[[], ClientSet]
  126. _worker_id_getter: Callable[[], int]
  127. def __init__(
  128. self,
  129. worker_id_getter: Callable[[], int],
  130. clientset_getter: Callable[[], ClientSet],
  131. cfg: Config,
  132. ):
  133. self._worker_id_getter = worker_id_getter
  134. self._config = cfg
  135. self._serve_log_dir = f"{cfg.log_dir}/serve"
  136. self._clientset_getter = clientset_getter
  137. self._provisioning_processes = {}
  138. self._log_persistence_threads = {}
  139. self._log_persistence_stop_events = {}
  140. self._error_model_instances = {}
  141. self._model_cache_by_instance = {}
  142. self._model_instance_by_instance_id = {}
  143. # Instance-level port tracking to avoid conflicts
  144. self._assigned_ports: Dict[int, Set[int]] = {}
  145. self._restart_backoff_counts: Dict[int, int] = {}
  146. # Inference health check failure tracking
  147. # {model_instance_id: failure_count}
  148. self._inference_health_check_failures: Dict[int, int] = {}
  149. # Track last successful inference per port (set by worker proxy)
  150. self._last_successful_inference: Dict[int, float] = {}
  151. # Track last health check time per model instance
  152. self._last_health_check_time: Dict[int, float] = {}
  153. os.makedirs(self._serve_log_dir, exist_ok=True)
  154. def record_successful_inference(self, instance_id: int):
  155. """Called by worker proxy on successful inference response."""
  156. self._last_successful_inference[instance_id] = time.time()
  157. async def watch_models(self):
  158. """
  159. Loop to watch models to keep the cache updated.
  160. """
  161. logger.debug("Watching models.")
  162. while True:
  163. try:
  164. # Watch models without callback to keep the cache updated.
  165. await self._clientset.models.awatch(callback=None)
  166. except asyncio.CancelledError:
  167. break
  168. except Exception as e:
  169. logger.error(f"Error watching models: {e}")
  170. await asyncio.sleep(5)
  171. async def watch_model_instances_event(self):
  172. """
  173. Loop to watch model instances' event and handle.
  174. """
  175. logger.debug("Watching model instances event.")
  176. while True:
  177. try:
  178. await self._clientset.model_instances.awatch(
  179. callback=self._handle_model_instance_event
  180. )
  181. except asyncio.CancelledError:
  182. break
  183. except Exception as e:
  184. logger.error(f"Error watching model instances: {e}")
  185. await asyncio.sleep(5)
  186. async def watch_model_instances(self):
  187. """
  188. Loop to post process model instances, for example, restarting error instances.
  189. """
  190. logger.debug("Watching model instances.")
  191. while True:
  192. try:
  193. for mi in list(self._error_model_instances.values()):
  194. self._restart_error_model_instance(mi)
  195. await asyncio.sleep(10)
  196. except Exception as e:
  197. logger.error(f"Error restarting model instances: {e}")
  198. await asyncio.sleep(5)
  199. def sync_model_instances_state(self): # noqa: C901
  200. """
  201. Synchronize model instances' state.
  202. - If the model instance is scheduled but not initialized, skip.
  203. - If the provision process is still alive, skip.
  204. - If the workload is still launching, skip.
  205. - If the workload is not existed, unhealthy, inactive or failed, update the model instance state to ERROR.
  206. - If everything is fine, update the model instance state to RUNNING.
  207. """
  208. # Get all model instances assigned to this worker.
  209. #
  210. # FIXME(thxCode): This may cause performance issues when there are many model instances in the system.
  211. # A mechanism is needed to improve efficiency here.
  212. model_instances_page = self._clientset.model_instances.list(use_cache=False)
  213. if not model_instances_page.items:
  214. return
  215. model_instances: List[ModelInstance] = []
  216. for model_instance in model_instances_page.items:
  217. # if the model instance is assigned to this worker, it must be scheduled.
  218. # But we don't need to sync the scheduled model when it is not initialized yet.
  219. if (
  220. model_instance.worker_id == self._worker_id
  221. and model_instance.state != ModelInstanceStateEnum.SCHEDULED
  222. ):
  223. model_instances.append(model_instance)
  224. if (
  225. model_instance.distributed_servers
  226. and model_instance.distributed_servers.subordinate_workers
  227. ):
  228. for sw in model_instance.distributed_servers.subordinate_workers:
  229. if sw.worker_id == self._worker_id:
  230. model_instances.append(model_instance)
  231. break
  232. for model_instance in model_instances:
  233. # Skip if the provision process has not exited yet.
  234. if self._is_provisioning(model_instance):
  235. logger.trace(
  236. f"Model instance {model_instance.name} is provisioning. Skipping sync."
  237. )
  238. continue
  239. is_main_worker = model_instance.worker_id == self._worker_id
  240. # Skip if the workload is still launching.
  241. # Use deployment metadata name for subordinate workers (e.g., "model-f0")
  242. # since their workload name differs from the model instance name.
  243. if is_main_worker:
  244. workload = get_workload(model_instance.name)
  245. else:
  246. deployment_metadata = model_instance.get_deployment_metadata(
  247. self._worker_id
  248. )
  249. workload_name = (
  250. deployment_metadata.name
  251. if deployment_metadata
  252. else model_instance.name
  253. )
  254. workload = get_workload(workload_name)
  255. if workload and workload.state in [
  256. WorkloadStatusStateEnum.PENDING,
  257. WorkloadStatusStateEnum.INITIALIZING,
  258. ]:
  259. logger.trace(
  260. f"Model instance {model_instance.name} workload is still launching. Skipping sync."
  261. )
  262. continue
  263. # Update model instance state to ERROR if the workload is not existed, unhealthy, inactive or failed.
  264. if not workload or workload.state in [
  265. WorkloadStatusStateEnum.UNKNOWN, # Rare, but possible, for example, leaving pause container.
  266. WorkloadStatusStateEnum.UNHEALTHY,
  267. WorkloadStatusStateEnum.INACTIVE,
  268. WorkloadStatusStateEnum.FAILED,
  269. ]:
  270. # Only if not in ERROR state yet.
  271. if model_instance.state != ModelInstanceStateEnum.ERROR:
  272. with contextlib.suppress(NotFoundException):
  273. # Get patch dict for main worker.
  274. if is_main_worker:
  275. patch_dict = {
  276. "state": ModelInstanceStateEnum.ERROR,
  277. "state_message": "Inference server exited or unhealthy.",
  278. }
  279. # Get patch dict for subordinate worker.
  280. else:
  281. sw_pos = next(
  282. (
  283. i
  284. for i, sw in enumerate(
  285. model_instance.distributed_servers.subordinate_workers
  286. )
  287. if sw.worker_id == self._worker_id
  288. ),
  289. )
  290. sw = model_instance.distributed_servers.subordinate_workers[
  291. sw_pos
  292. ]
  293. sw.state = ModelInstanceStateEnum.ERROR
  294. sw.state_message = "Inference server exited or unhealthy."
  295. patch_dict = {
  296. f"distributed_servers.subordinate_workers.{sw_pos}": sw,
  297. }
  298. # Update model instance.
  299. self._update_model_instance(model_instance.id, **patch_dict)
  300. continue
  301. # Otherwise, update model instance state to RUNNING if everything is fine.
  302. model = self._get_model(model_instance)
  303. if not model.backend_version:
  304. # backend version may be empty on initialization.
  305. # try to refresh to get updated model info on syncs.
  306. model = self._refresh_model(model_instance)
  307. backend = get_backend(model)
  308. health_check_path = self._get_health_check_path(backend)
  309. if model.env and 'GPUSTACK_MODEL_HEALTH_CHECK_PATH' in model.env:
  310. # NOTE: There is no known use case for now. Keep this in case the built-in backends
  311. # introduce breaking changes and the default health check path no longer works.
  312. health_check_path = model.env['GPUSTACK_MODEL_HEALTH_CHECK_PATH']
  313. with contextlib.suppress(NotFoundException):
  314. # Get patch dict for main worker.
  315. if is_main_worker:
  316. subordinate_state = self._get_main_worker_distributed_state(
  317. model_instance
  318. )
  319. if subordinate_state is None:
  320. if model_instance.state == ModelInstanceStateEnum.RUNNING:
  321. self._restart_backoff_counts.pop(model_instance.id, None)
  322. continue
  323. if (
  324. model_instance.state == ModelInstanceStateEnum.ERROR
  325. or not is_ready(
  326. backend, model_instance, health_check_path, model
  327. )
  328. ):
  329. continue
  330. self._restart_backoff_counts.pop(model_instance.id, None)
  331. patch_dict = {
  332. "state": ModelInstanceStateEnum.RUNNING,
  333. "state_message": "",
  334. }
  335. # Fetch model meta once running.
  336. meta = get_meta_from_running_instance(
  337. model_instance, backend, model
  338. )
  339. if meta:
  340. # Some meta is set in server evaluation and should be preserved, so we update meta instead of overwrite.
  341. merged_meta = dict(model.meta or {})
  342. merged_meta.update(meta)
  343. if merged_meta != model.meta:
  344. self._update_model(model.id, meta=merged_meta)
  345. elif subordinate_state["should_update"]:
  346. patch_dict = {
  347. "state": subordinate_state["state"],
  348. "state_message": subordinate_state["state_message"],
  349. }
  350. else:
  351. continue
  352. # Get patch dict for subordinate worker.
  353. else:
  354. # For initialize later mode, the state is set to RUNNING directly,
  355. # which means the subordinate worker doesn't need to wait for the main worker to be healthy.
  356. if (
  357. model_instance.distributed_servers.mode
  358. == DistributedServerCoordinateModeEnum.INITIALIZE_LATER
  359. ):
  360. continue
  361. # Otherwise, update subordinate worker state to RUNNING.
  362. sw_pos = next(
  363. (
  364. i
  365. for i, sw in enumerate(
  366. model_instance.distributed_servers.subordinate_workers
  367. )
  368. if sw.worker_id == self._worker_id
  369. ),
  370. )
  371. sw = model_instance.distributed_servers.subordinate_workers[sw_pos]
  372. if sw.state == ModelInstanceStateEnum.RUNNING:
  373. continue
  374. sw.state = ModelInstanceStateEnum.RUNNING
  375. sw.state_message = ""
  376. patch_dict = {
  377. f"distributed_servers.subordinate_workers.{sw_pos}": sw,
  378. }
  379. # Update model instance.
  380. self._update_model_instance(model_instance.id, **patch_dict)
  381. @staticmethod
  382. def _get_main_worker_distributed_state(
  383. model_instance: ModelInstance,
  384. ) -> Optional[dict]:
  385. subordinate_workers = (
  386. model_instance.distributed_servers.subordinate_workers
  387. if (
  388. model_instance.distributed_servers
  389. and model_instance.distributed_servers.subordinate_workers
  390. )
  391. else []
  392. )
  393. if not subordinate_workers:
  394. return None
  395. error_sw = None
  396. unreachable_sw = None
  397. all_running = True
  398. for sw in subordinate_workers:
  399. if sw.state == ModelInstanceStateEnum.ERROR:
  400. error_sw = sw
  401. break
  402. if (
  403. sw.state == ModelInstanceStateEnum.UNREACHABLE
  404. and unreachable_sw is None
  405. ):
  406. unreachable_sw = sw
  407. if sw.state != ModelInstanceStateEnum.RUNNING:
  408. all_running = False
  409. if error_sw:
  410. return {
  411. "should_update": model_instance.state != ModelInstanceStateEnum.ERROR,
  412. "state": ModelInstanceStateEnum.ERROR,
  413. "state_message": (
  414. f"Distributed serving error in subordinate worker "
  415. f"{error_sw.worker_ip}: {error_sw.state_message}."
  416. ),
  417. }
  418. if unreachable_sw:
  419. return {
  420. "should_update": model_instance.state
  421. != ModelInstanceStateEnum.UNREACHABLE,
  422. "state": ModelInstanceStateEnum.UNREACHABLE,
  423. "state_message": (
  424. f"Distributed serving unreachable in subordinate worker "
  425. f"{unreachable_sw.worker_ip}: {unreachable_sw.state_message}."
  426. ),
  427. }
  428. if not all_running:
  429. return {"should_update": False}
  430. return None
  431. @staticmethod
  432. def _serve_model_instance(
  433. mi: ModelInstance,
  434. backend: BackendEnum,
  435. client_headers: dict,
  436. log_file_path: str,
  437. cfg: Config,
  438. worker_id: int,
  439. inference_backend: InferenceBackend,
  440. fallback_registry: Optional[str] = None,
  441. ):
  442. """
  443. Serve model instance in a subprocess.
  444. Exits the subprocess when serving ends.
  445. Args:
  446. mi: The model instance to serve.
  447. backend: The backend of the model instance.
  448. client_headers: The headers for the clientset.
  449. log_file_path: The path to the log file.
  450. cfg: The configuration.
  451. worker_id: The ID of the worker.
  452. inference_backend: The inference backend configuration.
  453. fallback_registry: The fallback container registry to use if needed.
  454. """
  455. setproctitle.setproctitle(f"gpustack_model_instance_{mi.id}")
  456. add_signal_handlers()
  457. clientset = ClientSet(
  458. base_url=cfg.get_server_url(),
  459. headers=client_headers,
  460. )
  461. with open(log_file_path, "w", buffering=1, encoding="utf-8") as log_file:
  462. with RedirectStdoutStderr(log_file):
  463. try:
  464. server_cls = _SERVER_CLASS_MAPPING.get(backend, CustomServer)
  465. server_ins = server_cls(
  466. clientset,
  467. mi,
  468. cfg,
  469. worker_id,
  470. inference_backend,
  471. fallback_registry,
  472. )
  473. logger.info(f"Provisioning model instance {mi.name}")
  474. server_ins.start()
  475. logger.info(f"Finished provisioning model instance {mi.name}")
  476. except Exception as e:
  477. logger.exception(
  478. f"Error provisioning model instance {mi.name}: {e}"
  479. )
  480. raise e
  481. def sync_model_instances_inference_health(self):
  482. """
  483. Synchronize model instances' inference health by sending actual inference requests.
  484. Per-model configuration is read from model.env:
  485. - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_ENABLED: "true"/"false" (default: false)
  486. - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_INTERVAL: seconds (default: global env)
  487. - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_TIMEOUT: seconds (default: 15)
  488. - GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD: count (default: global env)
  489. If the model has received successful inference traffic recently
  490. (within the configured interval), the active health check is skipped.
  491. """
  492. # Use the event-driven local cache instead of an API call.
  493. model_instances = [
  494. mi
  495. for mi in self._model_instance_by_instance_id.values()
  496. if mi.state == ModelInstanceStateEnum.RUNNING
  497. ]
  498. if not model_instances:
  499. return
  500. now = time.time()
  501. for model_instance in model_instances:
  502. model = self._get_model(model_instance)
  503. if not model:
  504. continue
  505. # Read per-model config from model.env.
  506. config = _get_inference_health_check_config(model)
  507. if not config["enabled"]:
  508. continue
  509. interval = config["interval"]
  510. timeout = config["timeout"]
  511. threshold = config["threshold"]
  512. # Skip if the model is still provisioning.
  513. if self._is_provisioning(model_instance):
  514. continue
  515. # Skip if not enough time has passed since last check.
  516. last_check = self._last_health_check_time.get(model_instance.id, 0)
  517. if now - last_check < interval:
  518. continue
  519. self._last_health_check_time[model_instance.id] = now
  520. # Skip if recent successful inference was observed for this instance.
  521. last_success = self._last_successful_inference.get(model_instance.id, 0)
  522. if last_success > now - interval:
  523. logger.debug(
  524. f"Model instance {model_instance.name} had recent successful "
  525. f"inference, skipping health check."
  526. )
  527. # Reset failure count since real traffic is succeeding.
  528. self._inference_health_check_failures.pop(model_instance.id, None)
  529. continue
  530. # Perform inference health check.
  531. if not is_inference_ready(model_instance, model, timeout=timeout):
  532. failure_count = self._inference_health_check_failures.get(
  533. model_instance.id, 0
  534. )
  535. failure_count += 1
  536. self._inference_health_check_failures[model_instance.id] = failure_count
  537. if failure_count >= threshold:
  538. logger.warning(
  539. f"Model instance {model_instance.name} inference health check failed "
  540. f"{failure_count} times, updating state to ERROR."
  541. )
  542. patch_dict = {
  543. "state": ModelInstanceStateEnum.ERROR,
  544. "state_message": _INFERENCE_HEALTH_CHECK_FAILED_MESSAGE,
  545. }
  546. self._update_model_instance(model_instance.id, **patch_dict)
  547. # Reset failure count after marking as error.
  548. del self._inference_health_check_failures[model_instance.id]
  549. else:
  550. logger.debug(
  551. f"Model instance {model_instance.name} inference health check failed "
  552. f"{failure_count}/{threshold} times."
  553. )
  554. else:
  555. # Reset failure count on success.
  556. self._inference_health_check_failures.pop(model_instance.id, None)
  557. def _handle_model_instance_event(self, event: Event): # noqa: C901
  558. """
  559. Handle model instance events.
  560. Args:
  561. event: The model instance event to handle.
  562. """
  563. mi = ModelInstance.model_validate(event.data)
  564. logger.trace(
  565. f"Received event: {str(event.type)}, id: {mi.id}, name: {mi.name}, state: {str(mi.state)}"
  566. )
  567. is_main_worker = mi.worker_id == self._worker_id
  568. if is_main_worker:
  569. self._model_instance_by_instance_id[mi.id] = mi
  570. # Return if all subordinate workers aren't running.
  571. if (
  572. mi.distributed_servers
  573. and mi.distributed_servers.mode
  574. == DistributedServerCoordinateModeEnum.RUN_FIRST
  575. and mi.distributed_servers.subordinate_workers
  576. ):
  577. ready = all(
  578. sw.state == ModelInstanceStateEnum.RUNNING
  579. for sw in mi.distributed_servers.subordinate_workers
  580. )
  581. if not ready:
  582. logger.info(
  583. f"Model instance {mi.name} waits for all subordinate workers to be ready."
  584. )
  585. return
  586. else:
  587. # Return if it isn't a distribution serving.
  588. if not mi.distributed_servers:
  589. return
  590. # Return if it's a delegated distribution,
  591. # which means the main worker is responsible for serving.
  592. if (
  593. mi.distributed_servers.mode
  594. == DistributedServerCoordinateModeEnum.DELEGATED
  595. ):
  596. return
  597. # Return if it isn't the member of the distribution serving.
  598. joined = any(
  599. sw.worker_id == self._worker_id
  600. for sw in mi.distributed_servers.subordinate_workers or []
  601. )
  602. if not joined:
  603. return
  604. # Return if the main worker isn't initialized.
  605. if (
  606. mi.distributed_servers.mode
  607. == DistributedServerCoordinateModeEnum.INITIALIZE_LATER
  608. and (
  609. mi.state
  610. not in [
  611. ModelInstanceStateEnum.STARTING,
  612. ModelInstanceStateEnum.RUNNING,
  613. ModelInstanceStateEnum.ERROR,
  614. ]
  615. )
  616. ):
  617. logger.info(
  618. f"Model instance {mi.name} waits for main worker {mi.worker_ip} to be initialized."
  619. )
  620. return
  621. # FIXME: This is a temporary solution to prevent the main worker from being unable to start due to phantom reads.
  622. # We confirm whether the operation should be performed by checking the state of the earlier subordinate worker.
  623. for sw in mi.distributed_servers.subordinate_workers:
  624. if sw.worker_id == self._worker_id:
  625. break
  626. if sw.state not in [
  627. ModelInstanceStateEnum.RUNNING,
  628. ModelInstanceStateEnum.ERROR,
  629. ]:
  630. logger.info(
  631. f"Model instance {mi.name} waits for previous subordinate worker {sw.worker_ip} to be ready."
  632. )
  633. return
  634. if event.type == EventType.DELETED:
  635. self._stop_model_instance(mi)
  636. logger.trace(f"DELETED event: stopped deleted model instance {mi.name}.")
  637. return
  638. if event.type == EventType.UPDATED:
  639. # Caching matched ERROR instances for restart handling.
  640. if mi.state == ModelInstanceStateEnum.ERROR:
  641. model = self._get_model(mi)
  642. if model.restart_on_error:
  643. self._error_model_instances[mi.id] = mi
  644. logger.trace(
  645. f"UPDATED event: cached error model instance {mi.name} for restart."
  646. )
  647. return
  648. # Restart if scheduled and this is the assigned worker.
  649. if is_main_worker and mi.state == ModelInstanceStateEnum.SCHEDULED:
  650. self._restart_model_instance(mi)
  651. logger.trace(
  652. f"UPDATED event: restarted scheduled model instance {mi.name}."
  653. )
  654. # Start on subordinate worker if not started yet, or restart if failed.
  655. if not is_main_worker:
  656. deployment_metadata = mi.get_deployment_metadata(self._worker_id)
  657. workload_name = (
  658. deployment_metadata.name if deployment_metadata else mi.name
  659. )
  660. workload = get_workload(workload_name)
  661. if not workload:
  662. self._start_model_instance(mi)
  663. logger.trace(
  664. f"UPDATED event: started model instance {mi.name} on subordinate worker."
  665. )
  666. elif workload.state in [
  667. WorkloadStatusStateEnum.UNKNOWN,
  668. WorkloadStatusStateEnum.UNHEALTHY,
  669. WorkloadStatusStateEnum.INACTIVE,
  670. WorkloadStatusStateEnum.FAILED,
  671. ]:
  672. self._stop_model_instance(mi, clear_restart_backoff=False)
  673. self._start_model_instance(mi)
  674. logger.trace(
  675. f"UPDATED event: restarted failed model instance {mi.name} on subordinate worker."
  676. )
  677. return
  678. if event.type == EventType.CREATED:
  679. # Only handle CREATED if this is the assigned worker
  680. if not is_main_worker:
  681. return
  682. if mi.state == ModelInstanceStateEnum.RUNNING:
  683. logger.warning(
  684. f"Model instance {mi.name} is already running. Skipping start."
  685. )
  686. return
  687. self._start_model_instance(mi)
  688. logger.trace(f"CREATED event: started created model instance {mi.name}.")
  689. def _get_numbered_log_path(self, mi: ModelInstance) -> str:
  690. """Get log file path with restart count.
  691. Args:
  692. mi: The model instance.
  693. Returns:
  694. Log file path with format: {log_dir}/{model_instance_id}.{restart_count}.log
  695. """
  696. restart_count = mi.restart_count or 0
  697. return f"{self._serve_log_dir}/{mi.id}.{restart_count}.log"
  698. def _persist_container_logs(
  699. self,
  700. workload_name: str,
  701. log_path: str,
  702. stop_event: threading.Event,
  703. token: Optional[str] = None,
  704. ):
  705. """Persist container logs to local file.
  706. This is a blocking operation that runs in a separate thread.
  707. Retries indefinitely until container is created.
  708. Args:
  709. workload_name: Name of the container workload
  710. log_path: Path to save container logs
  711. stop_event: Event to signal thread to stop
  712. token: Operation token identifying a specific container in the workload.
  713. If None, logs from the default (index=0) container are fetched.
  714. """
  715. retry_count = 0
  716. while not stop_event.is_set():
  717. try:
  718. log_stream = logs_workload(
  719. name=workload_name,
  720. token=token,
  721. tail=-1,
  722. follow=True,
  723. )
  724. if hasattr(log_stream, '__iter__'):
  725. with open(log_path, 'w', buffering=1, encoding='utf-8') as f:
  726. for line in log_stream:
  727. if stop_event.is_set():
  728. break
  729. if isinstance(line, bytes):
  730. f.write(line.decode('utf-8', errors='replace'))
  731. else:
  732. f.write(str(line))
  733. f.flush()
  734. break
  735. except Exception as e:
  736. if stop_event.is_set():
  737. break
  738. retry_count += 1
  739. logger.debug(
  740. f"Container not ready for {workload_name}, retrying "
  741. f"(attempt {retry_count}): {e}"
  742. )
  743. stop_event.wait(timeout=2)
  744. logger.debug(f"Log persistence thread for {workload_name} exiting")
  745. def _discover_sidecar_logs(
  746. self,
  747. mi_id: int,
  748. workload_name: str,
  749. restart_count: int,
  750. stop_event: threading.Event,
  751. ):
  752. """Background thread that waits for sidecar containers to appear.
  753. Polls get_workload() until sidecar containers are found in the
  754. loggable list, then starts log persistence threads for each.
  755. Exits when sidecars are found or stop_event is set.
  756. Args:
  757. mi_id: Model instance ID
  758. workload_name: Workload name
  759. restart_count: Current restart count for log file naming
  760. stop_event: Event to signal thread to stop
  761. """
  762. while not stop_event.is_set():
  763. try:
  764. workload = get_workload(workload_name)
  765. if workload and workload.loggable:
  766. sidecars = [op for op in workload.loggable if op.name != "default"]
  767. if sidecars:
  768. self._start_sidecar_log_threads(
  769. mi_id,
  770. workload_name,
  771. workload.loggable,
  772. restart_count,
  773. )
  774. logger.debug(f"Sidecar discovery for {workload_name} complete")
  775. return
  776. except Exception:
  777. pass
  778. stop_event.wait(timeout=2)
  779. def _start_sidecar_log_threads(
  780. self,
  781. mi_id: int,
  782. workload_name: str,
  783. loggable_ops: list,
  784. restart_count: int,
  785. ):
  786. """Start additional log persistence threads for sidecar containers.
  787. Called from the main log persistence thread once the workload is available
  788. and multiple loggable containers are discovered.
  789. Args:
  790. mi_id: Model instance ID
  791. workload_name: Workload name
  792. loggable_ops: List of WorkloadStatusOperation from workload.loggable
  793. restart_count: Current restart count for log file naming
  794. """
  795. names = []
  796. for op in loggable_ops:
  797. if op.name == "default":
  798. continue # Main container handled by caller thread
  799. log_path = (
  800. f"{self._serve_log_dir}/{mi_id}.container."
  801. f"{op.name}.{restart_count}.log"
  802. )
  803. stop_event = threading.Event()
  804. thread = threading.Thread(
  805. target=self._persist_container_logs,
  806. args=(workload_name, log_path, stop_event, op.token),
  807. daemon=True,
  808. name=f"log-persist-{workload_name}-{op.name}",
  809. )
  810. thread.start()
  811. # Append to existing tracking lists.
  812. self._log_persistence_threads.setdefault(mi_id, []).append(thread)
  813. self._log_persistence_stop_events.setdefault(mi_id, []).append(stop_event)
  814. names.append(op.name)
  815. if names:
  816. logger.debug(
  817. f"Started sidecar log persistence threads for {workload_name}: "
  818. f"{names}"
  819. )
  820. def _start_container_log_persistence(self, mi: ModelInstance):
  821. """Start a background thread to persist container logs.
  822. Starts a single "main" log persistence thread. The thread will
  823. automatically discover sidecar containers (e.g., Ray head) once
  824. the workload is created, and spawn additional threads for each.
  825. Args:
  826. mi: The model instance.
  827. """
  828. # Stop and clean up existing threads if any
  829. self._stop_container_log_persistence(mi.id)
  830. # Use deployment metadata name for the actual workload name,
  831. # which differs for subordinate workers (e.g., "model-f0").
  832. deployment_metadata = mi.get_deployment_metadata(self._worker_id)
  833. workload_name = deployment_metadata.name if deployment_metadata else mi.name
  834. restart_count = mi.restart_count or 0
  835. log_path = f"{self._serve_log_dir}/{mi.id}.container.{restart_count}.log"
  836. stop_event = threading.Event()
  837. # Main container log thread.
  838. thread = threading.Thread(
  839. target=self._persist_container_logs,
  840. args=(workload_name, log_path, stop_event),
  841. daemon=True,
  842. name=f"log-persist-{workload_name}",
  843. )
  844. thread.start()
  845. # Sidecar discovery thread — polls until sidecar containers appear,
  846. # then starts additional log threads for each.
  847. discovery_thread = threading.Thread(
  848. target=self._discover_sidecar_logs,
  849. args=(mi.id, workload_name, restart_count, stop_event),
  850. daemon=True,
  851. name=f"log-discover-{workload_name}",
  852. )
  853. discovery_thread.start()
  854. self._log_persistence_threads[mi.id] = [thread, discovery_thread]
  855. self._log_persistence_stop_events[mi.id] = [stop_event]
  856. logger.debug(f"Started container log persistence thread for {mi.name}")
  857. def _stop_container_log_persistence(
  858. self, model_instance_id: int, timeout: float = 2.0
  859. ):
  860. """Stop all container log persistence threads for a model instance.
  861. Args:
  862. model_instance_id: The model instance ID
  863. timeout: Maximum time to wait for each thread to stop (seconds)
  864. """
  865. # Signal all threads to stop
  866. stop_events = self._log_persistence_stop_events.pop(model_instance_id, [])
  867. for stop_event in stop_events:
  868. stop_event.set()
  869. # Wait for all threads to finish
  870. threads = self._log_persistence_threads.pop(model_instance_id, [])
  871. for thread in threads:
  872. if thread and thread.is_alive():
  873. thread.join(timeout=timeout)
  874. if thread.is_alive():
  875. logger.warning(
  876. f"Log persistence thread {thread.name} for model instance "
  877. f"{model_instance_id} did not stop within {timeout}s"
  878. )
  879. def _cleanup_old_logs(self, model_instance_id: int, current_restart_count: int):
  880. """Remove serve log files except the current and previous restart_count.
  881. Keeps files for restart_count in {R, R-1} where R is current_restart_count;
  882. when R is 0, only R is kept.
  883. Args:
  884. model_instance_id: Model instance ID
  885. current_restart_count: Restart count for the upcoming run (same as log path).
  886. """
  887. try:
  888. log_dir = Path(self._serve_log_dir)
  889. # Separate main logs, container logs, and sidecar container logs
  890. main_log_pattern = f"{model_instance_id}.*.log"
  891. all_main_logs = [
  892. f for f in log_dir.glob(main_log_pattern) if '.container.' not in f.name
  893. ]
  894. container_log_pattern = f"{model_instance_id}.container.*.log"
  895. all_container_files = list(log_dir.glob(container_log_pattern))
  896. # Split into default container logs (e.g., 42.container.0.log)
  897. # and sidecar container logs (e.g., 42.container.ray-head.0.log)
  898. default_container_logs = [
  899. f
  900. for f in all_container_files
  901. if extract_container_restart_count(f.name) > 0
  902. or re.match(rf'{model_instance_id}\.container\.\d+\.log', f.name)
  903. ]
  904. sidecar_container_logs = [
  905. f for f in all_container_files if f not in default_container_logs
  906. ]
  907. self._cleanup_log_type(all_main_logs, current_restart_count, "main")
  908. self._cleanup_log_type(
  909. default_container_logs, current_restart_count, "container"
  910. )
  911. self._cleanup_log_type(
  912. sidecar_container_logs, current_restart_count, "sidecar_container"
  913. )
  914. except Exception as e:
  915. logger.error(f"Failed to cleanup old logs for {model_instance_id}: {e}")
  916. def _cleanup_log_type(
  917. self,
  918. log_files: List[Path],
  919. current_restart_count: int,
  920. log_type: str,
  921. ):
  922. """Delete log files whose restart_count is not current or previous."""
  923. keep = {current_restart_count}
  924. if current_restart_count > 0:
  925. keep.add(current_restart_count - 1)
  926. def _extract_sidecar_restart_count(filename: str) -> int:
  927. """Extract restart count from {id}.container.{name}.{restart_count}.log"""
  928. match = re.match(r'\d+\.container\.[^.]+\.(\d+)\.log', filename)
  929. return int(match.group(1)) if match else 0
  930. extract_fns = {
  931. "main": extract_restart_count,
  932. "container": extract_container_restart_count,
  933. "sidecar_container": _extract_sidecar_restart_count,
  934. }
  935. extract_fn = extract_fns.get(log_type, extract_container_restart_count)
  936. for f in log_files:
  937. rc = extract_fn(f.name)
  938. if rc in keep:
  939. continue
  940. try:
  941. f.unlink()
  942. logger.info(f"Deleted old {log_type} log file: {f}")
  943. except Exception as e:
  944. logger.warning(f"Failed to delete {log_type} log file {f}: {e}")
  945. def _start_model_instance(self, mi: ModelInstance): # noqa: C901
  946. """
  947. Start model instance through a subprocess.
  948. Args:
  949. mi: The model instance to start.
  950. """
  951. if self._is_provisioning(mi):
  952. logger.warning(f"Model instance {mi.name} is provisioning. Skipping start.")
  953. return
  954. # Clean up old log files before starting
  955. self._cleanup_old_logs(mi.id, mi.restart_count or 0)
  956. is_main_worker = mi.worker_id == self._worker_id
  957. log_file_path = self._get_numbered_log_path(mi)
  958. sw_pos: Optional[int] = None
  959. sw: Optional[ModelInstanceSubordinateWorker] = None
  960. if not is_main_worker:
  961. sw_pos = next(
  962. (
  963. i
  964. for i, sw in enumerate(mi.distributed_servers.subordinate_workers)
  965. if sw.worker_id == self._worker_id
  966. ),
  967. )
  968. sw = mi.distributed_servers.subordinate_workers[sw_pos]
  969. try:
  970. model = self._get_model(mi)
  971. backend = get_backend(model)
  972. self._assign_ports(mi, model, backend)
  973. logger.debug(
  974. f"Starting model instance {mi.name}"
  975. f"{'' if not is_main_worker else f' on ports {mi.ports if mi.ports else [mi.port]}'}"
  976. )
  977. fallback_registry = (
  978. registration.determine_default_registry(
  979. self._config.system_default_container_registry
  980. )
  981. if is_built_in_backend(backend)
  982. else None
  983. )
  984. process = multiprocessing.Process(
  985. target=ServeManager._serve_model_instance,
  986. args=(
  987. mi,
  988. backend,
  989. self._clientset.headers,
  990. log_file_path,
  991. self._config,
  992. self._worker_id,
  993. self._inference_backend_manager.get_backend_by_name(backend),
  994. fallback_registry,
  995. ),
  996. )
  997. process.daemon = False
  998. process.start()
  999. self._provisioning_processes[mi.id] = process
  1000. # Start container log persistence for containerized backends
  1001. self._start_container_log_persistence(mi)
  1002. # Get patch dict for main worker.
  1003. if is_main_worker:
  1004. patch_dict = {
  1005. "state": ModelInstanceStateEnum.INITIALIZING,
  1006. "port": mi.port,
  1007. "ports": mi.ports,
  1008. "pid": process.pid,
  1009. }
  1010. # Get patch dict for subordinate worker.
  1011. else:
  1012. sw.state = ModelInstanceStateEnum.INITIALIZING
  1013. # For initialize later mode, the state is set to RUNNING directly,
  1014. # which means the subordinate worker doesn't need to wait for the main worker to be healthy.
  1015. if (
  1016. mi.distributed_servers.mode
  1017. == DistributedServerCoordinateModeEnum.INITIALIZE_LATER
  1018. ):
  1019. sw.state = ModelInstanceStateEnum.RUNNING
  1020. sw.pid = process.pid
  1021. patch_dict = {
  1022. f"distributed_servers.subordinate_workers.{sw_pos}": sw,
  1023. }
  1024. self._update_model_instance(mi.id, **patch_dict)
  1025. logger.info(
  1026. f"Started model instance {mi.name}"
  1027. f"{'' if not is_main_worker else f' on ports {mi.ports if mi.ports else [mi.port]}'}"
  1028. )
  1029. except Exception as e:
  1030. # Clean up provisioning process if started.
  1031. if mi.id in self._provisioning_processes:
  1032. self._stop_model_instance(mi)
  1033. # Get patch dict for main worker.
  1034. if is_main_worker:
  1035. patch_dict = {
  1036. "state": ModelInstanceStateEnum.ERROR,
  1037. "state_message": f"Failed to start model instance: {e}",
  1038. }
  1039. # Get patch dict for subordinate worker.
  1040. else:
  1041. sw.state = ModelInstanceStateEnum.ERROR
  1042. sw.state_message = f"Failed to start model instance: {e}"
  1043. patch_dict = {
  1044. f"distributed_servers.subordinate_workers.{sw_pos}": sw,
  1045. }
  1046. self._update_model_instance(mi.id, **patch_dict)
  1047. logger.error(f"Failed to start model instance {mi.name}: {e}")
  1048. def _assign_ports(
  1049. self,
  1050. mi: ModelInstance,
  1051. model: Model,
  1052. backend: BackendEnum,
  1053. ) -> None:
  1054. """
  1055. Assign ports to the model instance.
  1056. This method is thread-safe and allocates ports for:
  1057. - Main serving port
  1058. - RPC port for vLLM DP communication (if applicable)
  1059. - Connecting port for subordinate workers (if applicable)
  1060. Args:
  1061. mi: The model instance to assign ports to.
  1062. model: The model associated with the instance.
  1063. backend: The backend type (e.g., vLLM, SGLang).
  1064. """
  1065. if mi.port:
  1066. # Port already assigned, skip.
  1067. return
  1068. with _port_lock:
  1069. if mi.port:
  1070. # Port already assigned, skip.
  1071. return
  1072. if self._assigned_ports:
  1073. unavailable_ports = set.union(*self._assigned_ports.values())
  1074. else:
  1075. unavailable_ports = set()
  1076. # Main serving port
  1077. mi.port = network.get_free_port(
  1078. port_range=self._config.service_port_range,
  1079. unavailable_ports=unavailable_ports,
  1080. host=mi.worker_ip,
  1081. )
  1082. mi.ports = [mi.port]
  1083. unavailable_ports.add(mi.port)
  1084. # Additional ports for distributed servers
  1085. if mi.distributed_servers and mi.distributed_servers.subordinate_workers:
  1086. # RPC port for DP communication in vLLM backend
  1087. if backend == BackendEnum.VLLM:
  1088. dps = find_int_parameter(
  1089. model.backend_parameters,
  1090. ["data-parallel-size", "dp"],
  1091. )
  1092. if dps and dps > 1:
  1093. dp_connecting_port = network.get_free_port(
  1094. port_range=self._config.service_port_range,
  1095. unavailable_ports=unavailable_ports,
  1096. host=mi.worker_ip,
  1097. )
  1098. mi.ports.append(dp_connecting_port)
  1099. unavailable_ports.add(dp_connecting_port)
  1100. # Connecting port for subordinate workers communication
  1101. connecting_port = network.get_free_port(
  1102. port_range=self._config.service_port_range,
  1103. unavailable_ports=unavailable_ports,
  1104. host=mi.worker_ip,
  1105. )
  1106. mi.ports.append(connecting_port)
  1107. unavailable_ports.add(connecting_port)
  1108. self._assigned_ports[mi.id] = set(mi.ports)
  1109. def _restart_model_instance(self, mi: ModelInstance):
  1110. """
  1111. Restart model instance.
  1112. Args:
  1113. mi: The model instance to restart.
  1114. """
  1115. self._stop_model_instance(mi, clear_restart_backoff=False)
  1116. self._start_model_instance(mi)
  1117. def _update_model(self, id: int, **kwargs):
  1118. """
  1119. Update model instance with given fields.
  1120. Args:
  1121. id: The ID of the model instance to update.
  1122. **kwargs: The fields to update, group by field name and value.
  1123. """
  1124. try:
  1125. m_public = self._clientset.models.get(id=id)
  1126. m = ModelUpdate(**m_public.model_dump())
  1127. for key, value in kwargs.items():
  1128. set_attr(m, key, value)
  1129. self._clientset.models.update(id=id, model_update=m)
  1130. except NotFoundException:
  1131. logger.warning(f"Model with ID {id} not found when trying to update.")
  1132. def _update_model_instance(self, id: int, **kwargs):
  1133. """
  1134. Update model instance with given fields.
  1135. Args:
  1136. id: The ID of the model instance to update.
  1137. **kwargs: The fields to update, group by field name and value.
  1138. """
  1139. try:
  1140. mi_public = self._clientset.model_instances.get(id=id)
  1141. mi = ModelInstanceUpdate(**mi_public.model_dump())
  1142. for key, value in kwargs.items():
  1143. set_attr(mi, key, value)
  1144. self._clientset.model_instances.update(id=id, model_update=mi)
  1145. except NotFoundException:
  1146. logger.warning(
  1147. f"Model instance with ID {id} not found when trying to update."
  1148. )
  1149. def _stop_model_instance(
  1150. self, mi: ModelInstance, clear_restart_backoff: bool = True
  1151. ):
  1152. """
  1153. Stop model instance and clean up.
  1154. Args:
  1155. mi: The model instance to stop.
  1156. clear_restart_backoff: Whether to clear transient restart backoff state.
  1157. """
  1158. logger.debug(f"Stopping model instance {mi.name or mi.id}")
  1159. # Stop container log persistence thread
  1160. self._stop_container_log_persistence(mi.id)
  1161. # Teardown provisioning process if still alive.
  1162. if self._is_provisioning(mi):
  1163. terminate_process_tree(self._provisioning_processes[mi.id].pid)
  1164. # Delete workload.
  1165. deployment_metadata = mi.get_deployment_metadata(self._worker_id)
  1166. if deployment_metadata:
  1167. delete_workload(deployment_metadata.name)
  1168. # Cleanup internal states.
  1169. self._provisioning_processes.pop(mi.id, None)
  1170. self._assigned_ports.pop(mi.id, None)
  1171. self._error_model_instances.pop(mi.id, None)
  1172. self._model_cache_by_instance.pop(mi.id, None)
  1173. self._model_instance_by_instance_id.pop(mi.id, None)
  1174. if clear_restart_backoff:
  1175. self._restart_backoff_counts.pop(mi.id, None)
  1176. self._inference_health_check_failures.pop(mi.id, None)
  1177. self._last_health_check_time.pop(mi.id, None)
  1178. self._last_successful_inference.pop(mi.id, None)
  1179. logger.info(f"Stopped model instance {mi.name or mi.id}")
  1180. def _restart_error_model_instance(self, mi: ModelInstance):
  1181. """
  1182. Restart error model instance with exponential backoff,
  1183. maximum delay 5 minutes.
  1184. Args:
  1185. mi: The model instance to restart.
  1186. """
  1187. if self._is_provisioning(mi):
  1188. logger.debug(f"Model instance {mi.name} is provisioning. Skipping restart.")
  1189. return
  1190. restart_count = mi.restart_count or 0
  1191. backoff_count = self._restart_backoff_counts.get(mi.id, 0)
  1192. last_restart_time = mi.last_restart_time or mi.updated_at
  1193. current_time = datetime.now(timezone.utc)
  1194. delay = min(10 * (2 ** (backoff_count - 1)), 300) if backoff_count > 0 else 0
  1195. if backoff_count > 0 and last_restart_time:
  1196. elapsed_time = (current_time - last_restart_time).total_seconds()
  1197. if elapsed_time < delay:
  1198. logger.trace(
  1199. f"Delaying restart of {mi.name} for {delay - elapsed_time:.2f} seconds."
  1200. )
  1201. return
  1202. logger.info(
  1203. f"Restarting model instance {mi.name} "
  1204. f"(attempt {backoff_count + 1}) after {delay} seconds delay."
  1205. )
  1206. with contextlib.suppress(NotFoundException):
  1207. self._restart_backoff_counts[mi.id] = backoff_count + 1
  1208. self._update_model_instance(
  1209. mi.id,
  1210. restart_count=restart_count + 1,
  1211. last_restart_time=current_time,
  1212. state=ModelInstanceStateEnum.SCHEDULED,
  1213. state_message="",
  1214. )
  1215. # Pop from error model instances,
  1216. # if failed to restart next time, it will be added again in watch_model_instance_events().
  1217. self._error_model_instances.pop(mi.id, None)
  1218. def _get_model(self, mi: ModelInstance) -> Model:
  1219. """
  1220. Efficiently get model related to the model instance with caching.
  1221. Args:
  1222. mi: The model instance whose model to get.
  1223. """
  1224. if model := self._model_cache_by_instance.get(mi.id):
  1225. return model
  1226. model = self._clientset.models.get(mi.model_id)
  1227. self._model_cache_by_instance[mi.id] = model
  1228. return model
  1229. def _refresh_model(self, mi: ModelInstance) -> Model:
  1230. """
  1231. Refresh the model information from the server.
  1232. Args:
  1233. mi: The model instance whose model to refresh.
  1234. Returns:
  1235. The refreshed model.
  1236. """
  1237. logger.debug(f"Refreshing model {mi.model_name} information from server.")
  1238. refreshed_model = self._clientset.models.get(mi.model_id)
  1239. self._model_cache_by_instance[mi.id] = refreshed_model
  1240. return refreshed_model
  1241. def _is_provisioning(self, mi: ModelInstance) -> bool:
  1242. """
  1243. Check if the model instance is still provisioning.
  1244. Args:
  1245. mi: The model instance to check.
  1246. """
  1247. if process := self._provisioning_processes.get(mi.id):
  1248. if process.is_alive():
  1249. process.join(timeout=0)
  1250. return process.is_alive()
  1251. return False
  1252. def _get_health_check_path(self, backend: str) -> Optional[str]:
  1253. """
  1254. Get health check path for the given backend.
  1255. Args:
  1256. backend: The backend name.
  1257. Returns:
  1258. The health check path if exists, else None.
  1259. """
  1260. inference_backend = self._inference_backend_manager.get_backend_by_name(backend)
  1261. return inference_backend.health_check_path if inference_backend else None
  1262. def get_instance_port_by_model_instance_id(
  1263. self, model_instance_id: int
  1264. ) -> Optional[int]:
  1265. """
  1266. Get the port of the model instance related to the given model instance ID.
  1267. Args:
  1268. model_instance_id: The model instance ID to get the port for.
  1269. Returns:
  1270. The port of the model instance if it exists and is running, else None.
  1271. """
  1272. instance = self._model_instance_by_instance_id.get(
  1273. model_instance_id
  1274. ) # Ensure the model instance is cached.
  1275. return (
  1276. instance.ports[0]
  1277. if instance and instance.state == ModelInstanceStateEnum.RUNNING
  1278. else None
  1279. )
  1280. def is_ready(
  1281. backend: str,
  1282. mi: ModelInstance,
  1283. health_check_path: Optional[str] = None,
  1284. model: Model = None,
  1285. ) -> bool:
  1286. """
  1287. Access the health endpoint of the given model instance to check if it is servable.
  1288. """
  1289. is_built_in = is_built_in_backend(backend)
  1290. if (not is_built_in or backend == BackendEnum.CUSTOM) and (not health_check_path):
  1291. # If custom backend does not have health check path, consider it always ready.
  1292. return True
  1293. if backend == BackendEnum.ASCEND_MINDIE and not health_check_path:
  1294. # Ref: https://www.hiascend.com/document/detail/zh/mindie/21RC2/mindieservice/servicedev/mindie_service0066.html
  1295. # /info provides metadata information and requires more time to respond. Use it for health check.
  1296. health_check_path = "/info"
  1297. elif (
  1298. backend == BackendEnum.SGLANG
  1299. and model
  1300. and CategoryEnum.IMAGE in model.categories
  1301. ):
  1302. if not model.backend_version:
  1303. # version may be empty at initialization, consider it not ready.
  1304. return False
  1305. elif compare_versions(model.backend_version, "0.5.5.post3") >= 0:
  1306. # SGLang Diffusion supported health check path at v0.5.5.post3
  1307. health_check_path = "/health"
  1308. else:
  1309. # Older versions do not support health check, consider it always ready.
  1310. return True
  1311. elif is_built_in and backend != BackendEnum.CUSTOM and not health_check_path:
  1312. # Built-in backends (vLLM, SGLang, vox-box) except (Custom, MindIE) use /v1/models as health check path.
  1313. health_check_path = "/v1/models"
  1314. try:
  1315. # Use the worker IP instead of localhost for health check.
  1316. # Reasons:
  1317. # 1. Connectivity to the loopback address does not work with Ascend MindIE.
  1318. # 2. More adaptable to container networks.
  1319. health_check_url = f"http://{mi.worker_ip}:{mi.port}{health_check_path}"
  1320. response = requests.get(health_check_url, timeout=1)
  1321. if response.status_code == 200:
  1322. return True
  1323. except Exception as e:
  1324. logger.debug(f"Error checking model instance {mi.name} health: {e}")
  1325. pass
  1326. return False
  1327. def _get_inference_endpoint_and_payload(model: Model) -> tuple[str, dict] | None:
  1328. """
  1329. Get inference endpoint and payload for the model.
  1330. Returns None if the model type should skip health check.
  1331. """
  1332. skip_categories = {
  1333. CategoryEnum.IMAGE,
  1334. CategoryEnum.SPEECH_TO_TEXT,
  1335. CategoryEnum.TEXT_TO_SPEECH,
  1336. CategoryEnum.UNKNOWN,
  1337. }
  1338. if not skip_categories.isdisjoint(model.categories):
  1339. return None
  1340. # Return endpoint and payload based on model type (priority order)
  1341. if CategoryEnum.EMBEDDING in model.categories:
  1342. return "/v1/embeddings", {"model": model.name, "input": "test"}
  1343. if CategoryEnum.RERANKER in model.categories:
  1344. return "/v1/rerank", {
  1345. "model": model.name,
  1346. "query": "test",
  1347. "documents": ["test"],
  1348. }
  1349. return "/v1/chat/completions", {
  1350. "model": model.name,
  1351. "messages": [{"role": "user", "content": "ping"}],
  1352. "max_tokens": 1,
  1353. "max_completion_tokens": 1,
  1354. }
  1355. def _get_inference_health_check_config(model: Model) -> dict:
  1356. """Read per-model inference health check config from model.env."""
  1357. env = model.env or {}
  1358. enabled = env.get(
  1359. "GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_ENABLED", "false"
  1360. ).lower() in (
  1361. "true",
  1362. "1",
  1363. )
  1364. interval = safe_int(
  1365. env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_INTERVAL"),
  1366. 300,
  1367. )
  1368. timeout = safe_int(
  1369. env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_TIMEOUT"),
  1370. 15,
  1371. )
  1372. threshold = safe_int(
  1373. env.get("GPUSTACK_MODEL_INFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"),
  1374. 3,
  1375. )
  1376. return {
  1377. "enabled": enabled,
  1378. "interval": interval,
  1379. "timeout": timeout,
  1380. "threshold": threshold,
  1381. }
  1382. def is_inference_ready(mi: ModelInstance, model: Model, timeout: int = 15) -> bool:
  1383. """
  1384. Send a minimal inference request to verify the inference capability is working.
  1385. """
  1386. # Check Custom backend (no standard inference API)
  1387. if is_custom_backend(model.backend):
  1388. return True
  1389. # Check port assignment
  1390. if not mi.port:
  1391. logger.debug(f"Model instance {mi.name} does not have port assigned yet.")
  1392. return False
  1393. # Get endpoint and payload, None means skip health check
  1394. result = _get_inference_endpoint_and_payload(model)
  1395. if not result:
  1396. logger.debug(f"Skipping inference check for {mi.name}")
  1397. return True
  1398. endpoint_path, payload = result
  1399. inference_url = f"http://{mi.worker_ip}:{mi.port}{endpoint_path}"
  1400. try:
  1401. response = requests.post(inference_url, json=payload, timeout=timeout)
  1402. if response.status_code == 200:
  1403. return True
  1404. else:
  1405. logger.warning(
  1406. f"Model instance {mi.name} inference health check failed "
  1407. f"with status {response.status_code} for endpoint {endpoint_path}"
  1408. )
  1409. except Exception as e:
  1410. logger.debug(
  1411. f"Error checking model instance {mi.name} inference at {endpoint_path}: {e}"
  1412. )
  1413. return False