test_connection.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. """
  2. Tests for connection module (TunnelConnection, AsyncIOConnection).
  3. """
  4. import asyncio
  5. import pytest
  6. import uuid
  7. import uvicorn
  8. from fastapi import FastAPI
  9. from gpustack.websocket_proxy.proxy_server import HTTPSProxyServer
  10. from gpustack.websocket_proxy.message_server import MessageServerHandler, router
  11. from gpustack.websocket_proxy.message_client import MessageClient
  12. def get_free_port(host: str) -> int:
  13. """Get a free port on the host"""
  14. import socket
  15. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  16. s.bind((host, 0))
  17. return s.getsockname()[1]
  18. async def start_websocket_server(
  19. message_handler: MessageServerHandler, host: str, port: int
  20. ):
  21. """Start WebSocket server using uvicorn"""
  22. app = FastAPI()
  23. app.state.message_server_handler = message_handler
  24. app.include_router(router)
  25. actual_port = get_free_port(host) if port == 0 else port
  26. config = uvicorn.Config(app, host=host, port=actual_port, log_level="error")
  27. server = uvicorn.Server(config)
  28. server_task = asyncio.create_task(server.serve())
  29. await asyncio.sleep(0.5)
  30. return server, server_task, actual_port
  31. class TestConnectTunnel:
  32. """Test CONNECT tunnel functionality.
  33. These tests verify that POST request bodies can be sent through a CONNECT tunnel.
  34. In a proper implementation, the proxy should establish a direct TCP connection
  35. to the target and relay data bidirectionally.
  36. """
  37. @pytest.mark.asyncio
  38. async def test_connect_tunnel_with_post_body(self): # noqa C901
  39. """Test POST request body through CONNECT tunnel.
  40. This test sends a raw TCP connection through the proxy using CONNECT,
  41. then sends a POST request through that tunnel. The target server should
  42. receive the POST body and respond.
  43. Note: This test may fail with the current implementation because
  44. WebSocket tunnel mode is designed for HTTP proxy, not raw TCP tunneling.
  45. """
  46. # Setup WebSocket server
  47. message_handler = MessageServerHandler(
  48. listen_address="127.0.0.1",
  49. listen_port=0,
  50. proxy_port=0,
  51. )
  52. _, ws_task, ws_port = await start_websocket_server(
  53. message_handler, "127.0.0.1", 0
  54. )
  55. # Setup HTTP proxy server
  56. proxy_server = HTTPSProxyServer(
  57. host="127.0.0.1",
  58. port=0,
  59. connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
  60. )
  61. proxy_task = asyncio.create_task(proxy_server.start())
  62. await asyncio.sleep(0.5)
  63. proxy_addr = proxy_server.server.sockets[0].getsockname()
  64. proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
  65. # Track what body we received at the target server
  66. received_body = None
  67. # Start test HTTP server
  68. async def handle_request(reader, writer):
  69. nonlocal received_body
  70. try:
  71. # Read headers first
  72. header_data = b""
  73. while b"\r\n\r\n" not in header_data:
  74. chunk = await reader.read(8192)
  75. if not chunk:
  76. break
  77. header_data += chunk
  78. header_end = header_data.find(b"\r\n\r\n")
  79. if header_end < 0:
  80. return
  81. headers = header_data[:header_end].decode("utf-8", errors="ignore")
  82. body_received = header_data[header_end + 4 :]
  83. # Parse Content-Length
  84. content_length = 0
  85. for line in headers.split("\r\n"):
  86. if line.lower().startswith("content-length:"):
  87. content_length = int(line.split(":")[1].strip())
  88. break
  89. # Continue reading body if needed
  90. while len(body_received) < content_length:
  91. chunk = await reader.read(8192)
  92. if not chunk:
  93. break
  94. body_received += chunk
  95. received_body = body_received
  96. print(f"[Test Server] Received body: {received_body!r}")
  97. # Send response
  98. response_body = b'POST_RECEIVED'
  99. response = (
  100. f"HTTP/1.1 200 OK\r\n"
  101. f"Content-Length: {len(response_body)}\r\n"
  102. f"\r\n"
  103. ).encode() + response_body
  104. writer.write(response)
  105. await writer.drain()
  106. except Exception as e:
  107. print(f"[Test Server] Error: {e}")
  108. finally:
  109. writer.close()
  110. await writer.wait_closed()
  111. test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
  112. test_port = test_server.sockets[0].getsockname()[1]
  113. # Connect MessageClient
  114. client = MessageClient(
  115. server_endpoint=f"ws://127.0.0.1:{ws_port}",
  116. client_id=uuid.uuid4(),
  117. cidrs=["127.0.0.1/32"],
  118. )
  119. client_task = asyncio.create_task(client.run())
  120. await asyncio.sleep(0.5)
  121. # Make a raw TCP connection through the proxy using CONNECT
  122. reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
  123. try:
  124. # Send CONNECT request
  125. connect_request = f"CONNECT 127.0.0.1:{test_port} HTTP/1.1\r\nHost: 127.0.0.1:{test_port}\r\n\r\n"
  126. writer.write(connect_request.encode())
  127. await writer.drain()
  128. # Read CONNECT response with timeout
  129. connect_response = await asyncio.wait_for(
  130. reader.readuntil(b"\r\n\r\n"), timeout=5.0
  131. )
  132. assert connect_response.startswith(
  133. b"HTTP/1.1 200"
  134. ), f"CONNECT failed: {connect_response}"
  135. # Now send a POST request through the tunnel
  136. post_body = b"This is the POST body sent through CONNECT tunnel"
  137. post_request = (
  138. f"POST / HTTP/1.1\r\n"
  139. f"Host: 127.0.0.1:{test_port}\r\n"
  140. f"Content-Length: {len(post_body)}\r\n"
  141. f"Connection: keep-alive\r\n"
  142. f"\r\n"
  143. ).encode() + post_body
  144. writer.write(post_request)
  145. await writer.drain()
  146. # Read response with timeout
  147. response = await asyncio.wait_for(reader.read(8192), timeout=5.0)
  148. assert b"200 OK" in response, f"Expected 200 OK in response: {response}"
  149. assert (
  150. received_body == post_body
  151. ), f"Body mismatch: {received_body!r} != {post_body!r}"
  152. finally:
  153. writer.close()
  154. await writer.wait_closed()
  155. client_task.cancel()
  156. try:
  157. await client_task
  158. except asyncio.CancelledError:
  159. pass
  160. proxy_task.cancel()
  161. try:
  162. await proxy_task
  163. except asyncio.CancelledError:
  164. pass
  165. await proxy_server.stop()
  166. test_server.close()
  167. await test_server.wait_closed()
  168. ws_task.cancel()
  169. try:
  170. await ws_task
  171. except asyncio.CancelledError:
  172. pass
  173. if __name__ == "__main__":
  174. pytest.main([__file__, "-v"])