generated_inference_backend_client.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import asyncio
  2. import json
  3. import logging
  4. import threading
  5. from typing import Any, Callable, Dict, Optional, Union, Awaitable
  6. import httpx
  7. from gpustack.api.exceptions import (
  8. raise_if_response_error,
  9. async_raise_if_response_error,
  10. )
  11. from gpustack.server.bus import Event, EventType
  12. from gpustack.schemas import *
  13. from gpustack.schemas.common import Pagination
  14. from .generated_http_client import HTTPClient
  15. logger = logging.getLogger(__name__)
  16. class InferenceBackendClient:
  17. def __init__(self, client: HTTPClient, enable_cache: bool = True):
  18. self._client = client
  19. self._url = "/inference-backends"
  20. self._enable_cache = enable_cache
  21. self._cache: Dict[int, InferenceBackendPublic] = {}
  22. self._cache_lock = threading.Lock()
  23. self._watch_started = False
  24. self._initial_sync_logged = False
  25. def list(
  26. self, params: Dict[str, Any] = None, use_cache: bool = True
  27. ) -> InferenceBackendsPublic:
  28. """
  29. List resources.
  30. Args:
  31. params: Query parameters for filtering
  32. use_cache: Whether to use cache. Defaults to True (use cache if available).
  33. Automatically falls back to API if cache watch is not running.
  34. Note: If 'page' or 'perPage' params are provided, always calls API.
  35. Returns:
  36. List of resources
  37. """
  38. # Determine if we should use cache
  39. # Don't use cache if pagination params are provided
  40. pagination_params = {"page", "perPage"} if params else set()
  41. has_pagination = any(k in pagination_params for k in (params or {}))
  42. should_use_cache = (
  43. use_cache
  44. and self._enable_cache
  45. and self._watch_started
  46. and not has_pagination # Don't use cache if pagination params exist
  47. )
  48. # If cache should be used, try to read from cache
  49. if should_use_cache:
  50. return self._list_from_cache(params)
  51. # Otherwise, make API call
  52. response = self._client.get_httpx_client().get(self._url, params=params)
  53. raise_if_response_error(response)
  54. return InferenceBackendsPublic.model_validate(response.json())
  55. def _list_from_cache(
  56. self, params: Dict[str, Any] = None
  57. ) -> InferenceBackendsPublic:
  58. """
  59. List resources from cache instead of making an API call.
  60. Note: Cache is automatically populated when awatch() is called.
  61. The first call to awatch() will set _watch_started=True and enable caching.
  62. """
  63. # Get all cached items
  64. with self._cache_lock:
  65. all_items = list(self._cache.values())
  66. # Apply filters if params provided
  67. if params:
  68. filtered_items = []
  69. for item in all_items:
  70. match = True
  71. for key, value in params.items():
  72. # Skip non-filter params like 'watch'
  73. if key == 'watch':
  74. continue
  75. # Convert attribute to string for comparison
  76. attr_value = getattr(item, key, None)
  77. if attr_value is not None and str(attr_value) != str(value):
  78. match = False
  79. break
  80. if match:
  81. filtered_items.append(item)
  82. all_items = filtered_items
  83. # Return in the same format as the original list()
  84. total = len(all_items)
  85. # Create pagination info for PaginatedList types
  86. pagination = Pagination(
  87. page=1,
  88. perPage=total if total > 0 else 100,
  89. total=total,
  90. totalPage=1 if total > 0 else 0,
  91. )
  92. return InferenceBackendsPublic(
  93. items=all_items, total=total, pagination=pagination
  94. )
  95. async def _update_cache_from_event(self, event: Event):
  96. """Update cache based on received event.
  97. Runs on the awatch event loop. Network I/O uses the async httpx
  98. client and happens outside the cache lock, so concurrent readers
  99. (list/get) are never blocked waiting on HTTP.
  100. """
  101. if not self._enable_cache:
  102. return
  103. try:
  104. # Server only emits ID-only events for DELETED (when its own
  105. # enrichment cache misses on a row that's already gone from DB).
  106. # CREATED/UPDATED are always enriched server-side or dropped, so
  107. # we only handle the DELETED case here.
  108. is_id_only_delete = (
  109. event.type == EventType.DELETED
  110. and isinstance(event.data, dict)
  111. and event.id is not None
  112. and set(event.data.keys()) == {"id"}
  113. )
  114. if is_id_only_delete:
  115. with self._cache_lock:
  116. item = self._cache.pop(event.id, None)
  117. if item is not None:
  118. # Enrich so downstream callbacks (e.g. ServeManager) see
  119. # a validated object instead of {"id": ...}.
  120. event.data = item
  121. logger.debug(f"Cache: removed inferencebackend {event.id}")
  122. return
  123. item = InferenceBackendPublic.model_validate(event.data)
  124. if not hasattr(item, 'id'):
  125. return
  126. with self._cache_lock:
  127. if event.type == EventType.DELETED:
  128. self._cache.pop(item.id, None)
  129. logger.debug(f"Cache: removed inferencebackend {item.id}")
  130. else: # CREATED or UPDATED
  131. self._cache[item.id] = item
  132. logger.trace(f"Cache: updated inferencebackend {item.id}")
  133. except Exception as e:
  134. logger.error(f"Failed to update inference-backends cache from event: {e}")
  135. def watch(
  136. self,
  137. callback: Optional[Callable[[Event], None]] = None,
  138. stop_condition: Optional[Callable[[Event], bool]] = None,
  139. params: Optional[Dict[str, Any]] = None,
  140. ):
  141. if params is None:
  142. params = {}
  143. params["watch"] = "true"
  144. if stop_condition is None:
  145. stop_condition = lambda event: False
  146. with self._client.get_httpx_client().stream(
  147. "GET", self._url, params=params, timeout=None
  148. ) as response:
  149. raise_if_response_error(response)
  150. for line in response.iter_lines():
  151. if line:
  152. event_data = json.loads(line)
  153. event = Event(**event_data)
  154. if callback:
  155. callback(event)
  156. if stop_condition(event):
  157. break
  158. async def awatch(
  159. self,
  160. callback: Optional[
  161. Union[Callable[[Event], None], Callable[[Event], Awaitable[Any]]]
  162. ] = None,
  163. stop_condition: Optional[Callable[[Event], bool]] = None,
  164. params: Optional[Dict[str, Any]] = None,
  165. ):
  166. if params is None:
  167. params = {}
  168. params["watch"] = "true"
  169. if stop_condition is None:
  170. stop_condition = lambda event: False
  171. # Mark watch as started when awatch is called
  172. # This enables list()/get() to use cache automatically
  173. if self._enable_cache and not self._watch_started:
  174. self._watch_started = True
  175. logger.debug(f"inference-backends cache watch started")
  176. async with self._client.get_async_httpx_client().stream(
  177. "GET",
  178. self._url,
  179. params=params,
  180. timeout=httpx.Timeout(connect=10, read=None, write=10, pool=10),
  181. ) as response:
  182. await async_raise_if_response_error(response)
  183. lines = response.aiter_lines()
  184. while True:
  185. try:
  186. line = await asyncio.wait_for(lines.__anext__(), timeout=45)
  187. if line:
  188. event_data = json.loads(line)
  189. event = Event(**event_data)
  190. # Update cache if enabled
  191. if self._enable_cache:
  192. await self._update_cache_from_event(event)
  193. # Log cache size after initial events (approximately)
  194. if (
  195. not self._initial_sync_logged
  196. and event.type == EventType.CREATED
  197. ):
  198. # Check if we have accumulated enough items (heuristic)
  199. with self._cache_lock:
  200. cache_size = len(self._cache)
  201. if cache_size > 0:
  202. # Set a flag to avoid repeated logging
  203. self._initial_sync_logged = True
  204. logger.debug(
  205. f"inference-backends cache populated with {cache_size} items"
  206. )
  207. # Skip the callback if the event is still ID-only after
  208. # cache update (e.g. DELETED for an item this client
  209. # never saw). Subscribers like ServeManager call
  210. # model_validate(event.data) and would otherwise fail;
  211. # also they can't act without the full object.
  212. if (
  213. isinstance(event.data, dict)
  214. and event.id is not None
  215. and set(event.data.keys()) == {"id"}
  216. ):
  217. logger.debug(
  218. f"Skipping callback for ID-only {event.type} event on inference-backends {event.id}"
  219. )
  220. elif callback:
  221. if asyncio.iscoroutinefunction(callback):
  222. await callback(event)
  223. else:
  224. callback(event)
  225. if stop_condition(event):
  226. break
  227. except asyncio.TimeoutError:
  228. raise Exception("watch timeout")
  229. def get(self, id: int, use_cache: bool = True) -> InferenceBackendPublic:
  230. """
  231. Get a resource by ID.
  232. Args:
  233. id: Resource ID
  234. use_cache: Whether to use cache. Defaults to True (use cache if available).
  235. Automatically falls back to API if cache watch is not running.
  236. Returns:
  237. Resource object
  238. """
  239. # Use cache if enabled, watch is running, and use_cache is True
  240. should_use_cache = use_cache and self._enable_cache and self._watch_started
  241. # Try to get from cache first if it should be used
  242. if should_use_cache:
  243. with self._cache_lock:
  244. if id in self._cache:
  245. logger.trace(f"Cache hit for inferencebackend {id}")
  246. return self._cache[id]
  247. # Fall back to API call
  248. response = self._client.get_httpx_client().get(f"{self._url}/{id}")
  249. raise_if_response_error(response)
  250. result = InferenceBackendPublic.model_validate(response.json())
  251. # Update cache if enabled
  252. if self._enable_cache:
  253. with self._cache_lock:
  254. self._cache[id] = result
  255. return result
  256. def create(self, model_create: InferenceBackendCreate):
  257. response = self._client.get_httpx_client().post(
  258. self._url,
  259. content=model_create.model_dump_json(),
  260. headers={"Content-Type": "application/json"},
  261. )
  262. raise_if_response_error(response)
  263. return InferenceBackendPublic.model_validate(response.json())
  264. def update(self, id: int, model_update: InferenceBackendUpdate):
  265. response = self._client.get_httpx_client().put(
  266. f"{self._url}/{id}",
  267. content=model_update.model_dump_json(),
  268. headers={"Content-Type": "application/json"},
  269. )
  270. raise_if_response_error(response)
  271. return InferenceBackendPublic.model_validate(response.json())
  272. def delete(self, id: int):
  273. response = self._client.get_httpx_client().delete(f"{self._url}/{id}")
  274. raise_if_response_error(response)