test_server_federation.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. """
  2. Test suite for server-to-server federation functionality.
  3. Tests:
  4. - Server connection via WebSocket handshake with header-based registration
  5. - Client registration broadcast to peers
  6. - Client disconnection broadcast to peers
  7. - Peer removal
  8. - Peer authentication with HMAC
  9. """
  10. import asyncio
  11. import pytest
  12. import uuid
  13. import uvicorn
  14. from fastapi import FastAPI
  15. from gpustack.websocket_proxy.message_server import MessageServerHandler, router
  16. from gpustack.websocket_proxy.message_client import MessageClient
  17. from gpustack.websocket_proxy.authenticator import create_authenticator
  18. def get_free_port(host: str = "127.0.0.1") -> int:
  19. """Get a free port on the host"""
  20. import socket
  21. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  22. s.bind((host, 0))
  23. return s.getsockname()[1]
  24. async def start_server(message_handler: MessageServerHandler, host: str, port: int):
  25. """Start WebSocket server using uvicorn"""
  26. app = FastAPI()
  27. app.state.message_server_handler = message_handler
  28. app.include_router(router)
  29. # Get actual port if port=0 (dynamic allocation)
  30. actual_port = get_free_port(host) if port == 0 else port
  31. config = uvicorn.Config(app, host=host, port=actual_port, log_level="error")
  32. server = uvicorn.Server(config)
  33. server_task = asyncio.create_task(server.serve())
  34. await asyncio.sleep(0.5) # Wait for server to start
  35. return server, server_task, actual_port
  36. class TestServerFederation:
  37. """Test server-to-server federation"""
  38. @pytest.mark.asyncio
  39. async def test_server_connection_via_headers(self):
  40. """Test that two servers can connect using header-based registration"""
  41. server1_id = uuid.uuid4()
  42. server2_id = uuid.uuid4()
  43. # Create two server handlers
  44. handler1 = MessageServerHandler(
  45. server_id=server1_id,
  46. listen_address="127.0.0.1",
  47. listen_port=0,
  48. proxy_port=0,
  49. )
  50. handler2 = MessageServerHandler(
  51. server_id=server2_id,
  52. listen_address="127.0.0.1",
  53. listen_port=0,
  54. proxy_port=0,
  55. )
  56. # Start both servers
  57. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  58. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  59. try:
  60. # Server1 connects to Server2 as a peer
  61. peer_id = await handler1.add_peer("127.0.0.1", port2)
  62. assert peer_id == server2_id
  63. # Server2 should have server1 in serving_peers
  64. assert server1_id in handler2.serving_peers
  65. assert handler2.serving_peers[server1_id].server_id == server1_id
  66. # Server1 should have server2 in peers
  67. assert server2_id in handler1.peers
  68. assert handler1.peers[server2_id].server_id == server2_id
  69. finally:
  70. task1.cancel()
  71. task2.cancel()
  72. try:
  73. await task1
  74. except asyncio.CancelledError:
  75. pass
  76. try:
  77. await task2
  78. except asyncio.CancelledError:
  79. pass
  80. @pytest.mark.asyncio
  81. async def test_client_registration_broadcast(self):
  82. """Test that client registration is broadcast to peers"""
  83. server1_id = uuid.uuid4()
  84. server2_id = uuid.uuid4()
  85. handler1 = MessageServerHandler(
  86. server_id=server1_id,
  87. listen_address="127.0.0.1",
  88. listen_port=0,
  89. proxy_port=0,
  90. )
  91. handler2 = MessageServerHandler(
  92. server_id=server2_id,
  93. listen_address="127.0.0.1",
  94. listen_port=0,
  95. proxy_port=0,
  96. )
  97. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  98. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  99. try:
  100. # Connect servers as peers
  101. await handler1.add_peer("127.0.0.1", port2)
  102. await asyncio.sleep(0.5) # Wait for connection to establish
  103. # Now connect a client to server1 (use valid UUID for client_id)
  104. client_uuid = uuid.uuid4()
  105. client = MessageClient(
  106. server_endpoint=f"ws://127.0.0.1:{port1}",
  107. client_id=client_uuid,
  108. cidrs=["192.168.1.0/24"],
  109. )
  110. client_task = asyncio.create_task(client.run())
  111. await asyncio.sleep(0.5) # Wait for client registration
  112. # Server1 should have the client (key is UUID)
  113. assert client_uuid in handler1.client_registry
  114. assert handler1.client_registry[client_uuid].cidrs == ["192.168.1.0/24"]
  115. # Server2 should receive the client update from server1
  116. # (it takes a moment for the broadcast to propagate)
  117. await asyncio.sleep(0.5)
  118. # The client should be registered in server2's registry via peer update
  119. assert client_uuid in handler2.client_registry
  120. assert handler2.client_registry[client_uuid].cidrs == ["192.168.1.0/24"]
  121. assert handler2.client_registry[client_uuid].server_id == server1_id
  122. client_task.cancel()
  123. try:
  124. await client_task
  125. except asyncio.CancelledError:
  126. pass
  127. finally:
  128. task1.cancel()
  129. task2.cancel()
  130. try:
  131. await task1
  132. except asyncio.CancelledError:
  133. pass
  134. try:
  135. await task2
  136. except asyncio.CancelledError:
  137. pass
  138. @pytest.mark.asyncio
  139. async def test_peer_sends_existing_clients_on_connect(self):
  140. """Test that when ServerA connects to ServerB, ServerA receives ServerB's existing clients"""
  141. server1_id = uuid.uuid4()
  142. server2_id = uuid.uuid4()
  143. handler1 = MessageServerHandler(
  144. server_id=server1_id,
  145. listen_address="127.0.0.1",
  146. listen_port=0,
  147. proxy_port=0,
  148. )
  149. handler2 = MessageServerHandler(
  150. server_id=server2_id,
  151. listen_address="127.0.0.1",
  152. listen_port=0,
  153. proxy_port=0,
  154. )
  155. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  156. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  157. try:
  158. # First, connect a client to Server2 (before Server1 connects)
  159. client_uuid = uuid.uuid4()
  160. client = MessageClient(
  161. server_endpoint=f"ws://127.0.0.1:{port2}",
  162. client_id=client_uuid,
  163. cidrs=["10.0.0.0/8"],
  164. )
  165. client_task = asyncio.create_task(client.run())
  166. await asyncio.sleep(0.5) # Wait for client registration
  167. # Server2 should have the client (key is UUID)
  168. assert client_uuid in handler2.client_registry
  169. assert handler2.client_registry[client_uuid].cidrs == ["10.0.0.0/8"]
  170. # Now Server1 connects to Server2
  171. await handler1.add_peer("127.0.0.1", port2)
  172. await asyncio.sleep(0.5) # Wait for connection and client list sync
  173. # Server1 should have received Server2's existing clients
  174. assert (
  175. client_uuid in handler1.client_registry
  176. ), f"Server1 did not receive Server2's client. Registry keys: {list(handler1.client_registry.keys())}"
  177. assert handler1.client_registry[client_uuid].cidrs == ["10.0.0.0/8"]
  178. assert handler1.client_registry[client_uuid].server_id == server2_id
  179. client_task.cancel()
  180. try:
  181. await client_task
  182. except asyncio.CancelledError:
  183. pass
  184. finally:
  185. task1.cancel()
  186. task2.cancel()
  187. try:
  188. await task1
  189. except asyncio.CancelledError:
  190. pass
  191. try:
  192. await task2
  193. except asyncio.CancelledError:
  194. pass
  195. try:
  196. await task1
  197. except asyncio.CancelledError:
  198. pass
  199. try:
  200. await task2
  201. except asyncio.CancelledError:
  202. pass
  203. @pytest.mark.asyncio
  204. async def test_remove_peer(self):
  205. """Test that peer removal works correctly"""
  206. server1_id = uuid.uuid4()
  207. server2_id = uuid.uuid4()
  208. handler1 = MessageServerHandler(
  209. server_id=server1_id,
  210. listen_address="127.0.0.1",
  211. listen_port=0,
  212. proxy_port=0,
  213. )
  214. handler2 = MessageServerHandler(
  215. server_id=server2_id,
  216. listen_address="127.0.0.1",
  217. listen_port=0,
  218. proxy_port=0,
  219. )
  220. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  221. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  222. try:
  223. # Connect servers as peers
  224. await handler1.add_peer("127.0.0.1", port2)
  225. await asyncio.sleep(0.5)
  226. # Verify connection
  227. assert server2_id in handler1.peers
  228. assert server1_id in handler2.serving_peers
  229. # Remove peer by UUID
  230. result = await handler1.remove_peer(server2_id)
  231. assert result is True
  232. # Verify removal
  233. assert server2_id not in handler1.peers
  234. # server2 should still have server1 in serving_peers until it detects disconnect
  235. await asyncio.sleep(0.5)
  236. finally:
  237. task1.cancel()
  238. task2.cancel()
  239. try:
  240. await task1
  241. except asyncio.CancelledError:
  242. pass
  243. try:
  244. await task2
  245. except asyncio.CancelledError:
  246. pass
  247. @pytest.mark.asyncio
  248. async def test_bidirectional_peer_connection(self):
  249. """Test that two servers can connect to each other as peers"""
  250. server1_id = uuid.uuid4()
  251. server2_id = uuid.uuid4()
  252. handler1 = MessageServerHandler(
  253. server_id=server1_id,
  254. listen_address="127.0.0.1",
  255. listen_port=0,
  256. proxy_port=0,
  257. )
  258. handler2 = MessageServerHandler(
  259. server_id=server2_id,
  260. listen_address="127.0.0.1",
  261. listen_port=0,
  262. proxy_port=0,
  263. )
  264. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  265. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  266. try:
  267. # Server1 connects to Server2
  268. await handler1.add_peer("127.0.0.1", port2)
  269. await asyncio.sleep(0.3)
  270. # Server2 connects to Server1
  271. await handler2.add_peer("127.0.0.1", port1)
  272. await asyncio.sleep(0.3)
  273. # Both should have each other as peers (server1 has server2 as outgoing,
  274. # server2 has server1 as outgoing, but server1 is also in server2's serving_peers)
  275. assert server2_id in handler1.peers
  276. assert server1_id in handler2.peers
  277. assert server1_id in handler2.serving_peers # server1 connected to server2
  278. finally:
  279. task1.cancel()
  280. task2.cancel()
  281. try:
  282. await task1
  283. except asyncio.CancelledError:
  284. pass
  285. try:
  286. await task2
  287. except asyncio.CancelledError:
  288. pass
  289. @pytest.mark.asyncio
  290. async def test_clients_removed_when_peer_disconnects(self):
  291. """Test that clients synced from a peer are removed when the peer connection is closed.
  292. Scenario:
  293. 1. Server1 connects to Server2 (server1.add_peer(server2))
  294. 2. Server2 connects a client to itself
  295. 3. Server2 syncs its client to Server1 (stored with server_id=server2_id)
  296. 4. Server1 disconnects from Server2 (server1.remove_peer(server2))
  297. 5. Server1 should clean up clients with server_id=server2_id
  298. """
  299. server1_id = uuid.uuid4()
  300. server2_id = uuid.uuid4()
  301. handler1 = MessageServerHandler(
  302. server_id=server1_id,
  303. listen_address="127.0.0.1",
  304. listen_port=0,
  305. proxy_port=0,
  306. )
  307. handler2 = MessageServerHandler(
  308. server_id=server2_id,
  309. listen_address="127.0.0.1",
  310. listen_port=0,
  311. proxy_port=0,
  312. )
  313. _, task1, _ = await start_server(handler1, "127.0.0.1", 0)
  314. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  315. try:
  316. # Step 1: Server1 connects to Server2 as peer
  317. await handler1.add_peer("127.0.0.1", port2)
  318. await asyncio.sleep(0.5)
  319. # Step 2: Connect a client to Server2 (not Server1)
  320. client_uuid = uuid.uuid4()
  321. client = MessageClient(
  322. server_endpoint=f"ws://127.0.0.1:{port2}",
  323. client_id=client_uuid,
  324. cidrs=["192.168.1.0/24"],
  325. )
  326. client_task = asyncio.create_task(client.run())
  327. await asyncio.sleep(0.5)
  328. # Verify client is in Server2 (its own registry)
  329. assert client_uuid in handler2.client_registry
  330. # Step 3: Server2 syncs its client to Server1 via peer connection
  331. await asyncio.sleep(0.5)
  332. assert client_uuid in handler1.client_registry
  333. assert handler1.client_registry[client_uuid].server_id == server2_id
  334. # Step 4: Server1 disconnects from Server2
  335. await handler1.remove_peer(server2_id)
  336. await asyncio.sleep(0.5)
  337. # Step 5: Server1 should clean up clients synced from Server2
  338. assert client_uuid not in handler1.client_registry
  339. # Server2 should still have its own client (disconnect was on Server1's outgoing connection)
  340. assert client_uuid in handler2.client_registry
  341. client_task.cancel()
  342. try:
  343. await client_task
  344. except asyncio.CancelledError:
  345. pass
  346. finally:
  347. task1.cancel()
  348. task2.cancel()
  349. try:
  350. await task1
  351. except asyncio.CancelledError:
  352. pass
  353. try:
  354. await task2
  355. except asyncio.CancelledError:
  356. pass
  357. class TestServerFederationWithAuth:
  358. """Integration tests for server federation with authentication"""
  359. @pytest.mark.asyncio
  360. async def test_peer_connection_with_authenticator(self):
  361. """Test that two servers can connect when both use the same authenticator"""
  362. server1_id = uuid.uuid4()
  363. server2_id = uuid.uuid4()
  364. secret = "shared-secret"
  365. handler1 = MessageServerHandler(
  366. server_id=server1_id,
  367. listen_address="127.0.0.1",
  368. listen_port=0,
  369. proxy_port=0,
  370. authenticator=create_authenticator(secret),
  371. )
  372. handler2 = MessageServerHandler(
  373. server_id=server2_id,
  374. listen_address="127.0.0.1",
  375. listen_port=0,
  376. proxy_port=0,
  377. authenticator=create_authenticator(secret),
  378. )
  379. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  380. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  381. try:
  382. # Server1 connects to Server2
  383. peer_id = await handler1.add_peer("127.0.0.1", port2)
  384. assert peer_id == server2_id
  385. # Verify connection established
  386. await asyncio.sleep(0.5)
  387. assert server2_id in handler1.peers
  388. assert server1_id in handler2.serving_peers
  389. finally:
  390. task1.cancel()
  391. task2.cancel()
  392. try:
  393. await task1
  394. except asyncio.CancelledError:
  395. pass
  396. try:
  397. await task2
  398. except asyncio.CancelledError:
  399. pass
  400. @pytest.mark.asyncio
  401. async def test_peer_connection_fails_with_wrong_secret(self):
  402. """Test that peer connection fails when secrets don't match"""
  403. server1_id = uuid.uuid4()
  404. server2_id = uuid.uuid4()
  405. handler1 = MessageServerHandler(
  406. server_id=server1_id,
  407. listen_address="127.0.0.1",
  408. listen_port=0,
  409. proxy_port=0,
  410. authenticator=create_authenticator("secret-a"),
  411. )
  412. handler2 = MessageServerHandler(
  413. server_id=server2_id,
  414. listen_address="127.0.0.1",
  415. listen_port=0,
  416. proxy_port=0,
  417. authenticator=create_authenticator("secret-b"),
  418. )
  419. _, task1, _ = await start_server(handler1, "127.0.0.1", 0)
  420. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  421. try:
  422. # Server1 tries to connect to Server2 with different secret
  423. _ = await handler1.add_peer("127.0.0.1", port2)
  424. except Exception as e:
  425. # Expect connection to fail due to authentication error
  426. # Server sends HTTP 403 during WebSocket handshake
  427. assert "403" in str(e) or "Authentication failed" in str(e)
  428. finally:
  429. task1.cancel()
  430. task2.cancel()
  431. try:
  432. await task1
  433. except asyncio.CancelledError:
  434. pass
  435. try:
  436. await task2
  437. except asyncio.CancelledError:
  438. pass
  439. @pytest.mark.asyncio
  440. async def test_peer_connection_with_noop_authenticator(self):
  441. """Test that peer connection works when no authenticator is set"""
  442. server1_id = uuid.uuid4()
  443. server2_id = uuid.uuid4()
  444. handler1 = MessageServerHandler(
  445. server_id=server1_id,
  446. listen_address="127.0.0.1",
  447. listen_port=0,
  448. proxy_port=0,
  449. # No authenticator - uses default NoOpAuthenticator
  450. )
  451. handler2 = MessageServerHandler(
  452. server_id=server2_id,
  453. listen_address="127.0.0.1",
  454. listen_port=0,
  455. proxy_port=0,
  456. # No authenticator - uses default NoOpAuthenticator
  457. )
  458. _, task1, port1 = await start_server(handler1, "127.0.0.1", 0)
  459. _, task2, port2 = await start_server(handler2, "127.0.0.1", 0)
  460. try:
  461. # Server1 connects to Server2
  462. peer_id = await handler1.add_peer("127.0.0.1", port2)
  463. assert peer_id == server2_id
  464. await asyncio.sleep(0.5)
  465. assert server2_id in handler1.peers
  466. assert server1_id in handler2.serving_peers
  467. finally:
  468. task1.cancel()
  469. task2.cancel()
  470. try:
  471. await task1
  472. except asyncio.CancelledError:
  473. pass
  474. try:
  475. await task2
  476. except asyncio.CancelledError:
  477. pass
  478. if __name__ == "__main__":
  479. pytest.main([__file__, "-v"])