| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 |
- #!/usr/bin/env python3
- """
- Message Protocol Example - Connect-based Proxy Demo
- This example demonstrates the connect-based proxy flow:
- 1. Client connects to Server and registers
- 2. Server sends CONNECT_REQUEST to Client (request to connect to target)
- 3. Client connects to target, sends CONNECT_RESPONSE (connect succeed)
- 4. Data flows through the tunnel via persistent socket connections
- Server Federation:
- - Multiple servers can form a federation to share client registrations
- - Each server maintains a registry of clients from all peers
- - When a client connects/disconnects, all peers are notified
- Run:
- - Start server: python src/main.py server
- - Start client: python src/main.py client
- Client API:
- - GET /clients - List all connected clients
- Returns: {"clients": [...], "total": N}
- - GET /clients/{client_id} - Get details for a specific client
- Returns: client info including CIDRs, unix sockets, active sessions
- - GET /clients/{client_id}/connections - Get active tunnel connections for a client
- Returns: {"connections": [...], "total": N}
- Federation API:
- - POST /register-peer - Register a peer server
- Body: {"address": "host:port", "port": 8765, "proxy_port": 8000}
- Returns: {"status": "ok", "server_id": "uuid"}
- - DELETE /register-peer/{server_id} - Remove a peer server
- Returns: {"status": "ok"}
- - GET /peers - List all connected peers (outgoing and incoming)
- Returns: {"peers": [...]}
- WebSocket Endpoints:
- - ws://host:port/connect - Client and peer server connections
- """
- import asyncio
- import argparse
- import logging
- import uuid
- import uvicorn
- from typing import Optional
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel
- from .proxy_server import HTTPSProxyServer
- from .message_server import MessageServerHandler, router
- from .message_client import MessageClient
- logger = logging.getLogger(__name__)
- # Pydantic models for API requests
- class RegisterPeerRequest(BaseModel):
- address: str
- port: int = 8765
- class RegisterPeerResponse(BaseModel):
- status: str
- server_id: Optional[uuid.UUID] = None
- class RemovePeerResponse(BaseModel):
- status: str
- # Global message handler reference for API endpoints
- _message_handler: Optional[MessageServerHandler] = None
- def _create_server_app(message_handler: MessageServerHandler): # noqa: C901
- """Create FastAPI app with all routes for the server."""
- app = FastAPI()
- app.state.message_server_handler = message_handler
- # WebSocket endpoint
- app.include_router(router)
- # Federation API endpoints
- @app.post("/register-peer", response_model=RegisterPeerResponse)
- async def register_peer(request: RegisterPeerRequest):
- """Register a peer server"""
- peer_id = await message_handler.add_peer(request.address, request.port)
- if peer_id:
- return RegisterPeerResponse(status="ok", server_id=peer_id)
- return RegisterPeerResponse(status="error", server_id=None)
- @app.delete("/register-peer/{peer_id}", response_model=RemovePeerResponse)
- async def remove_peer(peer_id: str):
- """Remove a peer server by UUID"""
- try:
- peer_uuid = uuid.UUID(peer_id)
- await message_handler.remove_peer(peer_uuid)
- except ValueError:
- # Try to remove by address instead
- await message_handler.remove_peer_by_address(peer_id)
- return RemovePeerResponse(status="ok")
- @app.get("/clients")
- async def list_clients():
- """List all connected clients."""
- clients = []
- for client_id, info in message_handler.client_registry.items():
- conn_mgr = message_handler.connection_managers.get(client_id)
- active_sessions = list(conn_mgr.connections.keys()) if conn_mgr else []
- clients.append(
- {
- "client_id": str(client_id),
- "cidrs": info.cidrs,
- "unix_sockets": info.unix_sockets,
- "active_sessions": [str(s) for s in active_sessions],
- "session_count": len(active_sessions),
- }
- )
- return {"clients": clients, "total": len(clients)}
- @app.get("/clients/{client_id}")
- async def get_client(client_id: str):
- """Get details for a specific client."""
- try:
- client_uuid = uuid.UUID(client_id)
- except ValueError:
- raise HTTPException(404, detail=f"Invalid client ID format: {client_id}")
- info = message_handler.client_registry.get(client_uuid)
- if not info:
- raise HTTPException(404, detail=f"Client not found: {client_id}")
- conn_mgr = message_handler.connection_managers.get(client_uuid)
- active_sessions = list(conn_mgr.connections.keys()) if conn_mgr else []
- return {
- "client_id": client_id,
- "cidrs": info.cidrs,
- "unix_sockets": info.unix_sockets,
- "server_id": str(info.server_id) if info.server_id else None,
- "active_sessions": [str(s) for s in active_sessions],
- "session_count": len(active_sessions),
- }
- @app.get("/clients/{client_id}/connections")
- async def get_client_connections(client_id: str):
- """Get active tunnel connections for a specific client."""
- try:
- client_uuid = uuid.UUID(client_id)
- except ValueError:
- raise HTTPException(404, detail=f"Invalid client ID format: {client_id}")
- info = message_handler.client_registry.get(client_uuid)
- if not info:
- raise HTTPException(404, detail=f"Client not found: {client_id}")
- conn_mgr = message_handler.connection_managers.get(client_uuid)
- if not conn_mgr:
- return {"connections": [], "total": 0}
- connections = []
- for session_id, conn in conn_mgr.connections().items():
- connections.append(
- {
- "session_id": str(session_id),
- "is_pending": conn.is_pending,
- "is_connected": conn.is_connected,
- }
- )
- return {"connections": connections, "total": len(connections)}
- @app.get("/peers")
- async def list_peers():
- """List all connected peers (outgoing and incoming)"""
- all_peers = []
- # Outgoing peers (we connected to)
- for peer in message_handler.peers.values():
- all_peers.append(
- {
- "server_id": str(peer.server_id),
- "address": (
- f"{peer.listen_address}:{peer.listen_port}"
- if peer.listen_address
- else ""
- ),
- "proxy_port": peer.proxy_port,
- "connected": peer.connected,
- "type": "outgoing",
- }
- )
- # Incoming serving_peers (connected to us)
- for peer in message_handler.serving_peers.values():
- all_peers.append(
- {
- "server_id": str(peer.server_id),
- "address": (
- f"{peer.listen_address}:{peer.listen_port}"
- if peer.listen_address
- else ""
- ),
- "proxy_port": peer.proxy_port,
- "connected": peer.connected,
- "type": "incoming",
- }
- )
- return {"peers": all_peers}
- return app
- async def _run_server(args):
- """Run the server role."""
- global _message_handler
- server_id = uuid.UUID(args.server_id) if args.server_id else uuid.uuid4()
- message_handler = MessageServerHandler(
- server_id=server_id,
- listen_address=args.host,
- listen_port=args.port,
- proxy_port=args.proxy_port,
- )
- _message_handler = message_handler
- proxy = HTTPSProxyServer(
- host=args.host,
- port=args.proxy_port,
- connection_manager_getter=message_handler.get_connection_manager,
- )
- logger.debug(f"[Server] Starting WebSocket server on {args.host}:{args.port}")
- logger.debug(f"[Server] WebSocket endpoint: ws://{args.host}:{args.port}/connect")
- logger.debug(f"[Server] HTTP proxy endpoint: http://{args.host}:{args.proxy_port}/")
- logger.debug(
- f"[Server] Federation API: http://{args.host}:{args.port}/register-peer"
- )
- app = _create_server_app(message_handler)
- config = uvicorn.Config(app, host=args.host, port=args.port, log_level="info")
- uvicorn_server = uvicorn.Server(config)
- # Start proxy server in background
- proxy_task = asyncio.create_task(proxy.start())
- # Run uvicorn (it handles signals internally)
- await uvicorn_server.serve()
- # Uvicorn exited, stop proxy
- proxy_task.cancel()
- try:
- await proxy_task
- except asyncio.CancelledError:
- pass
- await proxy.stop()
- async def _run_client(args):
- """Run the client role."""
- client = MessageClient(
- server_endpoint=f"http://{args.host}:{args.port}",
- client_id=args.client_id,
- cidrs=args.cidr,
- unix_sockets=args.unix_sockets,
- )
- await client.run()
- def _parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description='WebSocket Message Protocol Example')
- parser.add_argument(
- 'role', choices=['server', 'client'], help='Run as server or client'
- )
- parser.add_argument('--host', default='localhost', help='Server host')
- parser.add_argument('--port', type=int, default=8765, help='WebSocket server port')
- parser.add_argument(
- '--proxy-port', type=int, default=8000, help='HTTP proxy port (server only)'
- )
- # Server-specific options
- parser.add_argument(
- '--server-id', default=None, help='Server ID (auto-generated if not provided)'
- )
- # Client-specific options
- parser.add_argument(
- '--client-id', default=None, help='Client ID (auto-generated if not provided)'
- )
- parser.add_argument(
- '--cidr',
- action='append',
- default=[],
- help='CIDR to register (can be specified multiple times)',
- )
- parser.add_argument(
- '--unix-socket',
- action='append',
- default=[],
- dest='unix_sockets',
- help='Unix socket path to register (can be specified multiple times)',
- )
- return parser.parse_args()
- async def main():
- args = _parse_args()
- if args.role == 'server':
- await _run_server(args)
- else:
- await _run_client(args)
- if __name__ == '__main__':
- logging.basicConfig(level=logging.DEBUG)
- asyncio.run(main())
|