worker_request.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import asyncio
  2. from contextlib import asynccontextmanager
  3. import logging
  4. import time
  5. from typing import (
  6. AsyncGenerator,
  7. AsyncIterator,
  8. Callable,
  9. Dict,
  10. Literal,
  11. Optional,
  12. Tuple,
  13. Union,
  14. )
  15. import aiohttp
  16. from gpustack.schemas.workers import Worker
  17. from gpustack.utils.network import use_proxy_env_for_url
  18. logger = logging.getLogger(__name__)
  19. _TIMEOUT = 15
  20. async def is_worker_reachable(
  21. worker: Worker,
  22. proxy_client: Optional[aiohttp.ClientSession] = None,
  23. no_proxy_client: Optional[aiohttp.ClientSession] = None,
  24. timeout_in_second: int = 10,
  25. retry_interval_in_second: int = 3,
  26. ) -> bool:
  27. """
  28. Check if a worker is reachable via a lightweight health check.
  29. Args:
  30. worker: Target worker.
  31. proxy_client: HTTP client with proxy.
  32. no_proxy_client: HTTP client without proxy.
  33. timeout_in_second: Timeout in seconds. Defaults to 10.
  34. retry_interval_in_second: Retry interval in seconds. Defaults to 3.
  35. Returns:
  36. True if worker responds with status < 500, False otherwise.
  37. """
  38. end_time = time.time() + timeout_in_second
  39. while time.time() < end_time:
  40. try:
  41. async with _request_to_worker(
  42. worker=worker,
  43. method="GET",
  44. path="healthz",
  45. proxy_client=proxy_client,
  46. no_proxy_client=no_proxy_client,
  47. timeout=aiohttp.ClientTimeout(total=2),
  48. raise_on_error=False,
  49. ) as resp:
  50. if resp.status == 200:
  51. return True
  52. except Exception:
  53. pass
  54. await asyncio.sleep(retry_interval_in_second)
  55. return False
  56. def _build_url(worker: Worker, path: str) -> str:
  57. """Build URL for a worker request."""
  58. hostname = (
  59. worker.advertise_address
  60. if worker.advertise_address and not worker.get_proxy_address()
  61. else worker.ip
  62. )
  63. return f"http://{hostname}:{worker.port}/{path.lstrip('/')}"
  64. def _convert_params(params: Optional[Dict]) -> Optional[Dict]:
  65. """Convert bool params to str for aiohttp compatibility."""
  66. if params:
  67. return {
  68. k: str(v).lower() if isinstance(v, bool) else v for k, v in params.items()
  69. }
  70. return params
  71. @asynccontextmanager
  72. async def _request_to_worker(
  73. worker: Worker,
  74. method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
  75. path: str,
  76. proxy_client: Optional[aiohttp.ClientSession] = None,
  77. no_proxy_client: Optional[aiohttp.ClientSession] = None,
  78. params: Optional[Dict] = None,
  79. data: Optional[Union[bytes, AsyncIterator[bytes], aiohttp.FormData]] = None,
  80. headers: Optional[Dict[str, str]] = None,
  81. timeout: Optional[aiohttp.ClientTimeout] = None,
  82. raise_on_error: bool = True,
  83. ):
  84. """
  85. Async context manager for worker requests. Yields resp and auto-closes on exit.
  86. Raises:
  87. aiohttp.ClientError: If raise_on_error=True and response is non-2xx.
  88. """
  89. url = _build_url(worker, path)
  90. params = _convert_params(params)
  91. use_env_proxy = use_proxy_env_for_url(url)
  92. client = (
  93. proxy_client
  94. if use_env_proxy and worker.get_proxy_address() is None
  95. else no_proxy_client
  96. )
  97. if client is None:
  98. raise ValueError(
  99. f"No http client available: proxy_client={proxy_client}, no_proxy_client={no_proxy_client}"
  100. )
  101. req_headers = {"Authorization": f"Bearer {worker.token}"}
  102. if headers:
  103. req_headers.update(headers)
  104. if timeout is None:
  105. timeout = aiohttp.ClientTimeout(total=_TIMEOUT, sock_connect=5)
  106. resp = None
  107. try:
  108. resp = await client.request(
  109. method=method,
  110. url=url,
  111. params=params,
  112. data=data,
  113. headers=req_headers,
  114. timeout=timeout,
  115. proxy=worker.get_proxy_address(),
  116. )
  117. if resp.status >= 400 and raise_on_error:
  118. error_text = await resp.text()
  119. raise aiohttp.ClientError(
  120. f"Worker request failed: {worker.id} {method} {url} "
  121. f"status={resp.status}, error={error_text}"
  122. )
  123. yield resp
  124. except aiohttp.ClientError:
  125. raise
  126. except Exception as e:
  127. logger.error(f"Worker request failed: {worker.id} {method} {url}: {e}")
  128. raise
  129. finally:
  130. if resp is not None:
  131. resp.close()
  132. async def request_to_worker(
  133. worker: Worker,
  134. method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
  135. path: str,
  136. proxy_client: Optional[aiohttp.ClientSession] = None,
  137. no_proxy_client: Optional[aiohttp.ClientSession] = None,
  138. params: Optional[Dict] = None,
  139. data: Optional[Union[bytes, AsyncIterator[bytes], aiohttp.FormData]] = None,
  140. headers: Optional[Dict[str, str]] = None,
  141. timeout: Optional[aiohttp.ClientTimeout] = None,
  142. raise_on_error: bool = True,
  143. ) -> Tuple[aiohttp.ClientResponse, Optional[bytes]]:
  144. """
  145. Send a request to a worker.
  146. Returns:
  147. Tuple of (response, body_bytes). Body is None if no content.
  148. Raises:
  149. aiohttp.ClientError: If raise_on_error=True and response is non-2xx, or on other errors.
  150. """
  151. async with _request_to_worker(
  152. worker=worker,
  153. method=method,
  154. path=path,
  155. proxy_client=proxy_client,
  156. no_proxy_client=no_proxy_client,
  157. params=params,
  158. data=data,
  159. headers=headers,
  160. timeout=timeout,
  161. raise_on_error=raise_on_error,
  162. ) as resp:
  163. body = await resp.read()
  164. return resp, body if body else None
  165. def _process_stream_line(line_bytes: bytes) -> str:
  166. """Process a line of bytes to ensure it is properly formatted for streaming."""
  167. line = line_bytes.decode("utf-8").strip()
  168. return line + "\n\n" if line else ""
  169. async def _stream_response_chunks(
  170. resp: aiohttp.ClientResponse,
  171. ) -> AsyncGenerator[str, None]:
  172. """Stream the response content in chunks, processing each line for SSE format."""
  173. chunk_size = 4096 # 4KB
  174. chunk_buffer = b""
  175. async for data in resp.content.iter_chunked(chunk_size):
  176. lines = (chunk_buffer + data).split(b'\n')
  177. chunk_buffer = lines.pop(-1)
  178. for line_bytes in lines:
  179. if line_bytes:
  180. yield _process_stream_line(line_bytes)
  181. if chunk_buffer:
  182. yield _process_stream_line(chunk_buffer)
  183. async def stream_to_worker(
  184. worker: Worker,
  185. method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
  186. path: str,
  187. proxy_client: Optional[aiohttp.ClientSession] = None,
  188. no_proxy_client: Optional[aiohttp.ClientSession] = None,
  189. params: Optional[Dict] = None,
  190. data: Optional[Union[bytes, AsyncIterator[bytes], aiohttp.FormData]] = None,
  191. headers: Optional[Dict[str, str]] = None,
  192. timeout: Optional[aiohttp.ClientTimeout] = None,
  193. on_exception: Optional[
  194. Callable[[Exception, aiohttp.ClientTimeout], Tuple[str, int]]
  195. ] = None,
  196. raw: bool = False,
  197. ) -> AsyncGenerator[Tuple[Union[bytes, str], dict, int], None]:
  198. """
  199. Stream a request to a worker and yield response chunks.
  200. Yields tuples of (chunk, headers, status).
  201. Automatically handles:
  202. - URL construction (advertise_address or ip + port)
  203. - Proxy selection (based on use_proxy_env_for_url)
  204. - Authorization header
  205. Args:
  206. worker: Target worker
  207. method: HTTP method
  208. path: API path
  209. proxy_client: HTTP client with proxy
  210. no_proxy_client: HTTP client without proxy
  211. params: Query parameters
  212. data: Bytes, async iterator of bytes, or FormData
  213. headers: Additional headers
  214. timeout: Request timeout
  215. on_exception: Optional callback(exception, timeout) -> (error_msg, status_code).
  216. Called when an exception occurs during streaming. If not provided,
  217. the exception is raised.
  218. raw: If True, yield raw bytes without SSE line formatting (use for log streams).
  219. If False (default), format each line as SSE (use for OpenAI-compatible streams).
  220. """
  221. try:
  222. async with _request_to_worker(
  223. worker=worker,
  224. method=method,
  225. path=path,
  226. proxy_client=proxy_client,
  227. no_proxy_client=no_proxy_client,
  228. params=params,
  229. data=data,
  230. headers=headers,
  231. timeout=timeout,
  232. raise_on_error=False,
  233. ) as resp:
  234. if resp.status >= 400:
  235. body = await resp.read()
  236. yield body, dict(resp.headers), resp.status
  237. return
  238. if raw:
  239. async for chunk in resp.content.iter_any():
  240. yield chunk, dict(resp.headers), resp.status
  241. else:
  242. async for chunk in _stream_response_chunks(resp):
  243. yield chunk, dict(resp.headers), resp.status
  244. except Exception as e:
  245. logger.error(
  246. f"Worker stream failed: {worker.id} {method} {path}: {e}", exc_info=True
  247. )
  248. if on_exception is not None:
  249. error_msg, status_code = on_exception(e, timeout)
  250. yield error_msg, {}, status_code
  251. else:
  252. raise