connection_manager.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #!/usr/bin/env python3
  2. """
  3. Connection Managers - Handle tunnel connections lifecycle
  4. """
  5. import asyncio
  6. import logging
  7. import urllib.parse
  8. import uuid
  9. from typing import Optional, Dict, TYPE_CHECKING, Union, Protocol
  10. if TYPE_CHECKING:
  11. from websockets.client import ClientConnection
  12. from websockets.server import ServerConnection
  13. from starlette.websockets import WebSocket as StarletteWebSocket
  14. from .connection import TunnelConnection, IOConnection, AsyncIOConnection, tunnel
  15. from .message import (
  16. SessionBaseMessage,
  17. ConnectRequestMessage,
  18. ConnectResponseMessage,
  19. DataMessage,
  20. DisconnectMessage,
  21. pack_message,
  22. )
  23. logger = logging.getLogger(__name__)
  24. # ==================== Independent Handlers ====================
  25. class ConnectionManager:
  26. """Manages all tunnel connections lifecycle (server-side)"""
  27. def __init__(
  28. self,
  29. websocket: Optional[
  30. Union["ClientConnection", "ServerConnection", "StarletteWebSocket"]
  31. ] = None,
  32. ) -> None:
  33. self.websocket = websocket
  34. self._connections: Dict[uuid.UUID, TunnelConnection] = {}
  35. async def _send_to_websocket(self, data: bytes) -> None:
  36. """Send data to WebSocket, compatible with Starlette and websockets library"""
  37. if self.websocket is None:
  38. return
  39. if hasattr(self.websocket, 'send_bytes'):
  40. await self.websocket.send_bytes(data)
  41. else:
  42. await self.websocket.send(data)
  43. async def _direct_connect(
  44. self, session_id: uuid.UUID, target_url: str
  45. ) -> IOConnection:
  46. parsed = urllib.parse.urlparse(target_url)
  47. if parsed.scheme == "unix":
  48. reader, writer = await asyncio.open_unix_connection(parsed.path)
  49. else:
  50. reader, writer = await asyncio.open_connection(parsed.hostname, parsed.port)
  51. connection = AsyncIOConnection(reader=reader, writer=writer)
  52. self._connections[session_id] = connection
  53. return connection
  54. async def _websocket_connect(
  55. self, session_id: uuid.UUID, target_url: str
  56. ) -> TunnelConnection:
  57. connection = TunnelConnection(session_id, self.websocket)
  58. self._connections[session_id] = connection
  59. message = ConnectRequestMessage(session_id=session_id, target_url=target_url)
  60. await self._send_to_websocket(pack_message(message))
  61. logger.trace(
  62. f"[ConnectionManager] Sent CONNECT_REQUEST for {target_url} (session={session_id})"
  63. )
  64. try:
  65. await asyncio.wait_for(connection.connect_result, timeout=30)
  66. except (asyncio.TimeoutError, TimeoutError):
  67. self._connections.pop(session_id, None)
  68. raise TimeoutError(f"Connection to {target_url} timed out")
  69. except asyncio.CancelledError:
  70. self._connections.pop(session_id, None)
  71. raise
  72. return connection
  73. async def connect(
  74. self,
  75. target_url: str,
  76. ) -> IOConnection:
  77. """Create new connection: use WebSocket tunnel or direct TCP/Unix
  78. URL format: tcp://host:port or unix:///path/to/socket
  79. """
  80. session_id = uuid.uuid4()
  81. if self.websocket is None:
  82. connection = await self._direct_connect(session_id, target_url)
  83. else:
  84. connection = await self._websocket_connect(session_id, target_url)
  85. return connection
  86. def get_connection(self, session_id: uuid.UUID) -> Optional[TunnelConnection]:
  87. """Get connection by session_id"""
  88. return self._connections.get(session_id)
  89. def pop_connection(self, session_id: uuid.UUID) -> Optional[TunnelConnection]:
  90. """Remove and return connection by session_id"""
  91. return self._connections.pop(session_id, None)
  92. def connections(self) -> Dict[uuid.UUID, TunnelConnection]:
  93. """Get all connections"""
  94. return self._connections
  95. async def dispatch(self, msg: SessionBaseMessage) -> None:
  96. """Dispatch message to appropriate handler based on message type"""
  97. connection = self.get_connection(msg.session_id)
  98. if connection is None and not isinstance(msg, DisconnectMessage):
  99. logger.error(
  100. f"[ConnectionManager] WARNING: No connection found for session_id={msg.session_id}, message type={type(msg).__name__}"
  101. )
  102. return
  103. if isinstance(msg, ConnectResponseMessage):
  104. if msg.success:
  105. connection.set_connected()
  106. else:
  107. connection.connect_error(Exception(f"Connection failed: {msg.error}"))
  108. elif isinstance(msg, DataMessage):
  109. await connection.handle_data(msg.data)
  110. elif isinstance(msg, DisconnectMessage):
  111. connection = self.pop_connection(msg.session_id)
  112. if connection:
  113. await connection.close()
  114. class BaseConnectionManager(Protocol):
  115. """Abstract base class for connection managers"""
  116. def connections(self) -> Dict:
  117. """Get all connections"""
  118. ...
  119. async def connect(self, target_url: str) -> IOConnection:
  120. """Establish connection and return IOConnection
  121. URL format: tcp://host:port or unix:///path/to/socket
  122. """
  123. ...
  124. class ClientConnectionManager:
  125. """Client-side ConnectionManager, handles CONNECT_REQUEST and forwards data"""
  126. def __init__(
  127. self,
  128. websocket: Union["ClientConnection", "ServerConnection", "StarletteWebSocket"],
  129. ) -> None:
  130. self.websocket = websocket
  131. self._connections: Dict[uuid.UUID, TunnelConnection] = {}
  132. self._tasks: set[asyncio.Task] = set()
  133. async def _send_to_websocket(self, data: bytes) -> None:
  134. """Send data to WebSocket, compatible with Starlette and websockets library"""
  135. logger.trace(
  136. f"[ClientConnectionManager] Sending {len(data)} bytes to WebSocket"
  137. )
  138. if hasattr(self.websocket, 'send_bytes'):
  139. await self.websocket.send_bytes(data)
  140. else:
  141. await self.websocket.send(data)
  142. def get_connection(self, session_id: uuid.UUID) -> Optional[TunnelConnection]:
  143. """Get connection by session_id"""
  144. return self._connections.get(session_id)
  145. def pop_connection(self, session_id: uuid.UUID) -> Optional[TunnelConnection]:
  146. """Remove and return connection by session_id"""
  147. return self._connections.pop(session_id, None)
  148. def connections(self) -> Dict[uuid.UUID, TunnelConnection]:
  149. """Get all connections"""
  150. return self._connections
  151. async def dispatch(self, msg: SessionBaseMessage) -> None:
  152. """Dispatch message to appropriate handler based on message type"""
  153. connection = self.get_connection(msg.session_id)
  154. if isinstance(msg, ConnectRequestMessage):
  155. await self.handle_client_connect_request(msg)
  156. elif isinstance(msg, DataMessage):
  157. if connection:
  158. await connection.handle_data(msg.data)
  159. elif isinstance(msg, DisconnectMessage):
  160. connection = self.pop_connection(msg.session_id)
  161. if connection:
  162. await connection.close()
  163. async def handle_client_connect_request(self, msg: ConnectRequestMessage) -> None:
  164. """Handle CONNECT_REQUEST: establish connection and respond"""
  165. logger.trace(
  166. f"[ClientConnectionManager] Handling CONNECT_REQUEST for {msg.target_url} (session_id={msg.session_id})"
  167. )
  168. try:
  169. parsed = urllib.parse.urlparse(msg.target_url)
  170. if parsed.scheme == "unix":
  171. reader, writer = await asyncio.wait_for(
  172. asyncio.open_unix_connection(parsed.path), timeout=5.0
  173. )
  174. else:
  175. reader, writer = await asyncio.wait_for(
  176. asyncio.open_connection(parsed.hostname, parsed.port),
  177. timeout=5.0,
  178. )
  179. target_connection = AsyncIOConnection(reader=reader, writer=writer)
  180. connection = TunnelConnection(
  181. session_id=msg.session_id,
  182. websocket=self.websocket,
  183. )
  184. connection.set_connected()
  185. self._connections[msg.session_id] = connection
  186. response = ConnectResponseMessage(session_id=msg.session_id, success=True)
  187. await self._send_to_websocket(pack_message(response))
  188. logger.trace(f"[ClientConnectionManager] Connected to {msg.target_url}")
  189. async def tunnel_and_close(session_id: uuid.UUID = msg.session_id):
  190. try:
  191. await tunnel(
  192. connection,
  193. target_connection,
  194. name="ClientConnectionManager Tunnel",
  195. )
  196. except Exception as e:
  197. logger.error(f"[ClientConnectionManager] Tunnel error: {e}")
  198. finally:
  199. conn = self.pop_connection(session_id)
  200. if conn:
  201. await conn.close()
  202. task = asyncio.create_task(tunnel_and_close())
  203. self._tasks.add(task)
  204. task.add_done_callback(self._tasks.discard)
  205. except Exception as e:
  206. logger.error(f"[ClientConnectionManager] Failed to connect: {e}")
  207. response = ConnectResponseMessage(
  208. session_id=msg.session_id, success=False, error=str(e)
  209. )
  210. await self._send_to_websocket(pack_message(response))
  211. class RemoteConnectionManager:
  212. """Connection manager that forwards requests to a remote peer's HTTP proxy
  213. Only supports TCP connections. Unix socket targets will raise an error.
  214. """
  215. def __init__(self, peer_address: str, proxy_port: int):
  216. self.peer_address = peer_address
  217. self.proxy_port = proxy_port
  218. async def connect(
  219. self,
  220. target_url: str,
  221. ) -> IOConnection:
  222. """Forward HTTP request to remote peer's proxy
  223. URL format: tcp://host:port
  224. Note: Unix socket URLs are not supported and will raise an error.
  225. """
  226. parsed = urllib.parse.urlparse(target_url)
  227. if parsed.scheme == "unix":
  228. raise ValueError(
  229. "RemoteConnectionManager does not support Unix socket connections"
  230. )
  231. # TCP connection
  232. target = f"{parsed.hostname}:{parsed.port}"
  233. logger.trace(
  234. f"[RemoteConnectionManager] Forwarding to {self.peer_address}:{self.proxy_port} -> {target}"
  235. )
  236. # Connect to the remote proxy
  237. reader, writer = await asyncio.open_connection(
  238. self.peer_address, self.proxy_port
  239. )
  240. # For HTTP CONNECT method, we need to send CONNECT request to the proxy
  241. # The proxy will then connect to the target
  242. connect_request = f"CONNECT {target} HTTP/1.1\r\nHost: {target}\r\n\r\n"
  243. writer.write(connect_request.encode())
  244. await writer.drain()
  245. # Read response from proxy
  246. response = await reader.read(4096)
  247. response_str = response.decode('utf-8', errors='ignore')
  248. # Check if proxy accepted the connection
  249. if not response_str.startswith("HTTP/1.1 200 "):
  250. logger.error(
  251. f"[RemoteConnectionManager] Proxy rejected connection: {response_str}"
  252. )
  253. writer.close()
  254. await writer.wait_closed()
  255. raise Exception(f"Proxy connection failed: {response_str}")
  256. logger.trace(
  257. f"[RemoteConnectionManager] Connected to {target} via remote proxy"
  258. )
  259. # Create a tunnel connection for this forward
  260. connection = AsyncIOConnection(reader=reader, writer=writer)
  261. return connection
  262. def connections(self) -> Dict:
  263. """RemoteConnectionManager does not track active connections, so return empty dict"""
  264. return {}