test_websocket_bench.py 32 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. """
  2. Test suite for WebSocket proxy benchmark functionality.
  3. Tests WebSocket proxy over HTTP tunnel with various scenarios.
  4. """
  5. import asyncio
  6. import pytest
  7. import time
  8. import uuid
  9. import uvicorn
  10. from fastapi import FastAPI
  11. try:
  12. import aiohttp
  13. except ImportError:
  14. pytest.fail("aiohttp package not installed")
  15. from gpustack.websocket_proxy.proxy_server import HTTPSProxyServer
  16. from gpustack.websocket_proxy.message_server import MessageServerHandler, router
  17. from gpustack.websocket_proxy.message_client import MessageClient
  18. def get_free_port(host: str) -> int:
  19. """Get a free port on the host"""
  20. import socket
  21. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  22. s.bind((host, 0))
  23. return s.getsockname()[1]
  24. async def start_websocket_server(
  25. message_handler: MessageServerHandler, host: str, port: int
  26. ):
  27. """Start WebSocket server using uvicorn"""
  28. app = FastAPI()
  29. app.state.message_server_handler = message_handler
  30. app.include_router(router)
  31. # Get a free port if port is 0
  32. actual_port = get_free_port(host) if port == 0 else port
  33. config = uvicorn.Config(app, host=host, port=actual_port, log_level="error")
  34. server = uvicorn.Server(config)
  35. # Start server in background
  36. server_task = asyncio.create_task(server.serve())
  37. # Wait for server to start
  38. await asyncio.sleep(0.5)
  39. return server, server_task, actual_port
  40. class TestProxyWebSocketTunnel:
  41. """Test proxy functionality over WebSocket tunnel."""
  42. @pytest.mark.asyncio
  43. async def test_response_data_integrity(self):
  44. """Test that the proxy returns exact response data from the server."""
  45. message_handler = MessageServerHandler(
  46. listen_address="127.0.0.1",
  47. listen_port=0,
  48. proxy_port=0,
  49. )
  50. ws_server, ws_task, ws_port = await start_websocket_server(
  51. message_handler, "127.0.0.1", 0
  52. )
  53. proxy_server = HTTPSProxyServer(
  54. host="127.0.0.1",
  55. port=0,
  56. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  57. )
  58. proxy_task = asyncio.create_task(proxy_server.start())
  59. await asyncio.sleep(0.5)
  60. proxy_addr = proxy_server.server.sockets[0].getsockname()
  61. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  62. # Test server returns specific response body
  63. expected_body = b"Hello, WebSocket Proxy!"
  64. expected_ctype = "text/plain"
  65. async def handle_request(reader, writer):
  66. try:
  67. await reader.read(8192)
  68. response = (
  69. "HTTP/1.1 200 OK\r\n"
  70. f"Content-Type: {expected_ctype}\r\n"
  71. f"Content-Length: {len(expected_body)}\r\n"
  72. "\r\n"
  73. ).encode() + expected_body
  74. writer.write(response)
  75. await writer.drain()
  76. except Exception:
  77. pass
  78. finally:
  79. writer.close()
  80. await writer.wait_closed()
  81. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  82. test_port = test_server.sockets[0].getsockname()[1]
  83. client = MessageClient(
  84. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  85. client_id=uuid.uuid4(),
  86. cidrs=["127.0.0.1/32"],
  87. )
  88. client_task = asyncio.create_task(client.run())
  89. await asyncio.sleep(0.5)
  90. url = f"http://127.0.0.1:{test_port}/"
  91. connector = aiohttp.TCPConnector(limit=1)
  92. async with aiohttp.ClientSession(
  93. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  94. ) as session:
  95. async with session.get(
  96. url, timeout=aiohttp.ClientTimeout(total=10)
  97. ) as resp:
  98. body = await resp.read()
  99. assert resp.status == 200
  100. assert (
  101. body == expected_body
  102. ), f"Body mismatch: {body!r} != {expected_body!r}"
  103. assert resp.headers.get("Content-Type") == expected_ctype
  104. # Cleanup
  105. client_task.cancel()
  106. try:
  107. await client_task
  108. except asyncio.CancelledError:
  109. pass
  110. proxy_task.cancel()
  111. try:
  112. await proxy_task
  113. except asyncio.CancelledError:
  114. pass
  115. await proxy_server.stop()
  116. test_server.close()
  117. await test_server.wait_closed()
  118. ws_task.cancel()
  119. try:
  120. await ws_task
  121. except asyncio.CancelledError:
  122. pass
  123. @pytest.mark.asyncio
  124. async def test_chunked_transfer_encoding(self):
  125. """Test that chunked transfer encoding is correctly forwarded through the proxy."""
  126. message_handler = MessageServerHandler(
  127. listen_address="127.0.0.1",
  128. listen_port=0,
  129. proxy_port=0,
  130. )
  131. ws_server, ws_task, ws_port = await start_websocket_server(
  132. message_handler, "127.0.0.1", 0
  133. )
  134. proxy_server = HTTPSProxyServer(
  135. host="127.0.0.1",
  136. port=0,
  137. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  138. )
  139. proxy_task = asyncio.create_task(proxy_server.start())
  140. await asyncio.sleep(0.5)
  141. proxy_addr = proxy_server.server.sockets[0].getsockname()
  142. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  143. # Test server returns chunked response
  144. # Chunked response: each chunk is "size\r\ndata\r\n", final chunk is "0\r\n\r\n"
  145. chunks = [b"Hello", b" World", b"!"]
  146. expected_body = b"".join(chunks)
  147. async def handle_request(reader, writer):
  148. try:
  149. await reader.read(8192)
  150. # Build chunked response
  151. response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"
  152. for chunk in chunks:
  153. response += f"{len(chunk):x}\r\n".encode()
  154. response += chunk + b"\r\n"
  155. response += b"0\r\n\r\n"
  156. writer.write(response)
  157. await writer.drain()
  158. except Exception as e:
  159. print(f"[Test Server] Error: {e}")
  160. finally:
  161. writer.close()
  162. await writer.wait_closed()
  163. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  164. test_port = test_server.sockets[0].getsockname()[1]
  165. client = MessageClient(
  166. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  167. client_id=uuid.uuid4(),
  168. cidrs=["127.0.0.1/32"],
  169. )
  170. client_task = asyncio.create_task(client.run())
  171. await asyncio.sleep(0.5)
  172. url = f"http://127.0.0.1:{test_port}/"
  173. connector = aiohttp.TCPConnector(limit=1)
  174. async with aiohttp.ClientSession(
  175. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  176. ) as session:
  177. async with session.get(
  178. url, timeout=aiohttp.ClientTimeout(total=10)
  179. ) as resp:
  180. body = await resp.read()
  181. assert resp.status == 200, f"Expected 200, got {resp.status}"
  182. assert (
  183. body == expected_body
  184. ), f"Body mismatch: {body!r} != {expected_body!r}"
  185. # aiohttp should automatically decode chunked transfer
  186. print(f"[Test] Chunked response received: {body!r}")
  187. # Cleanup
  188. client_task.cancel()
  189. try:
  190. await client_task
  191. except asyncio.CancelledError:
  192. pass
  193. proxy_task.cancel()
  194. try:
  195. await proxy_task
  196. except asyncio.CancelledError:
  197. pass
  198. await proxy_server.stop()
  199. test_server.close()
  200. await test_server.wait_closed()
  201. ws_task.cancel()
  202. try:
  203. await ws_task
  204. except asyncio.CancelledError:
  205. pass
  206. @pytest.mark.asyncio
  207. async def test_response_json_data(self):
  208. """Test that JSON response is correctly forwarded through the proxy."""
  209. message_handler = MessageServerHandler(
  210. listen_address="127.0.0.1",
  211. listen_port=0,
  212. proxy_port=0,
  213. )
  214. ws_server, ws_task, ws_port = await start_websocket_server(
  215. message_handler, "127.0.0.1", 0
  216. )
  217. proxy_server = HTTPSProxyServer(
  218. host="127.0.0.1",
  219. port=0,
  220. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  221. )
  222. proxy_task = asyncio.create_task(proxy_server.start())
  223. await asyncio.sleep(0.5)
  224. proxy_addr = proxy_server.server.sockets[0].getsockname()
  225. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  226. import json
  227. expected_json = {"status": "ok", "data": [1, 2, 3], "message": "proxy works"}
  228. expected_body = json.dumps(expected_json).encode()
  229. async def handle_request(reader, writer):
  230. try:
  231. await reader.read(8192)
  232. response = (
  233. "HTTP/1.1 200 OK\r\n"
  234. f"Content-Type: application/json\r\n"
  235. f"Content-Length: {len(expected_body)}\r\n"
  236. "\r\n"
  237. ).encode() + expected_body
  238. writer.write(response)
  239. await writer.drain()
  240. except Exception:
  241. pass
  242. finally:
  243. writer.close()
  244. await writer.wait_closed()
  245. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  246. test_port = test_server.sockets[0].getsockname()[1]
  247. client = MessageClient(
  248. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  249. client_id=uuid.uuid4(),
  250. cidrs=["127.0.0.1/32"],
  251. )
  252. client_task = asyncio.create_task(client.run())
  253. await asyncio.sleep(0.5)
  254. url = f"http://127.0.0.1:{test_port}/api/data"
  255. connector = aiohttp.TCPConnector(limit=1)
  256. async with aiohttp.ClientSession(
  257. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  258. ) as session:
  259. async with session.get(
  260. url, timeout=aiohttp.ClientTimeout(total=10)
  261. ) as resp:
  262. data = await resp.json()
  263. assert (
  264. data == expected_json
  265. ), f"JSON mismatch: {data} != {expected_json}"
  266. # Cleanup
  267. client_task.cancel()
  268. try:
  269. await client_task
  270. except asyncio.CancelledError:
  271. pass
  272. proxy_task.cancel()
  273. try:
  274. await proxy_task
  275. except asyncio.CancelledError:
  276. pass
  277. await proxy_server.stop()
  278. test_server.close()
  279. await test_server.wait_closed()
  280. ws_task.cancel()
  281. try:
  282. await ws_task
  283. except asyncio.CancelledError:
  284. pass
  285. @pytest.mark.asyncio
  286. async def test_large_response_data(self):
  287. """Test that large response body is fully returned through the proxy."""
  288. message_handler = MessageServerHandler(
  289. listen_address="127.0.0.1",
  290. listen_port=0,
  291. proxy_port=0,
  292. )
  293. ws_server, ws_task, ws_port = await start_websocket_server(
  294. message_handler, "127.0.0.1", 0
  295. )
  296. proxy_server = HTTPSProxyServer(
  297. host="127.0.0.1",
  298. port=0,
  299. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  300. )
  301. proxy_task = asyncio.create_task(proxy_server.start())
  302. await asyncio.sleep(0.5)
  303. proxy_addr = proxy_server.server.sockets[0].getsockname()
  304. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  305. # 64KB response body
  306. response_size = 64 * 1024
  307. expected_body = b"X" * response_size
  308. async def handle_request(reader, writer):
  309. try:
  310. await reader.read(8192)
  311. response = (
  312. "HTTP/1.1 200 OK\r\n" f"Content-Length: {response_size}\r\n" "\r\n"
  313. ).encode() + expected_body
  314. writer.write(response)
  315. await writer.drain()
  316. except Exception:
  317. pass
  318. finally:
  319. writer.close()
  320. await writer.wait_closed()
  321. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  322. test_port = test_server.sockets[0].getsockname()[1]
  323. client = MessageClient(
  324. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  325. client_id=uuid.uuid4(),
  326. cidrs=["127.0.0.1/32"],
  327. )
  328. client_task = asyncio.create_task(client.run())
  329. await asyncio.sleep(0.5)
  330. url = f"http://127.0.0.1:{test_port}/"
  331. connector = aiohttp.TCPConnector(limit=1)
  332. async with aiohttp.ClientSession(
  333. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  334. ) as session:
  335. async with session.get(
  336. url, timeout=aiohttp.ClientTimeout(total=10)
  337. ) as resp:
  338. body = await resp.read()
  339. assert (
  340. len(body) == response_size
  341. ), f"Size mismatch: {len(body)} != {response_size}"
  342. assert body == expected_body, "Body content mismatch"
  343. # Cleanup
  344. client_task.cancel()
  345. try:
  346. await client_task
  347. except asyncio.CancelledError:
  348. pass
  349. proxy_task.cancel()
  350. try:
  351. await proxy_task
  352. except asyncio.CancelledError:
  353. pass
  354. await proxy_server.stop()
  355. test_server.close()
  356. await test_server.wait_closed()
  357. ws_task.cancel()
  358. try:
  359. await ws_task
  360. except asyncio.CancelledError:
  361. pass
  362. @pytest.mark.asyncio
  363. async def test_response_headers_forwarded(self):
  364. """Test that response headers are correctly forwarded through the proxy."""
  365. message_handler = MessageServerHandler(
  366. listen_address="127.0.0.1",
  367. listen_port=0,
  368. proxy_port=0,
  369. )
  370. ws_server, ws_task, ws_port = await start_websocket_server(
  371. message_handler, "127.0.0.1", 0
  372. )
  373. proxy_server = HTTPSProxyServer(
  374. host="127.0.0.1",
  375. port=0,
  376. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  377. )
  378. proxy_task = asyncio.create_task(proxy_server.start())
  379. await asyncio.sleep(0.5)
  380. proxy_addr = proxy_server.server.sockets[0].getsockname()
  381. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  382. expected_body = b"headers-test"
  383. custom_header_value = "custom-value-123"
  384. async def handle_request(reader, writer):
  385. try:
  386. await reader.read(8192)
  387. response = (
  388. "HTTP/1.1 200 OK\r\n"
  389. "Content-Type: text/custom\r\n"
  390. f"X-Custom-Header: {custom_header_value}\r\n"
  391. f"Content-Length: {len(expected_body)}\r\n"
  392. "\r\n"
  393. ).encode() + expected_body
  394. writer.write(response)
  395. await writer.drain()
  396. except Exception:
  397. pass
  398. finally:
  399. writer.close()
  400. await writer.wait_closed()
  401. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  402. test_port = test_server.sockets[0].getsockname()[1]
  403. client = MessageClient(
  404. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  405. client_id=uuid.uuid4(),
  406. cidrs=["127.0.0.1/32"],
  407. )
  408. client_task = asyncio.create_task(client.run())
  409. await asyncio.sleep(0.5)
  410. url = f"http://127.0.0.1:{test_port}/"
  411. connector = aiohttp.TCPConnector(limit=1)
  412. async with aiohttp.ClientSession(
  413. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  414. ) as session:
  415. async with session.get(
  416. url, timeout=aiohttp.ClientTimeout(total=10)
  417. ) as resp:
  418. assert resp.status == 200
  419. assert resp.headers.get("Content-Type") == "text/custom"
  420. assert resp.headers.get("X-Custom-Header") == custom_header_value
  421. body = await resp.read()
  422. assert body == expected_body
  423. # Cleanup
  424. client_task.cancel()
  425. try:
  426. await client_task
  427. except asyncio.CancelledError:
  428. pass
  429. proxy_task.cancel()
  430. try:
  431. await proxy_task
  432. except asyncio.CancelledError:
  433. pass
  434. await proxy_server.stop()
  435. test_server.close()
  436. await test_server.wait_closed()
  437. ws_task.cancel()
  438. try:
  439. await ws_task
  440. except asyncio.CancelledError:
  441. pass
  442. @pytest.mark.asyncio
  443. async def test_single_get_request(self):
  444. """Test a single GET request through WebSocket proxy tunnel."""
  445. # Setup WebSocket server
  446. message_handler = MessageServerHandler(
  447. listen_address="127.0.0.1",
  448. listen_port=0,
  449. proxy_port=0,
  450. )
  451. ws_server, ws_task, ws_port = await start_websocket_server(
  452. message_handler, "127.0.0.1", 0
  453. )
  454. print(f"[Test] WebSocket server started on port {ws_port}")
  455. # Setup HTTP proxy server
  456. proxy_server = HTTPSProxyServer(
  457. host="127.0.0.1",
  458. port=0,
  459. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  460. )
  461. proxy_task = asyncio.create_task(proxy_server.start())
  462. await asyncio.sleep(0.5)
  463. proxy_addr = proxy_server.server.sockets[0].getsockname()
  464. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  465. # Start test HTTP server
  466. async def handle_request(reader, writer):
  467. try:
  468. await reader.read(8192)
  469. response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK"
  470. writer.write(response)
  471. await writer.drain()
  472. except Exception:
  473. pass
  474. finally:
  475. writer.close()
  476. await writer.wait_closed()
  477. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  478. test_port = test_server.sockets[0].getsockname()[1]
  479. # Connect MessageClient to register the test server IP
  480. client = MessageClient(
  481. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  482. client_id="test-client",
  483. cidrs=["127.0.0.1/32"], # Register test server IP
  484. )
  485. client_task = asyncio.create_task(client.run())
  486. await asyncio.sleep(0.5)
  487. # Make request through proxy
  488. url = f"http://127.0.0.1:{test_port}/"
  489. connector = aiohttp.TCPConnector(limit=1)
  490. async with aiohttp.ClientSession(
  491. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  492. ) as session:
  493. start = time.time()
  494. async with session.get(
  495. url, timeout=aiohttp.ClientTimeout(total=10)
  496. ) as resp:
  497. body = await resp.text()
  498. assert resp.status == 200, f"Expected 200, got {resp.status}"
  499. assert body == "OK", f"Expected 'OK', got {body!r}"
  500. elapsed = time.time() - start
  501. # Cleanup
  502. client_task.cancel()
  503. try:
  504. await client_task
  505. except asyncio.CancelledError:
  506. pass
  507. proxy_task.cancel()
  508. try:
  509. await proxy_task
  510. except asyncio.CancelledError:
  511. pass
  512. await proxy_server.stop()
  513. test_server.close()
  514. await test_server.wait_closed()
  515. ws_task.cancel()
  516. try:
  517. await ws_task
  518. except asyncio.CancelledError:
  519. pass
  520. # Should complete in reasonable time
  521. assert elapsed < 5.0, f"Request took too long: {elapsed:.2f}s"
  522. @pytest.mark.asyncio
  523. async def test_single_post_request(self):
  524. """Test a single POST request through WebSocket proxy tunnel."""
  525. # Setup WebSocket server
  526. message_handler = MessageServerHandler(
  527. listen_address="127.0.0.1",
  528. listen_port=0,
  529. proxy_port=0,
  530. )
  531. _, ws_task, ws_port = await start_websocket_server(
  532. message_handler, "127.0.0.1", 0
  533. )
  534. # Setup HTTP proxy server
  535. proxy_server = HTTPSProxyServer(
  536. host="127.0.0.1",
  537. port=0,
  538. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  539. )
  540. proxy_task = asyncio.create_task(proxy_server.start())
  541. await asyncio.sleep(0.5)
  542. proxy_addr = proxy_server.server.sockets[0].getsockname()
  543. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  544. # Start test HTTP server
  545. request_size = 1024 # 1KB
  546. response_size = 2048 # 2KB
  547. async def handle_request(reader, writer):
  548. try:
  549. # Read all headers
  550. content_length = 0
  551. while True:
  552. line = await reader.readline()
  553. if not line:
  554. return
  555. if line == b'\r\n':
  556. break
  557. if line.lower().startswith(b'content-length:'):
  558. content_length = int(line.split(b':')[1].strip())
  559. # Read body if present
  560. if content_length > 0:
  561. await reader.readexactly(content_length)
  562. # Send response
  563. response_body = b'X' * response_size
  564. response = (
  565. "HTTP/1.1 200 OK\r\n" f"Content-Length: {response_size}\r\n" "\r\n"
  566. ).encode() + response_body
  567. writer.write(response)
  568. await writer.drain()
  569. except Exception as e:
  570. print(f"[Test Server] Error: {e}")
  571. finally:
  572. writer.close()
  573. await writer.wait_closed()
  574. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  575. test_port = test_server.sockets[0].getsockname()[1]
  576. # Connect MessageClient
  577. client = MessageClient(
  578. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  579. client_id=uuid.uuid4(),
  580. cidrs=["127.0.0.1/32"],
  581. )
  582. client_task = asyncio.create_task(client.run())
  583. await asyncio.sleep(0.5)
  584. # Make POST request through proxy
  585. url = f"http://127.0.0.1:{test_port}/"
  586. data = b'A' * request_size
  587. connector = aiohttp.TCPConnector(limit=1)
  588. async with aiohttp.ClientSession(
  589. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  590. ) as session:
  591. start = time.time()
  592. async with session.post(
  593. url, data=data, timeout=aiohttp.ClientTimeout(total=10)
  594. ) as resp:
  595. received = await resp.read()
  596. elapsed = time.time() - start
  597. # Verify response content
  598. assert (
  599. len(received) == response_size
  600. ), f"Response size mismatch: {len(received)} != {response_size}"
  601. assert received == b'X' * response_size, "Response body content mismatch"
  602. # Cleanup
  603. client_task.cancel()
  604. try:
  605. await client_task
  606. except asyncio.CancelledError:
  607. pass
  608. proxy_task.cancel()
  609. try:
  610. await proxy_task
  611. except asyncio.CancelledError:
  612. pass
  613. await proxy_server.stop()
  614. test_server.close()
  615. await test_server.wait_closed()
  616. ws_task.cancel()
  617. try:
  618. await ws_task
  619. except asyncio.CancelledError:
  620. pass
  621. # Should complete in reasonable time
  622. assert elapsed < 5.0, f"Request took too long: {elapsed:.2f}s"
  623. class TestProxyThroughput:
  624. """Test proxy throughput with different payload sizes."""
  625. @pytest.mark.asyncio
  626. async def test_small_payload_throughput(self):
  627. """Test throughput with 512B payload."""
  628. await self._test_throughput(
  629. request_size=512, response_size=512, num_requests=50, concurrency=5
  630. )
  631. @pytest.mark.asyncio
  632. async def test_medium_payload_throughput(self):
  633. """Test throughput with 4KB payload."""
  634. await self._test_throughput(
  635. request_size=4096, response_size=4096, num_requests=50, concurrency=5
  636. )
  637. async def _test_throughput(
  638. self, request_size, response_size, num_requests, concurrency
  639. ):
  640. """Helper method to test throughput."""
  641. # Setup WebSocket server
  642. message_handler = MessageServerHandler(
  643. listen_address="127.0.0.1",
  644. listen_port=0,
  645. proxy_port=0,
  646. )
  647. ws_server, ws_task, ws_port = await start_websocket_server(
  648. message_handler, "127.0.0.1", 0
  649. )
  650. # Setup HTTP proxy server
  651. proxy_server = HTTPSProxyServer(
  652. host="127.0.0.1",
  653. port=0,
  654. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  655. )
  656. proxy_task = asyncio.create_task(proxy_server.start())
  657. await asyncio.sleep(0.5)
  658. proxy_addr = proxy_server.server.sockets[0].getsockname()
  659. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  660. # Start test HTTP server
  661. async def handle_request(reader, writer):
  662. try:
  663. await reader.read(8192)
  664. response_body = b'X' * response_size
  665. response = (
  666. f"HTTP/1.1 200 OK\r\n"
  667. f"Content-Length: {response_size}\r\n"
  668. f"\r\n"
  669. ).encode() + response_body
  670. writer.write(response)
  671. await writer.drain()
  672. except Exception:
  673. pass
  674. finally:
  675. writer.close()
  676. await writer.wait_closed()
  677. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  678. test_port = test_server.sockets[0].getsockname()[1]
  679. # Connect MessageClient
  680. client = MessageClient(
  681. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  682. client_id=uuid.uuid4(),
  683. cidrs=["127.0.0.1/32"],
  684. )
  685. client_task = asyncio.create_task(client.run())
  686. await asyncio.sleep(0.5)
  687. # Make requests through proxy
  688. url = f"http://127.0.0.1:{test_port}/"
  689. data = b'A' * request_size
  690. connector = aiohttp.TCPConnector(limit=concurrency)
  691. async with aiohttp.ClientSession(
  692. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  693. ) as session:
  694. start = time.time()
  695. async def make_request():
  696. req_start = time.time()
  697. async with session.post(
  698. url, data=data, timeout=aiohttp.ClientTimeout(total=10)
  699. ) as resp:
  700. received = await resp.read()
  701. assert (
  702. received == b'X' * response_size
  703. ), "Response body content mismatch"
  704. return time.time() - req_start
  705. # Run requests in batches
  706. for batch_start in range(0, num_requests, concurrency):
  707. batch_size = min(concurrency, num_requests - batch_start)
  708. tasks = [make_request() for _ in range(batch_size)]
  709. await asyncio.gather(*tasks)
  710. done = batch_start + batch_size
  711. if done % 25 == 0:
  712. print(f"[Test] Progress: {done}/{num_requests}")
  713. elapsed = time.time() - start
  714. # Cleanup
  715. client_task.cancel()
  716. try:
  717. await client_task
  718. except asyncio.CancelledError:
  719. pass
  720. proxy_task.cancel()
  721. try:
  722. await proxy_task
  723. except asyncio.CancelledError:
  724. pass
  725. await proxy_server.stop()
  726. test_server.close()
  727. await test_server.wait_closed()
  728. ws_task.cancel()
  729. try:
  730. await ws_task
  731. except asyncio.CancelledError:
  732. pass
  733. # Calculate throughput
  734. total_bytes = num_requests * (request_size + response_size)
  735. throughput_bps = total_bytes / elapsed
  736. throughput_mbps = throughput_bps / (1024 * 1024)
  737. print(
  738. f"\n[Test] Throughput: {throughput_mbps:.2f} MB/s ({throughput_bps / 1024:.0f} KB/s)"
  739. )
  740. # Throughput targets
  741. assert throughput_mbps >= 0.1, f"Throughput too low: {throughput_mbps:.2f} MB/s"
  742. class TestProxyLatency:
  743. """Test proxy latency characteristics."""
  744. @pytest.mark.asyncio
  745. async def test_latency_distribution(self):
  746. """Test distribution of request latencies."""
  747. # Setup WebSocket server
  748. message_handler = MessageServerHandler(
  749. listen_address="127.0.0.1",
  750. listen_port=0,
  751. proxy_port=0,
  752. )
  753. ws_server, ws_task, ws_port = await start_websocket_server(
  754. message_handler, "127.0.0.1", 0
  755. )
  756. # Setup HTTP proxy server
  757. proxy_server = HTTPSProxyServer(
  758. host="127.0.0.1",
  759. port=0,
  760. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  761. )
  762. proxy_task = asyncio.create_task(proxy_server.start())
  763. await asyncio.sleep(0.5)
  764. proxy_addr = proxy_server.server.sockets[0].getsockname()
  765. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  766. # Start test HTTP server
  767. async def handle_request(reader, writer):
  768. try:
  769. await reader.read(8192)
  770. response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
  771. writer.write(response)
  772. await writer.drain()
  773. except Exception:
  774. pass
  775. finally:
  776. writer.close()
  777. await writer.wait_closed()
  778. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  779. test_port = test_server.sockets[0].getsockname()[1]
  780. # Connect MessageClient
  781. client = MessageClient(
  782. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  783. client_id=uuid.uuid4(),
  784. cidrs=["127.0.0.1/32"],
  785. )
  786. client_task = asyncio.create_task(client.run())
  787. await asyncio.sleep(0.5)
  788. # Measure latencies
  789. url = f"http://127.0.0.1:{test_port}/"
  790. connector = aiohttp.TCPConnector(limit=10)
  791. num_requests = 50
  792. async with aiohttp.ClientSession(
  793. proxy=f"http://{proxy_host}:{proxy_port}", connector=connector
  794. ) as session:
  795. async def make_request():
  796. req_start = time.time()
  797. async with session.get(
  798. url, timeout=aiohttp.ClientTimeout(total=10)
  799. ) as resp:
  800. await resp.text()
  801. return time.time() - req_start
  802. tasks = [make_request() for _ in range(num_requests)]
  803. latencies = await asyncio.gather(*tasks)
  804. # Cleanup
  805. client_task.cancel()
  806. try:
  807. await client_task
  808. except asyncio.CancelledError:
  809. pass
  810. proxy_task.cancel()
  811. try:
  812. await proxy_task
  813. except asyncio.CancelledError:
  814. pass
  815. await proxy_server.stop()
  816. test_server.close()
  817. await test_server.wait_closed()
  818. ws_task.cancel()
  819. try:
  820. await ws_task
  821. except asyncio.CancelledError:
  822. pass
  823. # Calculate statistics
  824. avg_latency = sum(latencies) / len(latencies) * 1000
  825. p50_latency = sorted(latencies)[int(len(latencies) * 0.5)] * 1000
  826. p95_latency = sorted(latencies)[int(len(latencies) * 0.95)] * 1000
  827. max_latency = max(latencies) * 1000
  828. print(f"\n[Test] Latency Statistics ({num_requests} requests):")
  829. print(f"[Test] Average: {avg_latency:.2f}ms")
  830. print(f"[Test] P50: {p50_latency:.2f}ms")
  831. print(f"[Test] P95: {p95_latency:.2f}ms")
  832. print(f"[Test] Max: {max_latency:.2f}ms")
  833. # Performance targets
  834. assert avg_latency < 500, f"Average latency too high: {avg_latency:.2f}ms"
  835. assert p95_latency < 1000, f"P95 latency too high: {p95_latency:.2f}ms"
  836. if __name__ == "__main__":
  837. pytest.main([__file__, "-v"])