| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- import uuid
- import pytest
- import json
- from gpustack.websocket_proxy.message import (
- # Classes
- ConnectRequestMessage,
- ConnectResponseMessage,
- DataMessage,
- DisconnectMessage,
- HeartbeatMessage,
- ListClientsMessage,
- ListClientsResponseMessage,
- ClientInfo,
- # Server <-> Server messages
- ClientUpdateMessage,
- ClientUpdateInfo,
- # Functions
- parse_message,
- message_to_json,
- # Constants
- DATA_COMPRESSION_NONE,
- DATA_COMPRESSION_GZIP,
- PROTOCOL_VERSION,
- )
- class TestConnectRequestMessage:
- def test_tcp(self):
- session_id = uuid.uuid4()
- msg = ConnectRequestMessage(
- session_id=session_id,
- target_url="tcp://192.168.1.1:8080",
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert isinstance(parsed, ConnectRequestMessage)
- assert parsed.session_id == session_id
- assert parsed.target_url == "tcp://192.168.1.1:8080"
- def test_unix(self):
- session_id = uuid.uuid4()
- msg = ConnectRequestMessage(
- session_id=session_id,
- target_url="unix:///var/run/app.sock",
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert isinstance(parsed, ConnectRequestMessage)
- assert parsed.target_url == "unix:///var/run/app.sock"
- class TestConnectResponseMessage:
- def test_success(self):
- session_id = uuid.uuid4()
- msg = ConnectResponseMessage(session_id=session_id, success=True)
- data = msg.pack()
- parsed = parse_message(data)
- assert isinstance(parsed, ConnectResponseMessage)
- assert parsed.success is True
- assert parsed.error is None
- def test_failure(self):
- session_id = uuid.uuid4()
- msg = ConnectResponseMessage(
- session_id=session_id, success=False, error="Connection refused"
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.success is False
- assert parsed.error == "Connection refused"
- class TestDataMessage:
- def test_no_compression(self):
- session_id = uuid.uuid4()
- test_data = b"Hello, World!"
- msg = DataMessage(
- session_id=session_id, data=test_data, compression=DATA_COMPRESSION_NONE
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert isinstance(parsed, DataMessage)
- assert parsed.session_id == session_id
- assert parsed.data == test_data
- assert parsed.compression == DATA_COMPRESSION_NONE
- def test_with_compression(self):
- session_id = uuid.uuid4()
- # Use repeated data to benefit from compression
- test_data = b"A" * 1000
- msg = DataMessage(
- session_id=session_id, data=test_data, compression=DATA_COMPRESSION_GZIP
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.data == test_data
- assert parsed.compression == DATA_COMPRESSION_GZIP
- class TestDisconnectMessage:
- def test_normal_disconnect(self):
- session_id = uuid.uuid4()
- msg = DisconnectMessage(session_id=session_id)
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.session_id == session_id
- assert parsed.error is None
- def test_with_error(self):
- session_id = uuid.uuid4()
- msg = DisconnectMessage(session_id=session_id, error="Server closed")
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.error == "Server closed"
- class TestHeartbeatMessage:
- def test_pack_and_parse(self):
- import time
- msg = HeartbeatMessage(timestamp=int(time.time()))
- data = msg.pack()
- parsed = parse_message(data)
- assert isinstance(parsed, HeartbeatMessage)
- assert parsed.timestamp == msg.timestamp
- class TestListClientsMessage:
- def test_pack_and_parse(self):
- msg = ListClientsMessage()
- data = msg.pack()
- parsed = parse_message(data)
- assert isinstance(parsed, ListClientsMessage)
- class TestListClientsResponseMessage:
- def test_empty_clients(self):
- msg = ListClientsResponseMessage(clients=[])
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.clients == []
- def test_multiple_clients(self):
- clients = [
- ClientInfo(
- client_id=uuid.uuid4(),
- cidrs=["192.168.1.100"],
- unix_sockets=["/var/run/a.sock"],
- ),
- ClientInfo(
- client_id=uuid.uuid4(),
- cidrs=["10.0.0.1", "10.0.0.2"],
- unix_sockets=[],
- ),
- ]
- msg = ListClientsResponseMessage(clients=clients)
- data = msg.pack()
- parsed = parse_message(data)
- assert len(parsed.clients) == 2
- assert parsed.clients[0].cidrs == ["192.168.1.100"]
- assert parsed.clients[1].cidrs == ["10.0.0.1", "10.0.0.2"]
- class TestMessageToJson:
- def test_connect_request_message(self):
- msg = ConnectRequestMessage(
- session_id=uuid.uuid4(),
- target_url="tcp://192.168.1.1:8080",
- )
- json_str = message_to_json(msg)
- data = json.loads(json_str)
- assert data["type"] == "connect_request"
- assert "session_id" in data
- assert data["target_url"] == "tcp://192.168.1.1:8080"
- def test_data_message(self):
- msg = DataMessage(session_id=uuid.uuid4(), data=b"\x00\x01\x02\x03")
- json_str = message_to_json(msg)
- data = json.loads(json_str)
- assert data["type"] == "data"
- assert data["data"] == "00010203"
- class TestProtocolVersion:
- def test_invalid_version(self):
- msg = ConnectRequestMessage(
- session_id=uuid.uuid4(), target_url="tcp://1.2.3.4:80"
- )
- data = msg.pack()
- # Corrupt version byte
- corrupted = bytes([0xFF]) + data[1:]
- with pytest.raises(ValueError, match="Unsupported protocol version"):
- parse_message(corrupted)
- class TestInvalidMessages:
- def test_too_short(self):
- with pytest.raises(ValueError, match="Message too short"):
- parse_message(b"\x01")
- def test_unknown_type(self):
- # Create valid header with unknown type
- data = bytes([PROTOCOL_VERSION, 0xFF])
- with pytest.raises(ValueError, match="Unknown binary message type"):
- parse_message(data)
- class TestClientUpdateMessage:
- def test_add_client(self):
- server_id = uuid.uuid4()
- client_id = uuid.uuid4()
- msg = ClientUpdateMessage(
- server_id=server_id,
- updates=[
- ClientUpdateInfo(
- client_id=client_id,
- action="add",
- cidrs=["192.168.1.100"],
- unix_sockets=["/var/run/app.sock"],
- )
- ],
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.server_id == server_id
- assert len(parsed.updates) == 1
- assert parsed.updates[0].action == "add"
- assert parsed.updates[0].cidrs == ["192.168.1.100"]
- def test_remove_client(self):
- server_id = uuid.uuid4()
- client_id = uuid.uuid4()
- msg = ClientUpdateMessage(
- server_id=server_id,
- updates=[ClientUpdateInfo(client_id=client_id, action="remove")],
- )
- data = msg.pack()
- parsed = parse_message(data)
- assert parsed.updates[0].action == "remove"
- if __name__ == "__main__":
- pytest.main([__file__, "-v"])
|