| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- import asyncio
- import json
- import logging
- import threading
- from typing import Any, Callable, Dict, Optional, Union, Awaitable
- import httpx
- from gpustack.api.exceptions import (
- raise_if_response_error,
- async_raise_if_response_error,
- )
- from gpustack.server.bus import Event, EventType
- from gpustack.schemas import *
- from gpustack.schemas.common import Pagination
- from .generated_http_client import HTTPClient
- logger = logging.getLogger(__name__)
- class WorkerClient:
- def __init__(self, client: HTTPClient, enable_cache: bool = True):
- self._client = client
- self._url = "/workers"
- self._enable_cache = enable_cache
- self._cache: Dict[int, WorkerPublic] = {}
- self._cache_lock = threading.Lock()
- self._watch_started = False
- self._initial_sync_logged = False
- def list(
- self, params: Dict[str, Any] = None, use_cache: bool = True
- ) -> WorkersPublic:
- """
- List resources.
- Args:
- params: Query parameters for filtering
- use_cache: Whether to use cache. Defaults to True (use cache if available).
- Automatically falls back to API if cache watch is not running.
- Note: If 'page' or 'perPage' params are provided, always calls API.
- Returns:
- List of resources
- """
- # Determine if we should use cache
- # Don't use cache if pagination params are provided
- pagination_params = {"page", "perPage"} if params else set()
- has_pagination = any(k in pagination_params for k in (params or {}))
- should_use_cache = (
- use_cache
- and self._enable_cache
- and self._watch_started
- and not has_pagination # Don't use cache if pagination params exist
- )
- # If cache should be used, try to read from cache
- if should_use_cache:
- return self._list_from_cache(params)
- # Otherwise, make API call
- response = self._client.get_httpx_client().get(self._url, params=params)
- raise_if_response_error(response)
- return WorkersPublic.model_validate(response.json())
- def _list_from_cache(self, params: Dict[str, Any] = None) -> WorkersPublic:
- """
- List resources from cache instead of making an API call.
- Note: Cache is automatically populated when awatch() is called.
- The first call to awatch() will set _watch_started=True and enable caching.
- """
- # Get all cached items
- with self._cache_lock:
- all_items = list(self._cache.values())
- # Apply filters if params provided
- if params:
- filtered_items = []
- for item in all_items:
- match = True
- for key, value in params.items():
- # Skip non-filter params like 'watch'
- if key == 'watch':
- continue
- # Convert attribute to string for comparison
- attr_value = getattr(item, key, None)
- if attr_value is not None and str(attr_value) != str(value):
- match = False
- break
- if match:
- filtered_items.append(item)
- all_items = filtered_items
- # Return in the same format as the original list()
- total = len(all_items)
- # Create pagination info for PaginatedList types
- pagination = Pagination(
- page=1,
- perPage=total if total > 0 else 100,
- total=total,
- totalPage=1 if total > 0 else 0,
- )
- return WorkersPublic(items=all_items, total=total, pagination=pagination)
- async def _update_cache_from_event(self, event: Event):
- """Update cache based on received event.
- Runs on the awatch event loop. Network I/O uses the async httpx
- client and happens outside the cache lock, so concurrent readers
- (list/get) are never blocked waiting on HTTP.
- """
- if not self._enable_cache:
- return
- try:
- # Server only emits ID-only events for DELETED (when its own
- # enrichment cache misses on a row that's already gone from DB).
- # CREATED/UPDATED are always enriched server-side or dropped, so
- # we only handle the DELETED case here.
- is_id_only_delete = (
- event.type == EventType.DELETED
- and isinstance(event.data, dict)
- and event.id is not None
- and set(event.data.keys()) == {"id"}
- )
- if is_id_only_delete:
- with self._cache_lock:
- item = self._cache.pop(event.id, None)
- if item is not None:
- # Enrich so downstream callbacks (e.g. ServeManager) see
- # a validated object instead of {"id": ...}.
- event.data = item
- logger.debug(f"Cache: removed worker {event.id}")
- return
- item = WorkerPublic.model_validate(event.data)
- if not hasattr(item, 'id'):
- return
- with self._cache_lock:
- if event.type == EventType.DELETED:
- self._cache.pop(item.id, None)
- logger.debug(f"Cache: removed worker {item.id}")
- else: # CREATED or UPDATED
- self._cache[item.id] = item
- logger.trace(f"Cache: updated worker {item.id}")
- except Exception as e:
- logger.error(f"Failed to update workers cache from event: {e}")
- def watch(
- self,
- callback: Optional[Callable[[Event], None]] = None,
- stop_condition: Optional[Callable[[Event], bool]] = None,
- params: Optional[Dict[str, Any]] = None,
- ):
- if params is None:
- params = {}
- params["watch"] = "true"
- if stop_condition is None:
- stop_condition = lambda event: False
- with self._client.get_httpx_client().stream(
- "GET", self._url, params=params, timeout=None
- ) as response:
- raise_if_response_error(response)
- for line in response.iter_lines():
- if line:
- event_data = json.loads(line)
- event = Event(**event_data)
- if callback:
- callback(event)
- if stop_condition(event):
- break
- async def awatch(
- self,
- callback: Optional[
- Union[Callable[[Event], None], Callable[[Event], Awaitable[Any]]]
- ] = None,
- stop_condition: Optional[Callable[[Event], bool]] = None,
- params: Optional[Dict[str, Any]] = None,
- ):
- if params is None:
- params = {}
- params["watch"] = "true"
- if stop_condition is None:
- stop_condition = lambda event: False
- # Mark watch as started when awatch is called
- # This enables list()/get() to use cache automatically
- if self._enable_cache and not self._watch_started:
- self._watch_started = True
- logger.debug(f"workers cache watch started")
- async with self._client.get_async_httpx_client().stream(
- "GET",
- self._url,
- params=params,
- timeout=httpx.Timeout(connect=10, read=None, write=10, pool=10),
- ) as response:
- await async_raise_if_response_error(response)
- lines = response.aiter_lines()
- while True:
- try:
- line = await asyncio.wait_for(lines.__anext__(), timeout=45)
- if line:
- event_data = json.loads(line)
- event = Event(**event_data)
- # Update cache if enabled
- if self._enable_cache:
- await self._update_cache_from_event(event)
- # Log cache size after initial events (approximately)
- if (
- not self._initial_sync_logged
- and event.type == EventType.CREATED
- ):
- # Check if we have accumulated enough items (heuristic)
- with self._cache_lock:
- cache_size = len(self._cache)
- if cache_size > 0:
- # Set a flag to avoid repeated logging
- self._initial_sync_logged = True
- logger.debug(
- f"workers cache populated with {cache_size} items"
- )
- # Skip the callback if the event is still ID-only after
- # cache update (e.g. DELETED for an item this client
- # never saw). Subscribers like ServeManager call
- # model_validate(event.data) and would otherwise fail;
- # also they can't act without the full object.
- if (
- isinstance(event.data, dict)
- and event.id is not None
- and set(event.data.keys()) == {"id"}
- ):
- logger.debug(
- f"Skipping callback for ID-only {event.type} event on workers {event.id}"
- )
- elif callback:
- if asyncio.iscoroutinefunction(callback):
- await callback(event)
- else:
- callback(event)
- if stop_condition(event):
- break
- except asyncio.TimeoutError:
- raise Exception("watch timeout")
- def get(self, id: int, use_cache: bool = True) -> WorkerPublic:
- """
- Get a resource by ID.
- Args:
- id: Resource ID
- use_cache: Whether to use cache. Defaults to True (use cache if available).
- Automatically falls back to API if cache watch is not running.
- Returns:
- Resource object
- """
- # Use cache if enabled, watch is running, and use_cache is True
- should_use_cache = use_cache and self._enable_cache and self._watch_started
- # Try to get from cache first if it should be used
- if should_use_cache:
- with self._cache_lock:
- if id in self._cache:
- logger.trace(f"Cache hit for worker {id}")
- return self._cache[id]
- # Fall back to API call
- response = self._client.get_httpx_client().get(f"{self._url}/{id}")
- raise_if_response_error(response)
- result = WorkerPublic.model_validate(response.json())
- # Update cache if enabled
- if self._enable_cache:
- with self._cache_lock:
- self._cache[id] = result
- return result
- def create(self, model_create: WorkerCreate):
- response = self._client.get_httpx_client().post(
- self._url,
- content=model_create.model_dump_json(),
- headers={"Content-Type": "application/json"},
- )
- raise_if_response_error(response)
- return WorkerPublic.model_validate(response.json())
- def update(self, id: int, model_update: WorkerUpdate):
- response = self._client.get_httpx_client().put(
- f"{self._url}/{id}",
- content=model_update.model_dump_json(),
- headers={"Content-Type": "application/json"},
- )
- raise_if_response_error(response)
- return WorkerPublic.model_validate(response.json())
- def delete(self, id: int):
- response = self._client.get_httpx_client().delete(f"{self._url}/{id}")
- raise_if_response_error(response)
|