| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- #!/usr/bin/env python3
- """
- TunnelConnection - Represents a tunnel connection to target server
- """
- from abc import ABC, abstractmethod
- import asyncio
- import logging
- import uuid
- from functools import partial
- from typing import Optional, Union, TYPE_CHECKING, Callable, Coroutine, Any
- from fastapi import HTTPException
- if TYPE_CHECKING:
- from websockets.client import ClientConnection
- from websockets.server import ServerConnection
- from starlette.websockets import WebSocket as StarletteWebSocket
- from .message import DataMessage, DisconnectMessage, pack_message
- logger = logging.getLogger(__name__)
- class IOConnection(ABC):
- @abstractmethod
- async def read(self, n: int = -1, timeout: Optional[float] = None) -> bytes:
- pass
- @abstractmethod
- async def write(self, data: bytes) -> None:
- pass
- @abstractmethod
- async def close(self) -> None:
- pass
- class AsyncIOConnection(IOConnection):
- writer: asyncio.StreamWriter
- reader: asyncio.StreamReader
- def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
- self.reader = reader
- self.writer = writer
- async def read(self, n: int = -1, timeout: Optional[float] = None) -> bytes:
- if timeout is not None:
- return await asyncio.wait_for(self.reader.read(n), timeout)
- return await self.reader.read(n)
- async def write(self, data: bytes) -> None:
- self.writer.write(data)
- await self.writer.drain()
- async def close(self) -> None:
- self.writer.close()
- await self.writer.wait_closed()
- async def send_error(self, code: int, message: str) -> None:
- """Send error response to client with message in body"""
- reason_phrases = {
- 400: "Bad Request",
- 502: "Bad Gateway",
- 503: "Service Unavailable",
- }
- reason = reason_phrases.get(code, "Error")
- response = f"HTTP/1.1 {code} {reason}\r\n"
- response += "Content-Type: text/plain\r\n"
- response += f"Content-Length: {len(message)}\r\n"
- response += "Connection: close\r\n"
- response += "\r\n"
- response += message
- await self.write(response.encode())
- async def write_connect_established(self) -> None:
- """Send 200 Connection Established response to client"""
- response = b"HTTP/1.1 200 Connection Established\r\n\r\n"
- await self.write(response)
- class TunnelConnection(IOConnection):
- """Represents a tunnel connection to target server"""
- def __init__(
- self,
- session_id: uuid.UUID,
- websocket: Union[
- "ClientConnection", "ServerConnection", "StarletteWebSocket", None
- ],
- ) -> None:
- self.session_id = session_id
- self.websocket = websocket
- # Connection state
- self._pending_future: Optional[asyncio.Future[bool]] = (
- asyncio.get_running_loop().create_future()
- )
- # Response tracking queue for WebSocket tunnel mode
- self._response_queue: asyncio.Queue[bytes] = asyncio.Queue()
- self._connection_error: Optional[Exception] = None
- @property
- def is_pending(self) -> bool:
- """Check if connection is in pending state (waiting for response)"""
- return self._pending_future is not None and not self._pending_future.done()
- @property
- def is_connected(self) -> bool:
- """Check if connection is established"""
- return self._pending_future is None and self._connection_error is None
- @property
- def connect_result(self) -> asyncio.Future[bool]:
- """Get the future representing the connection state"""
- return self._pending_future
- def set_connected(self) -> None:
- """Mark connection as connected"""
- if self._pending_future is None:
- logger.warning(
- f"set_connected called but no pending future, session_id={self.session_id}"
- )
- return
- if self._pending_future and not self._pending_future.done():
- self._pending_future.set_result(True)
- self._pending_future = None
- def connect_error(self, error: Exception) -> None:
- """Mark connection as failed with error"""
- if self._pending_future is None:
- logger.warning(
- f"connect_error called but no pending future, session_id={self.session_id}, error={error}"
- )
- return
- if self._pending_future and not self._pending_future.done():
- self._pending_future.set_exception(error)
- self._connection_error = error
- self._pending_future = None
- async def _send_to_websocket(self, data: bytes) -> None:
- """Send data to WebSocket, compatible with Starlette and websockets library"""
- if hasattr(self.websocket, 'send_bytes'):
- await self.websocket.send_bytes(data)
- else:
- await self.websocket.send(data)
- async def handle_data(self, data: bytes) -> None:
- """Handle data received from WebSocket, forward to target"""
- logger.trace(
- f"[Tunnel] handle_data: session_id={self.session_id}, {len(data)} bytes"
- )
- if not self.is_connected:
- logger.warning(
- f"[Tunnel] Connection is pending or failed, ignoring data until connected, session_id={self.session_id}"
- )
- return
- if not data:
- logger.trace("[Tunnel] Empty data received, signaling EOF")
- else:
- logger.trace(f"[Tunnel] Queuing {len(data)} bytes for response tracking")
- await self._response_queue.put(data)
- return
- # Followings methods are for compatibility with IOConnection interface, used in tunnel function
- async def close(self) -> None:
- """Close connection"""
- # Send disconnect message to WebSocket
- logger.trace(
- f"[Tunnel] Closing connection, sending DisconnectMessage, session_id={self.session_id}"
- )
- try:
- msg = DisconnectMessage(session_id=self.session_id)
- await self._send_to_websocket(pack_message(msg))
- except Exception as e:
- logger.trace(
- f"[Tunnel] Failed to send DisconnectMessage (websocket may already be closed): {e}, session_id={self.session_id}"
- )
- await self._response_queue.put(b"") # Unblock any pending reads
- async def read(self, _n: int = -1, timeout: float = 3000.0) -> bytes:
- return await asyncio.wait_for(self._response_queue.get(), timeout=timeout)
- async def write(self, data: bytes) -> None:
- """Send data to WebSocket"""
- logger.trace(
- f"[Tunnel] Sending {len(data)} bytes to WebSocket, session_id={self.session_id}"
- )
- msg = DataMessage(session_id=self.session_id, data=data)
- await self._send_to_websocket(pack_message(msg))
- async def relay(
- reader: IOConnection,
- writer: IOConnection,
- name: str,
- ) -> None:
- """Relay data from ``reader`` to ``writer`` until EOF or error.
- Reads in 8 KiB chunks and writes each to the writer. Closes the writer
- when the reader signals EOF (empty bytes) or raises an exception.
- """
- try:
- while True:
- data = await reader.read(8192)
- if not data:
- logger.trace(f"{name}: read EOF")
- break
- logger.trace(f"{name}: forwarding {len(data)} bytes")
- await writer.write(data)
- except Exception as e:
- logger.error(f"{name}: error {e}")
- finally:
- logger.trace(f"{name}: closing {type(writer).__name__} writer")
- await writer.close()
- async def tunnel(
- client_connection: IOConnection,
- remote_connection: IOConnection,
- name: str = "Tunnel",
- request: Optional[str] = None,
- response_relay: Optional[
- Callable[[IOConnection, IOConnection], Coroutine[Any, Any, None]]
- ] = None,
- ) -> None:
- """Tunnel data between client and remote connections"""
- if request is not None:
- logger.trace(
- f"[{name}] client->remote: Sending initial request data through tunnel:\n{request[:500]}..."
- )
- try:
- await remote_connection.write(request.encode())
- except Exception as e:
- logger.error(f"[{name}] Error sending initial request data: {e}")
- await remote_connection.close()
- raise HTTPException(
- status_code=502, detail="Failed to send initial request data"
- )
- if response_relay is None:
- response_relay = partial(relay, name=f"[{name}] remote->client")
- await asyncio.gather(
- response_relay(remote_connection, client_connection),
- relay(client_connection, remote_connection, f"[{name}] client->remote"),
- return_exceptions=True,
- )
- logger.trace(f"[{name}] Tunnel closed")
|