| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- """
- Tests for connection module (TunnelConnection, AsyncIOConnection).
- """
- import asyncio
- import pytest
- import uuid
- import uvicorn
- from fastapi import FastAPI
- from gpustack.websocket_proxy.proxy_server import HTTPSProxyServer
- from gpustack.websocket_proxy.message_server import MessageServerHandler, router
- from gpustack.websocket_proxy.message_client import MessageClient
- def get_free_port(host: str) -> int:
- """Get a free port on the host"""
- import socket
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind((host, 0))
- return s.getsockname()[1]
- async def start_websocket_server(
- message_handler: MessageServerHandler, host: str, port: int
- ):
- """Start WebSocket server using uvicorn"""
- app = FastAPI()
- app.state.message_server_handler = message_handler
- app.include_router(router)
- actual_port = get_free_port(host) if port == 0 else port
- config = uvicorn.Config(app, host=host, port=actual_port, log_level="error")
- server = uvicorn.Server(config)
- server_task = asyncio.create_task(server.serve())
- await asyncio.sleep(0.5)
- return server, server_task, actual_port
- class TestConnectTunnel:
- """Test CONNECT tunnel functionality.
- These tests verify that POST request bodies can be sent through a CONNECT tunnel.
- In a proper implementation, the proxy should establish a direct TCP connection
- to the target and relay data bidirectionally.
- """
- @pytest.mark.asyncio
- async def test_connect_tunnel_with_post_body(self): # noqa C901
- """Test POST request body through CONNECT tunnel.
- This test sends a raw TCP connection through the proxy using CONNECT,
- then sends a POST request through that tunnel. The target server should
- receive the POST body and respond.
- Note: This test may fail with the current implementation because
- WebSocket tunnel mode is designed for HTTP proxy, not raw TCP tunneling.
- """
- # Setup WebSocket server
- message_handler = MessageServerHandler(
- listen_address="127.0.0.1",
- listen_port=0,
- proxy_port=0,
- )
- _, ws_task, ws_port = await start_websocket_server(
- message_handler, "127.0.0.1", 0
- )
- # Setup HTTP proxy server
- proxy_server = HTTPSProxyServer(
- host="127.0.0.1",
- port=0,
- connection_manager_getter=message_handler.get_connection_manager_by_ip_in_cidr,
- )
- proxy_task = asyncio.create_task(proxy_server.start())
- await asyncio.sleep(0.5)
- proxy_addr = proxy_server.server.sockets[0].getsockname()
- proxy_host, proxy_port = proxy_addr[0], proxy_addr[1]
- # Track what body we received at the target server
- received_body = None
- # Start test HTTP server
- async def handle_request(reader, writer):
- nonlocal received_body
- try:
- # Read headers first
- header_data = b""
- while b"\r\n\r\n" not in header_data:
- chunk = await reader.read(8192)
- if not chunk:
- break
- header_data += chunk
- header_end = header_data.find(b"\r\n\r\n")
- if header_end < 0:
- return
- headers = header_data[:header_end].decode("utf-8", errors="ignore")
- body_received = header_data[header_end + 4 :]
- # Parse Content-Length
- content_length = 0
- for line in headers.split("\r\n"):
- if line.lower().startswith("content-length:"):
- content_length = int(line.split(":")[1].strip())
- break
- # Continue reading body if needed
- while len(body_received) < content_length:
- chunk = await reader.read(8192)
- if not chunk:
- break
- body_received += chunk
- received_body = body_received
- print(f"[Test Server] Received body: {received_body!r}")
- # Send response
- response_body = b'POST_RECEIVED'
- response = (
- f"HTTP/1.1 200 OK\r\n"
- f"Content-Length: {len(response_body)}\r\n"
- f"\r\n"
- ).encode() + response_body
- writer.write(response)
- await writer.drain()
- except Exception as e:
- print(f"[Test Server] Error: {e}")
- finally:
- writer.close()
- await writer.wait_closed()
- test_server = await asyncio.start_server(handle_request, "127.0.0.1", 0)
- test_port = test_server.sockets[0].getsockname()[1]
- # Connect MessageClient
- client = MessageClient(
- server_endpoint=f"ws://127.0.0.1:{ws_port}",
- client_id=uuid.uuid4(),
- cidrs=["127.0.0.1/32"],
- )
- client_task = asyncio.create_task(client.run())
- await asyncio.sleep(0.5)
- # Make a raw TCP connection through the proxy using CONNECT
- reader, writer = await asyncio.open_connection(proxy_host, proxy_port)
- try:
- # Send CONNECT request
- connect_request = f"CONNECT 127.0.0.1:{test_port} HTTP/1.1\r\nHost: 127.0.0.1:{test_port}\r\n\r\n"
- writer.write(connect_request.encode())
- await writer.drain()
- # Read CONNECT response with timeout
- connect_response = await asyncio.wait_for(
- reader.readuntil(b"\r\n\r\n"), timeout=5.0
- )
- assert connect_response.startswith(
- b"HTTP/1.1 200"
- ), f"CONNECT failed: {connect_response}"
- # Now send a POST request through the tunnel
- post_body = b"This is the POST body sent through CONNECT tunnel"
- post_request = (
- f"POST / HTTP/1.1\r\n"
- f"Host: 127.0.0.1:{test_port}\r\n"
- f"Content-Length: {len(post_body)}\r\n"
- f"Connection: keep-alive\r\n"
- f"\r\n"
- ).encode() + post_body
- writer.write(post_request)
- await writer.drain()
- # Read response with timeout
- response = await asyncio.wait_for(reader.read(8192), timeout=5.0)
- assert b"200 OK" in response, f"Expected 200 OK in response: {response}"
- assert (
- received_body == post_body
- ), f"Body mismatch: {received_body!r} != {post_body!r}"
- finally:
- writer.close()
- await writer.wait_closed()
- client_task.cancel()
- try:
- await client_task
- except asyncio.CancelledError:
- pass
- proxy_task.cancel()
- try:
- await proxy_task
- except asyncio.CancelledError:
- pass
- await proxy_server.stop()
- test_server.close()
- await test_server.wait_closed()
- ws_task.cancel()
- try:
- await ws_task
- except asyncio.CancelledError:
- pass
- if __name__ == "__main__":
- pytest.main([__file__, "-v"])
|