generated_worker_client.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import asyncio
  2. import json
  3. import logging
  4. import threading
  5. from typing import Any, Callable, Dict, Optional, Union, Awaitable
  6. import httpx
  7. from gpustack.api.exceptions import (
  8. raise_if_response_error,
  9. async_raise_if_response_error,
  10. )
  11. from gpustack.server.bus import Event, EventType
  12. from gpustack.schemas import *
  13. from gpustack.schemas.common import Pagination
  14. from .generated_http_client import HTTPClient
  15. logger = logging.getLogger(__name__)
  16. class WorkerClient:
  17. def __init__(self, client: HTTPClient, enable_cache: bool = True):
  18. self._client = client
  19. self._url = "/workers"
  20. self._enable_cache = enable_cache
  21. self._cache: Dict[int, WorkerPublic] = {}
  22. self._cache_lock = threading.Lock()
  23. self._watch_started = False
  24. self._initial_sync_logged = False
  25. def list(
  26. self, params: Dict[str, Any] = None, use_cache: bool = True
  27. ) -> WorkersPublic:
  28. """
  29. List resources.
  30. Args:
  31. params: Query parameters for filtering
  32. use_cache: Whether to use cache. Defaults to True (use cache if available).
  33. Automatically falls back to API if cache watch is not running.
  34. Note: If 'page' or 'perPage' params are provided, always calls API.
  35. Returns:
  36. List of resources
  37. """
  38. # Determine if we should use cache
  39. # Don't use cache if pagination params are provided
  40. pagination_params = {"page", "perPage"} if params else set()
  41. has_pagination = any(k in pagination_params for k in (params or {}))
  42. should_use_cache = (
  43. use_cache
  44. and self._enable_cache
  45. and self._watch_started
  46. and not has_pagination # Don't use cache if pagination params exist
  47. )
  48. # If cache should be used, try to read from cache
  49. if should_use_cache:
  50. return self._list_from_cache(params)
  51. # Otherwise, make API call
  52. response = self._client.get_httpx_client().get(self._url, params=params)
  53. raise_if_response_error(response)
  54. return WorkersPublic.model_validate(response.json())
  55. def _list_from_cache(self, params: Dict[str, Any] = None) -> WorkersPublic:
  56. """
  57. List resources from cache instead of making an API call.
  58. Note: Cache is automatically populated when awatch() is called.
  59. The first call to awatch() will set _watch_started=True and enable caching.
  60. """
  61. # Get all cached items
  62. with self._cache_lock:
  63. all_items = list(self._cache.values())
  64. # Apply filters if params provided
  65. if params:
  66. filtered_items = []
  67. for item in all_items:
  68. match = True
  69. for key, value in params.items():
  70. # Skip non-filter params like 'watch'
  71. if key == 'watch':
  72. continue
  73. # Convert attribute to string for comparison
  74. attr_value = getattr(item, key, None)
  75. if attr_value is not None and str(attr_value) != str(value):
  76. match = False
  77. break
  78. if match:
  79. filtered_items.append(item)
  80. all_items = filtered_items
  81. # Return in the same format as the original list()
  82. total = len(all_items)
  83. # Create pagination info for PaginatedList types
  84. pagination = Pagination(
  85. page=1,
  86. perPage=total if total > 0 else 100,
  87. total=total,
  88. totalPage=1 if total > 0 else 0,
  89. )
  90. return WorkersPublic(items=all_items, total=total, pagination=pagination)
  91. async def _update_cache_from_event(self, event: Event):
  92. """Update cache based on received event.
  93. Runs on the awatch event loop. Network I/O uses the async httpx
  94. client and happens outside the cache lock, so concurrent readers
  95. (list/get) are never blocked waiting on HTTP.
  96. """
  97. if not self._enable_cache:
  98. return
  99. try:
  100. # Server only emits ID-only events for DELETED (when its own
  101. # enrichment cache misses on a row that's already gone from DB).
  102. # CREATED/UPDATED are always enriched server-side or dropped, so
  103. # we only handle the DELETED case here.
  104. is_id_only_delete = (
  105. event.type == EventType.DELETED
  106. and isinstance(event.data, dict)
  107. and event.id is not None
  108. and set(event.data.keys()) == {"id"}
  109. )
  110. if is_id_only_delete:
  111. with self._cache_lock:
  112. item = self._cache.pop(event.id, None)
  113. if item is not None:
  114. # Enrich so downstream callbacks (e.g. ServeManager) see
  115. # a validated object instead of {"id": ...}.
  116. event.data = item
  117. logger.debug(f"Cache: removed worker {event.id}")
  118. return
  119. item = WorkerPublic.model_validate(event.data)
  120. if not hasattr(item, 'id'):
  121. return
  122. with self._cache_lock:
  123. if event.type == EventType.DELETED:
  124. self._cache.pop(item.id, None)
  125. logger.debug(f"Cache: removed worker {item.id}")
  126. else: # CREATED or UPDATED
  127. self._cache[item.id] = item
  128. logger.trace(f"Cache: updated worker {item.id}")
  129. except Exception as e:
  130. logger.error(f"Failed to update workers cache from event: {e}")
  131. def watch(
  132. self,
  133. callback: Optional[Callable[[Event], None]] = None,
  134. stop_condition: Optional[Callable[[Event], bool]] = None,
  135. params: Optional[Dict[str, Any]] = None,
  136. ):
  137. if params is None:
  138. params = {}
  139. params["watch"] = "true"
  140. if stop_condition is None:
  141. stop_condition = lambda event: False
  142. with self._client.get_httpx_client().stream(
  143. "GET", self._url, params=params, timeout=None
  144. ) as response:
  145. raise_if_response_error(response)
  146. for line in response.iter_lines():
  147. if line:
  148. event_data = json.loads(line)
  149. event = Event(**event_data)
  150. if callback:
  151. callback(event)
  152. if stop_condition(event):
  153. break
  154. async def awatch(
  155. self,
  156. callback: Optional[
  157. Union[Callable[[Event], None], Callable[[Event], Awaitable[Any]]]
  158. ] = None,
  159. stop_condition: Optional[Callable[[Event], bool]] = None,
  160. params: Optional[Dict[str, Any]] = None,
  161. ):
  162. if params is None:
  163. params = {}
  164. params["watch"] = "true"
  165. if stop_condition is None:
  166. stop_condition = lambda event: False
  167. # Mark watch as started when awatch is called
  168. # This enables list()/get() to use cache automatically
  169. if self._enable_cache and not self._watch_started:
  170. self._watch_started = True
  171. logger.debug(f"workers cache watch started")
  172. async with self._client.get_async_httpx_client().stream(
  173. "GET",
  174. self._url,
  175. params=params,
  176. timeout=httpx.Timeout(connect=10, read=None, write=10, pool=10),
  177. ) as response:
  178. await async_raise_if_response_error(response)
  179. lines = response.aiter_lines()
  180. while True:
  181. try:
  182. line = await asyncio.wait_for(lines.__anext__(), timeout=45)
  183. if line:
  184. event_data = json.loads(line)
  185. event = Event(**event_data)
  186. # Update cache if enabled
  187. if self._enable_cache:
  188. await self._update_cache_from_event(event)
  189. # Log cache size after initial events (approximately)
  190. if (
  191. not self._initial_sync_logged
  192. and event.type == EventType.CREATED
  193. ):
  194. # Check if we have accumulated enough items (heuristic)
  195. with self._cache_lock:
  196. cache_size = len(self._cache)
  197. if cache_size > 0:
  198. # Set a flag to avoid repeated logging
  199. self._initial_sync_logged = True
  200. logger.debug(
  201. f"workers cache populated with {cache_size} items"
  202. )
  203. # Skip the callback if the event is still ID-only after
  204. # cache update (e.g. DELETED for an item this client
  205. # never saw). Subscribers like ServeManager call
  206. # model_validate(event.data) and would otherwise fail;
  207. # also they can't act without the full object.
  208. if (
  209. isinstance(event.data, dict)
  210. and event.id is not None
  211. and set(event.data.keys()) == {"id"}
  212. ):
  213. logger.debug(
  214. f"Skipping callback for ID-only {event.type} event on workers {event.id}"
  215. )
  216. elif callback:
  217. if asyncio.iscoroutinefunction(callback):
  218. await callback(event)
  219. else:
  220. callback(event)
  221. if stop_condition(event):
  222. break
  223. except asyncio.TimeoutError:
  224. raise Exception("watch timeout")
  225. def get(self, id: int, use_cache: bool = True) -> WorkerPublic:
  226. """
  227. Get a resource by ID.
  228. Args:
  229. id: Resource ID
  230. use_cache: Whether to use cache. Defaults to True (use cache if available).
  231. Automatically falls back to API if cache watch is not running.
  232. Returns:
  233. Resource object
  234. """
  235. # Use cache if enabled, watch is running, and use_cache is True
  236. should_use_cache = use_cache and self._enable_cache and self._watch_started
  237. # Try to get from cache first if it should be used
  238. if should_use_cache:
  239. with self._cache_lock:
  240. if id in self._cache:
  241. logger.trace(f"Cache hit for worker {id}")
  242. return self._cache[id]
  243. # Fall back to API call
  244. response = self._client.get_httpx_client().get(f"{self._url}/{id}")
  245. raise_if_response_error(response)
  246. result = WorkerPublic.model_validate(response.json())
  247. # Update cache if enabled
  248. if self._enable_cache:
  249. with self._cache_lock:
  250. self._cache[id] = result
  251. return result
  252. def create(self, model_create: WorkerCreate):
  253. response = self._client.get_httpx_client().post(
  254. self._url,
  255. content=model_create.model_dump_json(),
  256. headers={"Content-Type": "application/json"},
  257. )
  258. raise_if_response_error(response)
  259. return WorkerPublic.model_validate(response.json())
  260. def update(self, id: int, model_update: WorkerUpdate):
  261. response = self._client.get_httpx_client().put(
  262. f"{self._url}/{id}",
  263. content=model_update.model_dump_json(),
  264. headers={"Content-Type": "application/json"},
  265. )
  266. raise_if_response_error(response)
  267. return WorkerPublic.model_validate(response.json())
  268. def delete(self, id: int):
  269. response = self._client.get_httpx_client().delete(f"{self._url}/{id}")
  270. raise_if_response_error(response)