main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #!/usr/bin/env python3
  2. """
  3. Message Protocol Example - Connect-based Proxy Demo
  4. This example demonstrates the connect-based proxy flow:
  5. 1. Client connects to Server and registers
  6. 2. Server sends CONNECT_REQUEST to Client (request to connect to target)
  7. 3. Client connects to target, sends CONNECT_RESPONSE (connect succeed)
  8. 4. Data flows through the tunnel via persistent socket connections
  9. Server Federation:
  10. - Multiple servers can form a federation to share client registrations
  11. - Each server maintains a registry of clients from all peers
  12. - When a client connects/disconnects, all peers are notified
  13. Run:
  14. - Start server: python src/main.py server
  15. - Start client: python src/main.py client
  16. Client API:
  17. - GET /clients - List all connected clients
  18. Returns: {"clients": [...], "total": N}
  19. - GET /clients/{client_id} - Get details for a specific client
  20. Returns: client info including CIDRs, unix sockets, active sessions
  21. - GET /clients/{client_id}/connections - Get active tunnel connections for a client
  22. Returns: {"connections": [...], "total": N}
  23. Federation API:
  24. - POST /register-peer - Register a peer server
  25. Body: {"address": "host:port", "port": 8765, "proxy_port": 8000}
  26. Returns: {"status": "ok", "server_id": "uuid"}
  27. - DELETE /register-peer/{server_id} - Remove a peer server
  28. Returns: {"status": "ok"}
  29. - GET /peers - List all connected peers (outgoing and incoming)
  30. Returns: {"peers": [...]}
  31. WebSocket Endpoints:
  32. - ws://host:port/connect - Client and peer server connections
  33. """
  34. import asyncio
  35. import argparse
  36. import logging
  37. import uuid
  38. import uvicorn
  39. from typing import Optional
  40. from fastapi import FastAPI, HTTPException
  41. from pydantic import BaseModel
  42. from .proxy_server import HTTPSProxyServer
  43. from .message_server import MessageServerHandler, router
  44. from .message_client import MessageClient
  45. logger = logging.getLogger(__name__)
  46. # Pydantic models for API requests
  47. class RegisterPeerRequest(BaseModel):
  48. address: str
  49. port: int = 8765
  50. class RegisterPeerResponse(BaseModel):
  51. status: str
  52. server_id: Optional[uuid.UUID] = None
  53. class RemovePeerResponse(BaseModel):
  54. status: str
  55. # Global message handler reference for API endpoints
  56. _message_handler: Optional[MessageServerHandler] = None
  57. def _create_server_app(message_handler: MessageServerHandler): # noqa: C901
  58. """Create FastAPI app with all routes for the server."""
  59. app = FastAPI()
  60. app.state.message_server_handler = message_handler
  61. # WebSocket endpoint
  62. app.include_router(router)
  63. # Federation API endpoints
  64. @app.post("/register-peer", response_model=RegisterPeerResponse)
  65. async def register_peer(request: RegisterPeerRequest):
  66. """Register a peer server"""
  67. peer_id = await message_handler.add_peer(request.address, request.port)
  68. if peer_id:
  69. return RegisterPeerResponse(status="ok", server_id=peer_id)
  70. return RegisterPeerResponse(status="error", server_id=None)
  71. @app.delete("/register-peer/{peer_id}", response_model=RemovePeerResponse)
  72. async def remove_peer(peer_id: str):
  73. """Remove a peer server by UUID"""
  74. try:
  75. peer_uuid = uuid.UUID(peer_id)
  76. await message_handler.remove_peer(peer_uuid)
  77. except ValueError:
  78. # Try to remove by address instead
  79. await message_handler.remove_peer_by_address(peer_id)
  80. return RemovePeerResponse(status="ok")
  81. @app.get("/clients")
  82. async def list_clients():
  83. """List all connected clients."""
  84. clients = []
  85. for client_id, info in message_handler.client_registry.items():
  86. conn_mgr = message_handler.connection_managers.get(client_id)
  87. active_sessions = list(conn_mgr.connections.keys()) if conn_mgr else []
  88. clients.append(
  89. {
  90. "client_id": str(client_id),
  91. "cidrs": info.cidrs,
  92. "unix_sockets": info.unix_sockets,
  93. "active_sessions": [str(s) for s in active_sessions],
  94. "session_count": len(active_sessions),
  95. }
  96. )
  97. return {"clients": clients, "total": len(clients)}
  98. @app.get("/clients/{client_id}")
  99. async def get_client(client_id: str):
  100. """Get details for a specific client."""
  101. try:
  102. client_uuid = uuid.UUID(client_id)
  103. except ValueError:
  104. raise HTTPException(404, detail=f"Invalid client ID format: {client_id}")
  105. info = message_handler.client_registry.get(client_uuid)
  106. if not info:
  107. raise HTTPException(404, detail=f"Client not found: {client_id}")
  108. conn_mgr = message_handler.connection_managers.get(client_uuid)
  109. active_sessions = list(conn_mgr.connections.keys()) if conn_mgr else []
  110. return {
  111. "client_id": client_id,
  112. "cidrs": info.cidrs,
  113. "unix_sockets": info.unix_sockets,
  114. "server_id": str(info.server_id) if info.server_id else None,
  115. "active_sessions": [str(s) for s in active_sessions],
  116. "session_count": len(active_sessions),
  117. }
  118. @app.get("/clients/{client_id}/connections")
  119. async def get_client_connections(client_id: str):
  120. """Get active tunnel connections for a specific client."""
  121. try:
  122. client_uuid = uuid.UUID(client_id)
  123. except ValueError:
  124. raise HTTPException(404, detail=f"Invalid client ID format: {client_id}")
  125. info = message_handler.client_registry.get(client_uuid)
  126. if not info:
  127. raise HTTPException(404, detail=f"Client not found: {client_id}")
  128. conn_mgr = message_handler.connection_managers.get(client_uuid)
  129. if not conn_mgr:
  130. return {"connections": [], "total": 0}
  131. connections = []
  132. for session_id, conn in conn_mgr.connections().items():
  133. connections.append(
  134. {
  135. "session_id": str(session_id),
  136. "is_pending": conn.is_pending,
  137. "is_connected": conn.is_connected,
  138. }
  139. )
  140. return {"connections": connections, "total": len(connections)}
  141. @app.get("/peers")
  142. async def list_peers():
  143. """List all connected peers (outgoing and incoming)"""
  144. all_peers = []
  145. # Outgoing peers (we connected to)
  146. for peer in message_handler.peers.values():
  147. all_peers.append(
  148. {
  149. "server_id": str(peer.server_id),
  150. "address": (
  151. f"{peer.listen_address}:{peer.listen_port}"
  152. if peer.listen_address
  153. else ""
  154. ),
  155. "proxy_port": peer.proxy_port,
  156. "connected": peer.connected,
  157. "type": "outgoing",
  158. }
  159. )
  160. # Incoming serving_peers (connected to us)
  161. for peer in message_handler.serving_peers.values():
  162. all_peers.append(
  163. {
  164. "server_id": str(peer.server_id),
  165. "address": (
  166. f"{peer.listen_address}:{peer.listen_port}"
  167. if peer.listen_address
  168. else ""
  169. ),
  170. "proxy_port": peer.proxy_port,
  171. "connected": peer.connected,
  172. "type": "incoming",
  173. }
  174. )
  175. return {"peers": all_peers}
  176. return app
  177. async def _run_server(args):
  178. """Run the server role."""
  179. global _message_handler
  180. server_id = uuid.UUID(args.server_id) if args.server_id else uuid.uuid4()
  181. message_handler = MessageServerHandler(
  182. server_id=server_id,
  183. listen_address=args.host,
  184. listen_port=args.port,
  185. proxy_port=args.proxy_port,
  186. )
  187. _message_handler = message_handler
  188. proxy = HTTPSProxyServer(
  189. host=args.host,
  190. port=args.proxy_port,
  191. connection_manager_getter=message_handler.get_connection_manager,
  192. )
  193. logger.debug(f"[Server] Starting WebSocket server on {args.host}:{args.port}")
  194. logger.debug(f"[Server] WebSocket endpoint: ws://{args.host}:{args.port}/connect")
  195. logger.debug(f"[Server] HTTP proxy endpoint: http://{args.host}:{args.proxy_port}/")
  196. logger.debug(
  197. f"[Server] Federation API: http://{args.host}:{args.port}/register-peer"
  198. )
  199. app = _create_server_app(message_handler)
  200. config = uvicorn.Config(app, host=args.host, port=args.port, log_level="info")
  201. uvicorn_server = uvicorn.Server(config)
  202. # Start proxy server in background
  203. proxy_task = asyncio.create_task(proxy.start())
  204. # Run uvicorn (it handles signals internally)
  205. await uvicorn_server.serve()
  206. # Uvicorn exited, stop proxy
  207. proxy_task.cancel()
  208. try:
  209. await proxy_task
  210. except asyncio.CancelledError:
  211. pass
  212. await proxy.stop()
  213. async def _run_client(args):
  214. """Run the client role."""
  215. client = MessageClient(
  216. server_endpoint=f"http://{args.host}:{args.port}",
  217. client_id=args.client_id,
  218. cidrs=args.cidr,
  219. unix_sockets=args.unix_sockets,
  220. )
  221. await client.run()
  222. def _parse_args():
  223. """Parse command line arguments."""
  224. parser = argparse.ArgumentParser(description='WebSocket Message Protocol Example')
  225. parser.add_argument(
  226. 'role', choices=['server', 'client'], help='Run as server or client'
  227. )
  228. parser.add_argument('--host', default='localhost', help='Server host')
  229. parser.add_argument('--port', type=int, default=8765, help='WebSocket server port')
  230. parser.add_argument(
  231. '--proxy-port', type=int, default=8000, help='HTTP proxy port (server only)'
  232. )
  233. # Server-specific options
  234. parser.add_argument(
  235. '--server-id', default=None, help='Server ID (auto-generated if not provided)'
  236. )
  237. # Client-specific options
  238. parser.add_argument(
  239. '--client-id', default=None, help='Client ID (auto-generated if not provided)'
  240. )
  241. parser.add_argument(
  242. '--cidr',
  243. action='append',
  244. default=[],
  245. help='CIDR to register (can be specified multiple times)',
  246. )
  247. parser.add_argument(
  248. '--unix-socket',
  249. action='append',
  250. default=[],
  251. dest='unix_sockets',
  252. help='Unix socket path to register (can be specified multiple times)',
  253. )
  254. return parser.parse_args()
  255. async def main():
  256. args = _parse_args()
  257. if args.role == 'server':
  258. await _run_server(args)
  259. else:
  260. await _run_client(args)
  261. if __name__ == '__main__':
  262. logging.basicConfig(level=logging.DEBUG)
  263. asyncio.run(main())