| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- import asyncio
- from contextlib import asynccontextmanager
- import logging
- import time
- from typing import (
- AsyncGenerator,
- AsyncIterator,
- Callable,
- Dict,
- Literal,
- Optional,
- Tuple,
- Union,
- )
- import aiohttp
- from gpustack.schemas.workers import Worker
- from gpustack.utils.network import use_proxy_env_for_url
- logger = logging.getLogger(__name__)
- _TIMEOUT = 15
- async def is_worker_reachable(
- worker: Worker,
- proxy_client: Optional[aiohttp.ClientSession] = None,
- no_proxy_client: Optional[aiohttp.ClientSession] = None,
- timeout_in_second: int = 10,
- retry_interval_in_second: int = 3,
- ) -> bool:
- """
- Check if a worker is reachable via a lightweight health check.
- Args:
- worker: Target worker.
- proxy_client: HTTP client with proxy.
- no_proxy_client: HTTP client without proxy.
- timeout_in_second: Timeout in seconds. Defaults to 10.
- retry_interval_in_second: Retry interval in seconds. Defaults to 3.
- Returns:
- True if worker responds with status < 500, False otherwise.
- """
- end_time = time.time() + timeout_in_second
- while time.time() < end_time:
- try:
- async with _request_to_worker(
- worker=worker,
- method="GET",
- path="healthz",
- proxy_client=proxy_client,
- no_proxy_client=no_proxy_client,
- timeout=aiohttp.ClientTimeout(total=2),
- raise_on_error=False,
- ) as resp:
- if resp.status == 200:
- return True
- except Exception:
- pass
- await asyncio.sleep(retry_interval_in_second)
- return False
- def _build_url(worker: Worker, path: str) -> str:
- """Build URL for a worker request."""
- hostname = (
- worker.advertise_address
- if worker.advertise_address and not worker.get_proxy_address()
- else worker.ip
- )
- return f"http://{hostname}:{worker.port}/{path.lstrip('/')}"
- def _convert_params(params: Optional[Dict]) -> Optional[Dict]:
- """Convert bool params to str for aiohttp compatibility."""
- if params:
- return {
- k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()
- }
- return params
- @asynccontextmanager
- async def _request_to_worker(
- worker: Worker,
- method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
- path: str,
- proxy_client: Optional[aiohttp.ClientSession] = None,
- no_proxy_client: Optional[aiohttp.ClientSession] = None,
- params: Optional[Dict] = None,
- data: Optional[Union[bytes, AsyncIterator[bytes], aiohttp.FormData]] = None,
- headers: Optional[Dict[str, str]] = None,
- timeout: Optional[aiohttp.ClientTimeout] = None,
- raise_on_error: bool = True,
- ):
- """
- Async context manager for worker requests. Yields resp and auto-closes on exit.
- Raises:
- aiohttp.ClientError: If raise_on_error=True and response is non-2xx.
- """
- url = _build_url(worker, path)
- params = _convert_params(params)
- use_env_proxy = use_proxy_env_for_url(url)
- client = (
- proxy_client
- if use_env_proxy and worker.get_proxy_address() is None
- else no_proxy_client
- )
- if client is None:
- raise ValueError(
- f"No http client available: proxy_client={proxy_client}, no_proxy_client={no_proxy_client}"
- )
- req_headers = {"Authorization": f"Bearer {worker.token}"}
- if headers:
- req_headers.update(headers)
- if timeout is None:
- timeout = aiohttp.ClientTimeout(total=_TIMEOUT, sock_connect=5)
- resp = None
- try:
- resp = await client.request(
- method=method,
- url=url,
- params=params,
- data=data,
- headers=req_headers,
- timeout=timeout,
- proxy=worker.get_proxy_address(),
- )
- if resp.status >= 400 and raise_on_error:
- error_text = await resp.text()
- raise aiohttp.ClientError(
- f"Worker request failed: {worker.id} {method} {url} "
- f"status={resp.status}, error={error_text}"
- )
- yield resp
- except aiohttp.ClientError:
- raise
- except Exception as e:
- logger.error(f"Worker request failed: {worker.id} {method} {url}: {e}")
- raise
- finally:
- if resp is not None:
- resp.close()
- async def request_to_worker(
- worker: Worker,
- method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
- path: str,
- proxy_client: Optional[aiohttp.ClientSession] = None,
- no_proxy_client: Optional[aiohttp.ClientSession] = None,
- params: Optional[Dict] = None,
- data: Optional[Union[bytes, AsyncIterator[bytes], aiohttp.FormData]] = None,
- headers: Optional[Dict[str, str]] = None,
- timeout: Optional[aiohttp.ClientTimeout] = None,
- raise_on_error: bool = True,
- ) -> Tuple[aiohttp.ClientResponse, Optional[bytes]]:
- """
- Send a request to a worker.
- Returns:
- Tuple of (response, body_bytes). Body is None if no content.
- Raises:
- aiohttp.ClientError: If raise_on_error=True and response is non-2xx, or on other errors.
- """
- async with _request_to_worker(
- worker=worker,
- method=method,
- path=path,
- proxy_client=proxy_client,
- no_proxy_client=no_proxy_client,
- params=params,
- data=data,
- headers=headers,
- timeout=timeout,
- raise_on_error=raise_on_error,
- ) as resp:
- body = await resp.read()
- return resp, body if body else None
- def _process_stream_line(line_bytes: bytes) -> str:
- """Process a line of bytes to ensure it is properly formatted for streaming."""
- line = line_bytes.decode("utf-8").strip()
- return line + "\n\n" if line else ""
- async def _stream_response_chunks(
- resp: aiohttp.ClientResponse,
- ) -> AsyncGenerator[str, None]:
- """Stream the response content in chunks, processing each line for SSE format."""
- chunk_size = 4096 # 4KB
- chunk_buffer = b""
- async for data in resp.content.iter_chunked(chunk_size):
- lines = (chunk_buffer + data).split(b'\n')
- chunk_buffer = lines.pop(-1)
- for line_bytes in lines:
- if line_bytes:
- yield _process_stream_line(line_bytes)
- if chunk_buffer:
- yield _process_stream_line(chunk_buffer)
- async def stream_to_worker(
- worker: Worker,
- method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
- path: str,
- proxy_client: Optional[aiohttp.ClientSession] = None,
- no_proxy_client: Optional[aiohttp.ClientSession] = None,
- params: Optional[Dict] = None,
- data: Optional[Union[bytes, AsyncIterator[bytes], aiohttp.FormData]] = None,
- headers: Optional[Dict[str, str]] = None,
- timeout: Optional[aiohttp.ClientTimeout] = None,
- on_exception: Optional[
- Callable[[Exception, aiohttp.ClientTimeout], Tuple[str, int]]
- ] = None,
- raw: bool = False,
- ) -> AsyncGenerator[Tuple[Union[bytes, str], dict, int], None]:
- """
- Stream a request to a worker and yield response chunks.
- Yields tuples of (chunk, headers, status).
- Automatically handles:
- - URL construction (advertise_address or ip + port)
- - Proxy selection (based on use_proxy_env_for_url)
- - Authorization header
- Args:
- worker: Target worker
- method: HTTP method
- path: API path
- proxy_client: HTTP client with proxy
- no_proxy_client: HTTP client without proxy
- params: Query parameters
- data: Bytes, async iterator of bytes, or FormData
- headers: Additional headers
- timeout: Request timeout
- on_exception: Optional callback(exception, timeout) -> (error_msg, status_code).
- Called when an exception occurs during streaming. If not provided,
- the exception is raised.
- raw: If True, yield raw bytes without SSE line formatting (use for log streams).
- If False (default), format each line as SSE (use for OpenAI-compatible streams).
- """
- try:
- async with _request_to_worker(
- worker=worker,
- method=method,
- path=path,
- proxy_client=proxy_client,
- no_proxy_client=no_proxy_client,
- params=params,
- data=data,
- headers=headers,
- timeout=timeout,
- raise_on_error=False,
- ) as resp:
- if resp.status >= 400:
- body = await resp.read()
- yield body, dict(resp.headers), resp.status
- return
- if raw:
- async for chunk in resp.content.iter_any():
- yield chunk, dict(resp.headers), resp.status
- else:
- async for chunk in _stream_response_chunks(resp):
- yield chunk, dict(resp.headers), resp.status
- except Exception as e:
- logger.error(
- f"Worker stream failed: {worker.id} {method} {path}: {e}", exc_info=True
- )
- if on_exception is not None:
- error_msg, status_code = on_exception(e, timeout)
- yield error_msg, {}, status_code
- else:
- raise
|