proxy_server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. #!/usr/bin/env python3
  2. """
  3. HTTP Proxy Server based on asyncio
  4. Supports HTTP/1.1 proxy protocol:
  5. - HTTP requests: Client sends full URI (http://host:port/path)
  6. - HTTPS requests: CONNECT method to establish tunnel
  7. Run:
  8. python -m gpustack.websocket_proxy.main --host 0.0.0.0 --port 8080
  9. Test:
  10. curl --proxy http://localhost:8080 http://example.com
  11. curl --proxy http://localhost:8080 https://example.com
  12. """
  13. import logging
  14. import asyncio
  15. from fastapi import HTTPException
  16. from typing import Optional, Dict, Callable, Coroutine, Any, Tuple, TypeAlias
  17. from urllib.parse import urlparse
  18. from .connection_manager import BaseConnectionManager
  19. from .connection import AsyncIOConnection, tunnel, IOConnection
  20. logger = logging.getLogger(__name__)
  21. # ==================== Handler Functions ====================
  22. ConnectionManagerGetter: TypeAlias = Callable[[str], Optional[BaseConnectionManager]]
  23. HeaderAuthenticator: TypeAlias = Callable[[Dict[str, str]], Coroutine[Any, Any, bool]]
  24. HeaderRouter: TypeAlias = Callable[
  25. [Dict[str, str]], Coroutine[Any, Any, Tuple[Optional[str], int]]
  26. ]
  27. async def _read_line(reader: asyncio.StreamReader) -> Optional[str]:
  28. """Read a line from the client"""
  29. try:
  30. line = await reader.readline()
  31. if not line:
  32. return None
  33. return line.decode('utf-8').strip()
  34. except Exception:
  35. return None
  36. async def _read_headers(reader: asyncio.StreamReader) -> Dict[str, str]:
  37. """Read all headers from the client"""
  38. headers: Dict[str, str] = {}
  39. while True:
  40. line = await _read_line(reader)
  41. if not line:
  42. break
  43. if line == "":
  44. break
  45. if ":" in line:
  46. key, value = line.split(":", 1)
  47. headers[key.strip().lower()] = value.strip()
  48. return headers
  49. async def _handle_connect(
  50. client_connection: AsyncIOConnection,
  51. target: str,
  52. connection_manager: BaseConnectionManager,
  53. ) -> None:
  54. """Handle CONNECT request for HTTPS tunnel"""
  55. logger.debug(f"[Proxy] CONNECT target: {target}")
  56. # Parse host and port
  57. if ":" in target:
  58. host, port_str = target.rsplit(":", 1)
  59. try:
  60. port = int(port_str)
  61. except ValueError:
  62. await client_connection.send_error(400, "Invalid port")
  63. return
  64. else:
  65. host = target
  66. port = 443
  67. try:
  68. target_url = f"tcp://{host}:{port}"
  69. logger.debug(
  70. f"[Proxy] Connecting to {target_url} with manager: {connection_manager}"
  71. )
  72. connection = await asyncio.wait_for(
  73. connection_manager.connect(target_url), timeout=30
  74. )
  75. logger.debug(f"[Proxy] Connection established: {connection}")
  76. except Exception as e:
  77. logger.exception(
  78. f"[Proxy] Error connecting to target {host}:{port}, error: {e}"
  79. )
  80. await client_connection.send_error(502, str(e))
  81. return
  82. # Send 200 Connection Established
  83. await client_connection.write_connect_established()
  84. await tunnel(client_connection, connection)
  85. # Hop-by-hop headers that should not be forwarded
  86. # Per RFC 7230 Section 6.1, these headers are hop-by-hop:
  87. # connection, keep-alive, proxy-authenticate, proxy-authorization,
  88. # proxy-connection, te, trailer, transfer-encoding, upgrade
  89. # Note: Upgrade is NOT in this list - it's an end-to-end header
  90. # that should be forwarded to the backend server (needed for WebSocket)
  91. HOP_BY_HOP_HEADERS = {
  92. 'connection',
  93. 'keep-alive',
  94. 'proxy-authenticate',
  95. 'proxy-authorization',
  96. 'proxy-connection',
  97. 'te',
  98. 'trailer',
  99. # Do not filter 'transfer-encoding' as it may be needed for chunked encoding, and the backend server should handle it correctly.
  100. # 'transfer-encoding',
  101. 'host',
  102. }
  103. def _filter_headers(headers) -> list[tuple[str, str]]:
  104. """Filter out hop-by-hop headers that should not be forwarded.
  105. Supports list of tuples (for response headers) or dict (for request headers).
  106. """
  107. filtered: list[tuple[str, str]] = []
  108. connection_values: set[str] = set()
  109. conn_header = None
  110. # Extract values from Connection header
  111. if isinstance(headers, dict):
  112. conn_header = headers.get('connection', '')
  113. elif isinstance(headers, list):
  114. conn_header = _header_get(headers, 'connection')
  115. if conn_header:
  116. for v in conn_header.split(','):
  117. connection_values.add(v.strip().lower())
  118. # Iterate over headers (supports both dict and list of tuples)
  119. if isinstance(headers, dict):
  120. items = headers.items()
  121. else:
  122. items = headers
  123. for key, value in items:
  124. # Skip hop-by-hop headers
  125. if key.lower() in HOP_BY_HOP_HEADERS:
  126. continue
  127. # Skip headers listed in Connection header
  128. if key.lower() in connection_values:
  129. continue
  130. filtered.append((key, value))
  131. return filtered
  132. def _get_request(
  133. method: str,
  134. path: str,
  135. headers: Dict[str, str],
  136. ) -> str:
  137. """Construct HTTP request line and headers for forwarding"""
  138. request = f"{method} {path} HTTP/1.1\r\n"
  139. for key, value in headers:
  140. request += f"{key}: {value}\r\n"
  141. request += "\r\n"
  142. return request
  143. async def _handle_http(
  144. client_connection: AsyncIOConnection,
  145. method: str,
  146. uri: str,
  147. headers: Dict[str, str],
  148. connection_manager: BaseConnectionManager,
  149. header_router: Optional[HeaderRouter] = None,
  150. ) -> None:
  151. """Handle HTTP request with full URI"""
  152. parsed = urlparse(uri)
  153. host, port = await header_router(headers) if header_router else (None, 0)
  154. if host is None:
  155. host = parsed.hostname
  156. port = parsed.port or (80 if parsed.scheme == "http" else 443)
  157. path = parsed.path or "/"
  158. if parsed.query:
  159. path = f"{path}?{parsed.query}"
  160. if not host:
  161. await client_connection.send_error(400, "Invalid URI")
  162. return
  163. logger.debug(f"[Proxy] {method} {uri} -> {host}:{port}")
  164. # Filter hop-by-hop headers
  165. filtered_headers: list[tuple[str, str]] = _filter_headers(headers)
  166. host_value = f"{host}:{port}" if port != 80 else host
  167. filtered_headers.append(('host', host_value))
  168. try:
  169. target_url = f"tcp://{host}:{port}"
  170. connection = await asyncio.wait_for(
  171. connection_manager.connect(target_url), timeout=30
  172. )
  173. except Exception as e:
  174. logger.exception(
  175. f"[Proxy] Error connecting to target {host}:{port}, error: {e}"
  176. )
  177. await client_connection.send_error(502, str(e))
  178. return
  179. try:
  180. await tunnel(
  181. client_connection,
  182. connection,
  183. request=_get_request(method, path, filtered_headers),
  184. response_relay=wait_for_complete_response,
  185. )
  186. except HTTPException as e:
  187. await client_connection.send_error(e.status_code, e.detail)
  188. finally:
  189. if connection:
  190. await connection.close()
  191. def _header_get(headers: list[tuple[str, str]], key: str) -> str:
  192. """Get header value from headers list. Returns empty string if not found."""
  193. for k, v in headers:
  194. if k == key.lower():
  195. return v
  196. return ""
  197. async def wait_for_complete_response( # noqa: C901
  198. remote_reader: IOConnection,
  199. client_writer: IOConnection,
  200. ) -> None:
  201. """Wait for complete HTTP response from WebSocket tunnel and forward to client.
  202. For WebSocket tunnel responses, we forward the raw response data directly without
  203. header filtering, as the response comes from a trusted internal source and may use
  204. chunked transfer encoding that would be broken by header parsing/reconstruction.
  205. """
  206. if client_writer is None:
  207. logger.debug("[Proxy] Client writer is None")
  208. return
  209. pending_data = b''
  210. headers_sent = False
  211. content_length: Optional[int] = None
  212. body_remaining: Optional[int] = None
  213. try:
  214. while True:
  215. chunk = await remote_reader.read()
  216. if not chunk:
  217. logger.debug(
  218. "[Proxy] Remote connection closed while waiting for response"
  219. )
  220. break
  221. pending_data += chunk
  222. # Parse headers on first chunk
  223. if not headers_sent and b'\r\n\r\n' in pending_data:
  224. header_end = pending_data.find(b'\r\n\r\n')
  225. header_part = pending_data[:header_end].decode('utf-8', errors='ignore')
  226. body_start = header_end + 4
  227. # Look for Content-Length
  228. for line in header_part.split('\r\n'):
  229. if line.lower().startswith('content-length:'):
  230. try:
  231. content_length = int(line.split(':', 1)[1].strip())
  232. body_remaining = content_length
  233. except ValueError:
  234. pass # malformed value — fall through to chunked streaming
  235. break
  236. # Forward headers + any body data already received
  237. if body_start < len(pending_data):
  238. body_data = pending_data[body_start:]
  239. # Send headers first
  240. await client_writer.write(pending_data[:body_start])
  241. # Then send initial body
  242. if body_data:
  243. if body_remaining is not None:
  244. to_write = min(len(body_data), body_remaining)
  245. await client_writer.write(body_data[:to_write])
  246. body_remaining -= to_write
  247. pending_data = b''
  248. else:
  249. # Chunked - send and continue
  250. await client_writer.write(body_data)
  251. pending_data = b''
  252. else:
  253. await client_writer.write(pending_data)
  254. pending_data = b''
  255. headers_sent = True
  256. logger.trace(
  257. f"[Proxy] Headers sent, content_length={content_length}, body_remaining={body_remaining}"
  258. )
  259. if body_remaining is not None and body_remaining <= 0:
  260. return
  261. continue
  262. # Forward subsequent body data
  263. if body_remaining is not None:
  264. to_write = min(len(pending_data), body_remaining)
  265. if to_write > 0:
  266. await client_writer.write(pending_data[:to_write])
  267. pending_data = pending_data[to_write:]
  268. body_remaining -= to_write
  269. logger.trace(
  270. f"[Proxy] Forwarded {to_write} bytes, body_remaining={body_remaining}"
  271. )
  272. if body_remaining <= 0:
  273. return
  274. else:
  275. # No Content-Length: stream until source closes (chunked encoding)
  276. if pending_data:
  277. await client_writer.write(pending_data)
  278. pending_data = b''
  279. except asyncio.TimeoutError:
  280. logger.debug("[Proxy] Timeout waiting for response")
  281. return
  282. except Exception as e:
  283. logger.debug(f"[Proxy] Error waiting for complete response: {e}")
  284. return
  285. finally:
  286. await client_writer.close()
  287. # ==================== Server Class ====================
  288. class HTTPSProxyServer:
  289. """Async HTTP/HTTPS Proxy Server"""
  290. def __init__(
  291. self,
  292. host: str,
  293. port: int,
  294. connection_manager_getter: ConnectionManagerGetter,
  295. authenticator: Optional[HeaderAuthenticator] = None,
  296. header_router: Optional[HeaderRouter] = None,
  297. ) -> None:
  298. self.host = host
  299. self.port = port
  300. self.server: Optional[asyncio.Server] = None
  301. self.connection_manager_getter = connection_manager_getter
  302. self.authenticator = authenticator
  303. self.header_router = header_router
  304. async def start(self) -> None:
  305. """Start the proxy server"""
  306. self.server = await asyncio.start_server(
  307. self._handle_client, self.host, self.port
  308. )
  309. logger.debug(f"[Proxy] Server started on {self.host}:{self.port}")
  310. async with self.server:
  311. await self.server.serve_forever()
  312. async def stop(self) -> None:
  313. """Stop the proxy server"""
  314. if self.server:
  315. self.server.close()
  316. await self.server.wait_closed()
  317. async def _get_target_ip(
  318. self, method: str, uri: str, headers: Dict[str, str]
  319. ) -> Optional[str]:
  320. """Extract target IP/hostname from request"""
  321. if method == "CONNECT":
  322. # CONNECT target is the host:port (e.g., "example.com:443")
  323. return uri.split(":")[0] if ":" in uri else uri
  324. elif self.header_router:
  325. target_ip, _ = await self.header_router(headers)
  326. if target_ip:
  327. return target_ip
  328. # HTTP request: parse URI (e.g., "http://example.com:8080/path")
  329. parsed = urlparse(uri)
  330. if parsed.hostname:
  331. return parsed.hostname
  332. # Fallback to Host header
  333. return headers.get("host", "").split(":")[0] or None
  334. async def _handle_client(
  335. self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
  336. ) -> None:
  337. """Handle incoming client connection"""
  338. client_connection = AsyncIOConnection(reader=reader, writer=writer)
  339. client_addr = writer.get_extra_info('peername')
  340. try:
  341. request_line = await _read_line(reader)
  342. if not request_line:
  343. return
  344. parts = request_line.split()
  345. if len(parts) < 3:
  346. await client_connection.send_error(400, "Bad Request")
  347. return
  348. method, uri, _ = parts[0], parts[1], parts[2]
  349. headers = await _read_headers(reader)
  350. logger.debug(f"[Proxy] Received request: {method} {uri} from {client_addr}")
  351. logger.trace(f"[Proxy] Headers: {headers}")
  352. # Authenticate before any other processing
  353. # Skip authenticator for /metrics path on non-CONNECT requests
  354. result = urlparse(uri)
  355. should_skip_auth = method == "GET" and result.path == "/metrics"
  356. if self.authenticator and not should_skip_auth:
  357. if not await self.authenticator(headers):
  358. await client_connection.send_error(401, "Unauthorized")
  359. return
  360. # Extract target address from request
  361. target_ip = await self._get_target_ip(method, uri, headers)
  362. if not target_ip:
  363. await client_connection.send_error(
  364. 400, "Bad Request: No target address"
  365. )
  366. return
  367. # Get connection manager by target IP
  368. connection_manager = (
  369. self.connection_manager_getter(target_ip)
  370. if self.connection_manager_getter
  371. else None
  372. )
  373. if connection_manager is None:
  374. # failed to get connection manager, return error.
  375. logger.debug(
  376. f"[Proxy] No connection manager available for target: {target_ip}"
  377. )
  378. await client_connection.send_error(
  379. 502, "Bad Gateway: No connection manager available"
  380. )
  381. return
  382. if method == "CONNECT":
  383. await _handle_connect(client_connection, uri, connection_manager)
  384. else:
  385. await _handle_http(
  386. client_connection,
  387. method,
  388. uri,
  389. headers,
  390. connection_manager,
  391. self.header_router,
  392. )
  393. except Exception as e:
  394. logger.debug(f"[Proxy] Error handling {client_addr}: {e}")
  395. finally:
  396. await client_connection.close()