connection.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. #!/usr/bin/env python3
  2. """
  3. TunnelConnection - Represents a tunnel connection to target server
  4. """
  5. from abc import ABC, abstractmethod
  6. import asyncio
  7. import logging
  8. import uuid
  9. from functools import partial
  10. from typing import Optional, Union, TYPE_CHECKING, Callable, Coroutine, Any
  11. from fastapi import HTTPException
  12. if TYPE_CHECKING:
  13. from websockets.client import ClientConnection
  14. from websockets.server import ServerConnection
  15. from starlette.websockets import WebSocket as StarletteWebSocket
  16. from .message import DataMessage, DisconnectMessage, pack_message
  17. logger = logging.getLogger(__name__)
  18. class IOConnection(ABC):
  19. @abstractmethod
  20. async def read(self, n: int = -1, timeout: Optional[float] = None) -> bytes:
  21. pass
  22. @abstractmethod
  23. async def write(self, data: bytes) -> None:
  24. pass
  25. @abstractmethod
  26. async def close(self) -> None:
  27. pass
  28. class AsyncIOConnection(IOConnection):
  29. writer: asyncio.StreamWriter
  30. reader: asyncio.StreamReader
  31. def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
  32. self.reader = reader
  33. self.writer = writer
  34. async def read(self, n: int = -1, timeout: Optional[float] = None) -> bytes:
  35. if timeout is not None:
  36. return await asyncio.wait_for(self.reader.read(n), timeout)
  37. return await self.reader.read(n)
  38. async def write(self, data: bytes) -> None:
  39. self.writer.write(data)
  40. await self.writer.drain()
  41. async def close(self) -> None:
  42. self.writer.close()
  43. await self.writer.wait_closed()
  44. async def send_error(self, code: int, message: str) -> None:
  45. """Send error response to client with message in body"""
  46. reason_phrases = {
  47. 400: "Bad Request",
  48. 502: "Bad Gateway",
  49. 503: "Service Unavailable",
  50. }
  51. reason = reason_phrases.get(code, "Error")
  52. response = f"HTTP/1.1 {code} {reason}\r\n"
  53. response += "Content-Type: text/plain\r\n"
  54. response += f"Content-Length: {len(message)}\r\n"
  55. response += "Connection: close\r\n"
  56. response += "\r\n"
  57. response += message
  58. await self.write(response.encode())
  59. async def write_connect_established(self) -> None:
  60. """Send 200 Connection Established response to client"""
  61. response = b"HTTP/1.1 200 Connection Established\r\n\r\n"
  62. await self.write(response)
  63. class TunnelConnection(IOConnection):
  64. """Represents a tunnel connection to target server"""
  65. def __init__(
  66. self,
  67. session_id: uuid.UUID,
  68. websocket: Union[
  69. "ClientConnection", "ServerConnection", "StarletteWebSocket", None
  70. ],
  71. ) -> None:
  72. self.session_id = session_id
  73. self.websocket = websocket
  74. # Connection state
  75. self._pending_future: Optional[asyncio.Future[bool]] = (
  76. asyncio.get_running_loop().create_future()
  77. )
  78. # Response tracking queue for WebSocket tunnel mode
  79. self._response_queue: asyncio.Queue[bytes] = asyncio.Queue()
  80. self._connection_error: Optional[Exception] = None
  81. @property
  82. def is_pending(self) -> bool:
  83. """Check if connection is in pending state (waiting for response)"""
  84. return self._pending_future is not None and not self._pending_future.done()
  85. @property
  86. def is_connected(self) -> bool:
  87. """Check if connection is established"""
  88. return self._pending_future is None and self._connection_error is None
  89. @property
  90. def connect_result(self) -> asyncio.Future[bool]:
  91. """Get the future representing the connection state"""
  92. return self._pending_future
  93. def set_connected(self) -> None:
  94. """Mark connection as connected"""
  95. if self._pending_future is None:
  96. logger.warning(
  97. f"set_connected called but no pending future, session_id={self.session_id}"
  98. )
  99. return
  100. if self._pending_future and not self._pending_future.done():
  101. self._pending_future.set_result(True)
  102. self._pending_future = None
  103. def connect_error(self, error: Exception) -> None:
  104. """Mark connection as failed with error"""
  105. if self._pending_future is None:
  106. logger.warning(
  107. f"connect_error called but no pending future, session_id={self.session_id}, error={error}"
  108. )
  109. return
  110. if self._pending_future and not self._pending_future.done():
  111. self._pending_future.set_exception(error)
  112. self._connection_error = error
  113. self._pending_future = None
  114. async def _send_to_websocket(self, data: bytes) -> None:
  115. """Send data to WebSocket, compatible with Starlette and websockets library"""
  116. if hasattr(self.websocket, 'send_bytes'):
  117. await self.websocket.send_bytes(data)
  118. else:
  119. await self.websocket.send(data)
  120. async def handle_data(self, data: bytes) -> None:
  121. """Handle data received from WebSocket, forward to target"""
  122. logger.trace(
  123. f"[Tunnel] handle_data: session_id={self.session_id}, {len(data)} bytes"
  124. )
  125. if not self.is_connected:
  126. logger.warning(
  127. f"[Tunnel] Connection is pending or failed, ignoring data until connected, session_id={self.session_id}"
  128. )
  129. return
  130. if not data:
  131. logger.trace("[Tunnel] Empty data received, signaling EOF")
  132. else:
  133. logger.trace(f"[Tunnel] Queuing {len(data)} bytes for response tracking")
  134. await self._response_queue.put(data)
  135. return
  136. # Followings methods are for compatibility with IOConnection interface, used in tunnel function
  137. async def close(self) -> None:
  138. """Close connection"""
  139. # Send disconnect message to WebSocket
  140. logger.trace(
  141. f"[Tunnel] Closing connection, sending DisconnectMessage, session_id={self.session_id}"
  142. )
  143. try:
  144. msg = DisconnectMessage(session_id=self.session_id)
  145. await self._send_to_websocket(pack_message(msg))
  146. except Exception as e:
  147. logger.trace(
  148. f"[Tunnel] Failed to send DisconnectMessage (websocket may already be closed): {e}, session_id={self.session_id}"
  149. )
  150. await self._response_queue.put(b"") # Unblock any pending reads
  151. async def read(self, _n: int = -1, timeout: float = 3000.0) -> bytes:
  152. return await asyncio.wait_for(self._response_queue.get(), timeout=timeout)
  153. async def write(self, data: bytes) -> None:
  154. """Send data to WebSocket"""
  155. logger.trace(
  156. f"[Tunnel] Sending {len(data)} bytes to WebSocket, session_id={self.session_id}"
  157. )
  158. msg = DataMessage(session_id=self.session_id, data=data)
  159. await self._send_to_websocket(pack_message(msg))
  160. async def relay(
  161. reader: IOConnection,
  162. writer: IOConnection,
  163. name: str,
  164. ) -> None:
  165. """Relay data from ``reader`` to ``writer`` until EOF or error.
  166. Reads in 8 KiB chunks and writes each to the writer. Closes the writer
  167. when the reader signals EOF (empty bytes) or raises an exception.
  168. """
  169. try:
  170. while True:
  171. data = await reader.read(8192)
  172. if not data:
  173. logger.trace(f"{name}: read EOF")
  174. break
  175. logger.trace(f"{name}: forwarding {len(data)} bytes")
  176. await writer.write(data)
  177. except Exception as e:
  178. logger.error(f"{name}: error {e}")
  179. finally:
  180. logger.trace(f"{name}: closing {type(writer).__name__} writer")
  181. await writer.close()
  182. async def tunnel(
  183. client_connection: IOConnection,
  184. remote_connection: IOConnection,
  185. name: str = "Tunnel",
  186. request: Optional[str] = None,
  187. response_relay: Optional[
  188. Callable[[IOConnection, IOConnection], Coroutine[Any, Any, None]]
  189. ] = None,
  190. ) -> None:
  191. """Tunnel data between client and remote connections"""
  192. if request is not None:
  193. logger.trace(
  194. f"[{name}] client->remote: Sending initial request data through tunnel:\n{request[:500]}..."
  195. )
  196. try:
  197. await remote_connection.write(request.encode())
  198. except Exception as e:
  199. logger.error(f"[{name}] Error sending initial request data: {e}")
  200. await remote_connection.close()
  201. raise HTTPException(
  202. status_code=502, detail="Failed to send initial request data"
  203. )
  204. if response_relay is None:
  205. response_relay = partial(relay, name=f"[{name}] remote->client")
  206. await asyncio.gather(
  207. response_relay(remote_connection, client_connection),
  208. relay(client_connection, remote_connection, f"[{name}] client->remote"),
  209. return_exceptions=True,
  210. )
  211. logger.trace(f"[{name}] Tunnel closed")