test_proxy_server.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. """
  2. Tests for proxy_server module (HTTPSProxyServer).
  3. """
  4. import asyncio
  5. import pytest
  6. import uuid
  7. import uvicorn
  8. from fastapi import FastAPI
  9. from gpustack.websocket_proxy.proxy_server import HTTPSProxyServer
  10. from gpustack.websocket_proxy.message_server import MessageServerHandler, router
  11. def get_free_port(host: str) -> int:
  12. """Get a free port on the host"""
  13. import socket
  14. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  15. s.bind((host, 0))
  16. return s.getsockname()[1]
  17. async def start_websocket_server(
  18. message_handler: MessageServerHandler, host: str, port: int
  19. ):
  20. """Start WebSocket server using uvicorn"""
  21. app = FastAPI()
  22. app.state.message_server_handler = message_handler
  23. app.include_router(router)
  24. actual_port = get_free_port(host) if port == 0 else port
  25. config = uvicorn.Config(app, host=host, port=actual_port, log_level="error")
  26. server = uvicorn.Server(config)
  27. server_task = asyncio.create_task(server.serve())
  28. await asyncio.sleep(0.5)
  29. return server, server_task, actual_port
  30. class TestProxyAuthenticator:
  31. """Test HTTPSProxyServer authenticator functionality."""
  32. @pytest.mark.asyncio
  33. async def test_no_authenticator_passes(self):
  34. """Test that requests pass when no authenticator is configured."""
  35. message_handler = MessageServerHandler(
  36. listen_address="127.0.0.1",
  37. listen_port=0,
  38. proxy_port=0,
  39. )
  40. _, ws_task, ws_port = await start_websocket_server(
  41. message_handler, "127.0.0.1", 0
  42. )
  43. # Proxy without authenticator
  44. proxy_server = HTTPSProxyServer(
  45. host="127.0.0.1",
  46. port=0,
  47. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  48. authenticator=None, # No authenticator
  49. )
  50. proxy_task = asyncio.create_task(proxy_server.start())
  51. await asyncio.sleep(0.5)
  52. proxy_addr = proxy_server.server.sockets[0].getsockname()
  53. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  54. # Start test server
  55. async def handle_request(reader, writer):
  56. response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK"
  57. writer.write(response)
  58. await writer.drain()
  59. writer.close()
  60. await writer.wait_closed()
  61. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  62. test_port = test_server.sockets[0].getsockname()[1]
  63. # Connect client
  64. from gpustack.websocket_proxy.message_client import MessageClient
  65. client = MessageClient(
  66. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  67. client_id=uuid.uuid4(),
  68. cidrs=["127.0.0.1/32"],
  69. )
  70. client_task = asyncio.create_task(client.run())
  71. await asyncio.sleep(0.5)
  72. try:
  73. # Connect through proxy using CONNECT tunnel
  74. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  75. connect_request = f"CONNECT 127.0.0.1:{test_port} HTTP/1.1\r\nHost: 127.0.0.1:{test_port}\r\n\r\n"
  76. writer.write(connect_request.encode())
  77. await writer.drain()
  78. connect_response = await asyncio.wait_for(
  79. reader.readuntil(b"\r\n\r\n"), timeout=5.0
  80. )
  81. assert connect_response.startswith(b"HTTP/1.1 200")
  82. # Make HTTP request through tunnel
  83. request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{test_port}\r\n\r\n"
  84. writer.write(request.encode())
  85. await writer.drain()
  86. response = await asyncio.wait_for(reader.read(1024), timeout=5.0)
  87. assert b"200 OK" in response
  88. writer.close()
  89. await writer.wait_closed()
  90. finally:
  91. client_task.cancel()
  92. try:
  93. await client_task
  94. except asyncio.CancelledError:
  95. pass
  96. proxy_task.cancel()
  97. try:
  98. await proxy_task
  99. except asyncio.CancelledError:
  100. pass
  101. await proxy_server.stop()
  102. test_server.close()
  103. await test_server.wait_closed()
  104. ws_task.cancel()
  105. try:
  106. await ws_task
  107. except asyncio.CancelledError:
  108. pass
  109. @pytest.mark.asyncio
  110. async def test_authenticator_allows_valid_request(self):
  111. """Test that authenticator returning True allows the request."""
  112. message_handler = MessageServerHandler(
  113. listen_address="127.0.0.1",
  114. listen_port=0,
  115. proxy_port=0,
  116. )
  117. _, ws_task, ws_port = await start_websocket_server(
  118. message_handler, "127.0.0.1", 0
  119. )
  120. # Authenticator that allows requests with valid auth header
  121. async def auth_check(headers: dict) -> bool:
  122. return headers.get("authorization") == "Bearer valid_token"
  123. proxy_server = HTTPSProxyServer(
  124. host="127.0.0.1",
  125. port=0,
  126. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  127. authenticator=auth_check,
  128. )
  129. proxy_task = asyncio.create_task(proxy_server.start())
  130. await asyncio.sleep(0.5)
  131. proxy_addr = proxy_server.server.sockets[0].getsockname()
  132. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  133. # Start test server
  134. async def handle_request(reader, writer):
  135. response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK"
  136. writer.write(response)
  137. await writer.drain()
  138. writer.close()
  139. await writer.wait_closed()
  140. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  141. test_port = test_server.sockets[0].getsockname()[1]
  142. # Connect client
  143. from gpustack.websocket_proxy.message_client import MessageClient
  144. client = MessageClient(
  145. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  146. client_id=uuid.uuid4(),
  147. cidrs=["127.0.0.1/32"],
  148. )
  149. client_task = asyncio.create_task(client.run())
  150. await asyncio.sleep(0.5)
  151. try:
  152. # Connect through proxy using CONNECT tunnel with auth
  153. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  154. connect_request = (
  155. f"CONNECT 127.0.0.1:{test_port} HTTP/1.1\r\n"
  156. f"Host: 127.0.0.1:{test_port}\r\n"
  157. f"Authorization: Bearer valid_token\r\n\r\n"
  158. )
  159. writer.write(connect_request.encode())
  160. await writer.drain()
  161. connect_response = await asyncio.wait_for(
  162. reader.readuntil(b"\r\n\r\n"), timeout=5.0
  163. )
  164. assert connect_response.startswith(b"HTTP/1.1 200")
  165. # Make HTTP request through tunnel (auth already validated by CONNECT)
  166. request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{test_port}\r\n\r\n"
  167. writer.write(request.encode())
  168. await writer.drain()
  169. response = await asyncio.wait_for(reader.read(1024), timeout=5.0)
  170. assert b"200 OK" in response, f"Expected 200 OK, got: {response}"
  171. writer.close()
  172. await writer.wait_closed()
  173. finally:
  174. client_task.cancel()
  175. try:
  176. await client_task
  177. except asyncio.CancelledError:
  178. pass
  179. proxy_task.cancel()
  180. try:
  181. await proxy_task
  182. except asyncio.CancelledError:
  183. pass
  184. await proxy_server.stop()
  185. test_server.close()
  186. await test_server.wait_closed()
  187. ws_task.cancel()
  188. try:
  189. await ws_task
  190. except asyncio.CancelledError:
  191. pass
  192. @pytest.mark.asyncio
  193. async def test_authenticator_rejects_invalid_request(self):
  194. """Test that authenticator returning False returns 401 Unauthorized."""
  195. message_handler = MessageServerHandler(
  196. listen_address="127.0.0.1",
  197. listen_port=0,
  198. proxy_port=0,
  199. )
  200. ws_server, ws_task, ws_port = await start_websocket_server(
  201. message_handler, "127.0.0.1", 0
  202. )
  203. # Authenticator that rejects requests
  204. async def auth_check(headers: dict) -> bool:
  205. return headers.get("authorization") == "Bearer valid_token"
  206. proxy_server = HTTPSProxyServer(
  207. host="127.0.0.1",
  208. port=0,
  209. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  210. authenticator=auth_check,
  211. )
  212. proxy_task = asyncio.create_task(proxy_server.start())
  213. await asyncio.sleep(0.5)
  214. proxy_addr = proxy_server.server.sockets[0].getsockname()
  215. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  216. # Connect client
  217. from gpustack.websocket_proxy.message_client import MessageClient
  218. client = MessageClient(
  219. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  220. client_id=uuid.uuid4(),
  221. cidrs=["127.0.0.1/32"],
  222. )
  223. client_task = asyncio.create_task(client.run())
  224. await asyncio.sleep(0.5)
  225. try:
  226. # Make request WITHOUT valid auth
  227. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  228. request = "GET / HTTP/1.1\r\nHost: 127.0.0.1:9999\r\n\r\n"
  229. writer.write(request.encode())
  230. await writer.drain()
  231. response = await asyncio.wait_for(reader.read(1024), timeout=5.0)
  232. assert b"401" in response, f"Expected 401 Unauthorized, got: {response}"
  233. writer.close()
  234. await writer.wait_closed()
  235. finally:
  236. client_task.cancel()
  237. try:
  238. await client_task
  239. except asyncio.CancelledError:
  240. pass
  241. proxy_task.cancel()
  242. try:
  243. await proxy_task
  244. except asyncio.CancelledError:
  245. pass
  246. await proxy_server.stop()
  247. ws_task.cancel()
  248. try:
  249. await ws_task
  250. except asyncio.CancelledError:
  251. pass
  252. @pytest.mark.asyncio
  253. async def test_authenticator_with_basic_auth(self):
  254. """Test authenticator with Basic authentication scheme."""
  255. message_handler = MessageServerHandler(
  256. listen_address="127.0.0.1",
  257. listen_port=0,
  258. proxy_port=0,
  259. )
  260. _, ws_task, ws_port = await start_websocket_server(
  261. message_handler, "127.0.0.1", 0
  262. )
  263. import base64
  264. async def auth_check(headers: dict) -> bool:
  265. auth = headers.get("authorization", "")
  266. if auth.lower().startswith("basic "):
  267. try:
  268. decoded = base64.b64decode(auth[6:]).decode("utf-8")
  269. return decoded == "admin:secret"
  270. except Exception:
  271. return False
  272. return False
  273. proxy_server = HTTPSProxyServer(
  274. host="127.0.0.1",
  275. port=0,
  276. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  277. authenticator=auth_check,
  278. )
  279. proxy_task = asyncio.create_task(proxy_server.start())
  280. await asyncio.sleep(0.5)
  281. proxy_addr = proxy_server.server.sockets[0].getsockname()
  282. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  283. # Start test server
  284. async def handle_request(reader, writer):
  285. response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK"
  286. writer.write(response)
  287. await writer.drain()
  288. writer.close()
  289. await writer.wait_closed()
  290. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  291. test_port = test_server.sockets[0].getsockname()[1]
  292. # Connect client
  293. from gpustack.websocket_proxy.message_client import MessageClient
  294. client = MessageClient(
  295. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  296. client_id=uuid.uuid4(),
  297. cidrs=["127.0.0.1/32"],
  298. )
  299. client_task = asyncio.create_task(client.run())
  300. await asyncio.sleep(0.5)
  301. try:
  302. # Connect through proxy using CONNECT tunnel with Basic auth
  303. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  304. credentials = base64.b64encode(b"admin:secret").decode("utf-8")
  305. connect_request = (
  306. f"CONNECT 127.0.0.1:{test_port} HTTP/1.1\r\n"
  307. f"Host: 127.0.0.1:{test_port}\r\n"
  308. f"Authorization: Basic {credentials}\r\n\r\n"
  309. )
  310. writer.write(connect_request.encode())
  311. await writer.drain()
  312. connect_response = await asyncio.wait_for(
  313. reader.readuntil(b"\r\n\r\n"), timeout=5.0
  314. )
  315. assert connect_response.startswith(b"HTTP/1.1 200")
  316. # Make HTTP request through tunnel (auth already validated by CONNECT)
  317. request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{test_port}\r\n\r\n"
  318. writer.write(request.encode())
  319. await writer.drain()
  320. response = await asyncio.wait_for(reader.read(1024), timeout=5.0)
  321. assert b"200 OK" in response, f"Expected 200 OK, got: {response}"
  322. writer.close()
  323. await writer.wait_closed()
  324. finally:
  325. client_task.cancel()
  326. try:
  327. await client_task
  328. except asyncio.CancelledError:
  329. pass
  330. proxy_task.cancel()
  331. try:
  332. await proxy_task
  333. except asyncio.CancelledError:
  334. pass
  335. await proxy_server.stop()
  336. test_server.close()
  337. await test_server.wait_closed()
  338. ws_task.cancel()
  339. try:
  340. await ws_task
  341. except asyncio.CancelledError:
  342. pass
  343. class TestHTTPProxy:
  344. """Test HTTP proxy direct forwarding (not CONNECT tunnel)."""
  345. @pytest.mark.asyncio
  346. async def test_http_proxy_with_content_length(self):
  347. """Test HTTP proxy with Content-Length request body (direct TCP forwarding)."""
  348. # Direct connection manager - creates direct TCP connections without WebSocket tunnel
  349. def direct_connection_manager_getter(_target_ip: str):
  350. from gpustack.websocket_proxy.connection_manager import ConnectionManager
  351. return ConnectionManager(websocket=None)
  352. # HTTP proxy with direct TCP connection
  353. proxy_server = HTTPSProxyServer(
  354. host="127.0.0.1",
  355. port=0,
  356. connection_manager_getter=direct_connection_manager_getter,
  357. )
  358. proxy_task = asyncio.create_task(proxy_server.start())
  359. await asyncio.sleep(0.5)
  360. proxy_addr = proxy_server.server.sockets[0].getsockname()
  361. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  362. # Track received body
  363. received_body = None
  364. # Start test HTTP server
  365. async def handle_request(reader, writer):
  366. nonlocal received_body
  367. try:
  368. # Read all request data
  369. request_data = b""
  370. while True:
  371. chunk = await reader.read(8192)
  372. if not chunk:
  373. break
  374. request_data += chunk
  375. # Check if we have complete request (headers + body)
  376. if b"\r\n\r\n" in request_data:
  377. # Parse Content-Length from headers
  378. header_end = request_data.find(b"\r\n\r\n")
  379. headers = request_data[:header_end].decode(
  380. "utf-8", errors="ignore"
  381. )
  382. content_length = 0
  383. for line in headers.split("\r\n"):
  384. if line.lower().startswith("content-length:"):
  385. content_length = int(line.split(":")[1].strip())
  386. break
  387. # Check if we have complete body
  388. body_start = header_end + 4
  389. body_len = len(request_data) - body_start
  390. if body_len >= content_length:
  391. received_body = request_data[
  392. body_start : body_start + content_length
  393. ]
  394. break
  395. if len(request_data) > 65536:
  396. break
  397. # Send response
  398. response_body = b"OK"
  399. response = (
  400. f"HTTP/1.1 200 OK\r\n"
  401. f"Content-Length: {len(response_body)}\r\n"
  402. f"\r\n"
  403. ).encode() + response_body
  404. writer.write(response)
  405. await writer.drain()
  406. except Exception as e:
  407. print(f"[Test Server] Error: {e}")
  408. finally:
  409. writer.close()
  410. await writer.wait_closed()
  411. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  412. test_port = test_server.sockets[0].getsockname()[1]
  413. try:
  414. # Send HTTP request directly through proxy (not CONNECT tunnel)
  415. post_body = b"This is the POST body via HTTP proxy"
  416. request = (
  417. f"POST http://127.0.0.1:{test_port}/ HTTP/1.1\r\n"
  418. f"Host: 127.0.0.1:{test_port}\r\n"
  419. f"Content-Length: {len(post_body)}\r\n"
  420. f"\r\n"
  421. ).encode() + post_body
  422. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  423. writer.write(request)
  424. await writer.drain()
  425. # Read response
  426. response = await asyncio.wait_for(reader.read(8192), timeout=5.0)
  427. assert b"200 OK" in response, f"Expected 200 OK, got: {response}"
  428. assert (
  429. received_body == post_body
  430. ), f"Body mismatch: {received_body!r} != {post_body!r}"
  431. writer.close()
  432. await writer.wait_closed()
  433. finally:
  434. proxy_task.cancel()
  435. try:
  436. await proxy_task
  437. except asyncio.CancelledError:
  438. pass
  439. await proxy_server.stop()
  440. test_server.close()
  441. await test_server.wait_closed()
  442. @pytest.mark.asyncio
  443. async def test_http_proxy_with_chunked_body(self): # noqa C901
  444. """Test HTTP proxy with chunked transfer encoding request body (direct TCP forwarding)."""
  445. # Direct connection manager - creates direct TCP connections without WebSocket tunnel
  446. def direct_connection_manager_getter(_target_ip: str):
  447. from gpustack.websocket_proxy.connection_manager import ConnectionManager
  448. return ConnectionManager(websocket=None)
  449. # HTTP proxy with direct TCP connection
  450. proxy_server = HTTPSProxyServer(
  451. host="127.0.0.1",
  452. port=0,
  453. connection_manager_getter=direct_connection_manager_getter,
  454. )
  455. proxy_task = asyncio.create_task(proxy_server.start())
  456. await asyncio.sleep(0.5)
  457. proxy_addr = proxy_server.server.sockets[0].getsockname()
  458. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  459. # Track received body
  460. received_body = None
  461. # Start test HTTP server that handles chunked body
  462. async def handle_chunked_request(reader, writer):
  463. nonlocal received_body
  464. import logging
  465. logger = logging.getLogger("test_server")
  466. try:
  467. logger.debug("[Test Server] Starting to read headers")
  468. # Read headers
  469. header_data = b""
  470. while b"\r\n\r\n" not in header_data:
  471. chunk = await reader.read(8192)
  472. if not chunk:
  473. logger.debug(
  474. "[Test Server] Connection closed while reading headers"
  475. )
  476. break
  477. header_data += chunk
  478. logger.debug(f"[Test Server] Received headers: {header_data[:200]!r}")
  479. header_end = header_data.find(b"\r\n\r\n")
  480. if header_end < 0:
  481. logger.debug("[Test Server] Headers incomplete, returning")
  482. return
  483. headers = header_data[:header_end].decode("utf-8", errors="ignore")
  484. body_data = header_data[header_end + 4 :]
  485. logger.debug(
  486. f"[Test Server] Headers complete, body_data: {body_data!r}"
  487. )
  488. # Check for Transfer-Encoding: chunked
  489. is_chunked = any(
  490. "transfer-encoding" in line.lower() and "chunked" in line.lower()
  491. for line in headers.split("\r\n")
  492. )
  493. logger.debug(f"[Test Server] is_chunked: {is_chunked}")
  494. if is_chunked:
  495. # Decode chunked body
  496. body = b""
  497. buffer = body_data # Use buffer to track unprocessed data
  498. async def read_until_crlf():
  499. """Read from buffer or socket until we have a complete line ending with CRLF."""
  500. nonlocal buffer
  501. while b"\r\n" not in buffer:
  502. c = await reader.read(1)
  503. if not c:
  504. return False # EOF
  505. buffer += c
  506. return True
  507. while True:
  508. logger.debug(f"[Test Server] Buffer: {buffer!r}")
  509. # Read chunk size line
  510. logger.debug("[Test Server] Reading chunk size line")
  511. if not await read_until_crlf():
  512. logger.debug("[Test Server] EOF while reading chunk size")
  513. break
  514. line_end = buffer.find(b"\r\n")
  515. line = buffer[:line_end]
  516. buffer = buffer[line_end + 2 :]
  517. logger.debug(f"[Test Server] Chunk size line: {line!r}")
  518. chunk_size = int(line.strip(), 16)
  519. if chunk_size == 0:
  520. logger.debug(
  521. "[Test Server] Got chunk size 0, chunked body complete"
  522. )
  523. break
  524. # Read chunk data
  525. logger.debug(
  526. f"[Test Server] Need {chunk_size} bytes of chunk data, buffer has {len(buffer)}"
  527. )
  528. while len(buffer) < chunk_size:
  529. needed = chunk_size - len(buffer)
  530. c = await reader.read(needed)
  531. if not c:
  532. logger.debug(
  533. "[Test Server] EOF while reading chunk data"
  534. )
  535. break
  536. buffer += c
  537. chunk = buffer[:chunk_size]
  538. buffer = buffer[chunk_size:]
  539. body += chunk
  540. logger.debug(f"[Test Server] Read chunk: {chunk!r}")
  541. # Read trailing \r\n after chunk
  542. if not await read_until_crlf():
  543. logger.debug(
  544. "[Test Server] EOF while reading chunk terminator"
  545. )
  546. break
  547. buffer = buffer[2:] # Skip the \r\n
  548. received_body = body
  549. logger.debug(
  550. f"[Test Server] Chunked body complete: {received_body!r}"
  551. )
  552. else:
  553. received_body = body_data
  554. logger.debug(f"[Test Server] Non-chunked body: {received_body!r}")
  555. print(f"[Test Server] Received body: {received_body!r}")
  556. # Send response
  557. response_body = b"OK"
  558. response = (
  559. f"HTTP/1.1 200 OK\r\n"
  560. f"Content-Length: {len(response_body)}\r\n"
  561. f"\r\n"
  562. ).encode() + response_body
  563. logger.debug("[Test Server] Sending response")
  564. writer.write(response)
  565. await writer.drain()
  566. logger.debug("[Test Server] Response sent, draining")
  567. except Exception as e:
  568. logger.exception(f"[Test Server] Error: {e}")
  569. finally:
  570. logger.debug("[Test Server] Closing connection")
  571. writer.close()
  572. await writer.wait_closed()
  573. logger.debug("[Test Server] Connection closed")
  574. test_server = await asyncio.start_server(handle_chunked_request, "127.0.0.1", 0)
  575. test_port = test_server.sockets[0].getsockname()[1]
  576. try:
  577. # Send HTTP request with chunked body
  578. post_body = b"Hello, chunked world!"
  579. # Chunked body format: <hex size>\r\n<data>\r\n...0\r\n\r\n
  580. chunked_body = (
  581. f"{len(post_body):x}\r\n".encode() + post_body + b"\r\n0\r\n\r\n"
  582. )
  583. request = (
  584. f"POST http://127.0.0.1:{test_port}/ HTTP/1.1\r\n"
  585. f"Host: 127.0.0.1:{test_port}\r\n"
  586. f"Transfer-Encoding: chunked\r\n"
  587. f"\r\n"
  588. ).encode() + chunked_body
  589. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  590. writer.write(request)
  591. await writer.drain()
  592. # Read response
  593. response = await asyncio.wait_for(reader.read(8192), timeout=5.0)
  594. assert b"200 OK" in response, f"Expected 200 OK, got: {response}"
  595. assert (
  596. received_body == post_body
  597. ), f"Body mismatch: {received_body!r} != {post_body!r}"
  598. writer.close()
  599. await writer.wait_closed()
  600. finally:
  601. proxy_task.cancel()
  602. try:
  603. await proxy_task
  604. except asyncio.CancelledError:
  605. pass
  606. await proxy_server.stop()
  607. test_server.close()
  608. await test_server.wait_closed()
  609. if __name__ == "__main__":
  610. pytest.main([__file__, "-v"])