test_authenticator.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. """
  2. Tests for the authenticator module.
  3. """
  4. import pytest
  5. import uuid
  6. from unittest.mock import MagicMock
  7. from gpustack.websocket_proxy.authenticator import (
  8. Authenticator,
  9. HMACAuthenticator,
  10. NoOpAuthenticator,
  11. create_authenticator,
  12. )
  13. def create_mock_websocket(headers: dict) -> MagicMock:
  14. """Create a mock WebSocket with the given headers."""
  15. mock_ws = MagicMock()
  16. mock_ws.headers = headers
  17. return mock_ws
  18. class TestAuthenticatorClass:
  19. """Tests for Authenticator class"""
  20. @pytest.mark.asyncio
  21. async def test_authenticate_valid_signature(self):
  22. """Test that authenticate returns True for valid signature"""
  23. auth = HMACAuthenticator("secret-key")
  24. server_id = uuid.uuid4()
  25. # Generate headers with signature
  26. headers = {'x-server-id': str(server_id)}
  27. auth.inject_headers(headers)
  28. # Authenticate should succeed
  29. mock_ws = create_mock_websocket(headers)
  30. assert await auth.authenticate(mock_ws) is True
  31. @pytest.mark.asyncio
  32. async def test_authenticate_invalid_signature(self):
  33. """Test that authenticate returns False for invalid signature"""
  34. auth = HMACAuthenticator("secret-key")
  35. server_id = uuid.uuid4()
  36. headers = {
  37. 'x-server-id': str(server_id),
  38. 'x-auth-signature': 'invalid-signature',
  39. }
  40. mock_ws = create_mock_websocket(headers)
  41. assert await auth.authenticate(mock_ws) is False
  42. @pytest.mark.asyncio
  43. async def test_authenticate_missing_signature(self):
  44. """Test that authenticate returns False when signature is missing"""
  45. auth = HMACAuthenticator("secret-key")
  46. server_id = uuid.uuid4()
  47. headers = {
  48. 'x-server-id': str(server_id),
  49. }
  50. mock_ws = create_mock_websocket(headers)
  51. assert await auth.authenticate(mock_ws) is False
  52. @pytest.mark.asyncio
  53. async def test_authenticate_wrong_key(self):
  54. """Test that authenticate fails with wrong key"""
  55. auth1 = HMACAuthenticator("secret-key")
  56. auth2 = HMACAuthenticator("wrong-key")
  57. server_id = uuid.uuid4()
  58. headers = {'x-server-id': str(server_id)}
  59. auth1.inject_headers(headers)
  60. mock_ws = create_mock_websocket(headers)
  61. # auth2 should reject headers signed by auth1
  62. assert await auth2.authenticate(mock_ws) is False
  63. @pytest.mark.asyncio
  64. async def test_authenticate_wrong_server_id(self):
  65. """Test that authenticate fails when server_id header is tampered"""
  66. auth = HMACAuthenticator("secret-key")
  67. server_id_a = uuid.uuid4()
  68. server_id_b = uuid.uuid4()
  69. headers = {'x-server-id': str(server_id_a)}
  70. auth.inject_headers(headers)
  71. # Tamper with x-server-id header
  72. headers['x-server-id'] = str(server_id_b)
  73. mock_ws = create_mock_websocket(headers)
  74. # Should fail since signature was computed for different server_id
  75. assert await auth.authenticate(mock_ws) is False
  76. class TestNoOpAuthenticator:
  77. """Tests for NoOpAuthenticator class"""
  78. def test_inject_headers_no_signature(self):
  79. """Test that NoOpAuthenticator injects no headers (auth headers should be injected by caller)."""
  80. auth = NoOpAuthenticator()
  81. server_id = uuid.uuid4()
  82. headers = {'x-server-id': str(server_id)}
  83. auth.inject_headers(headers)
  84. # No auth headers should be added
  85. assert headers == {'x-server-id': str(server_id)}
  86. @pytest.mark.asyncio
  87. async def test_authenticate_always_true(self):
  88. """Test that NoOpAuthenticator always returns True"""
  89. auth = NoOpAuthenticator()
  90. assert await auth.authenticate(create_mock_websocket({})) is True
  91. assert (
  92. await auth.authenticate(
  93. create_mock_websocket({'x-auth-signature': 'anything'})
  94. )
  95. is True
  96. )
  97. assert (
  98. await auth.authenticate(create_mock_websocket({'x-server-id': 'server'}))
  99. is True
  100. )
  101. class TestCreateAuthenticator:
  102. """Tests for create_authenticator factory"""
  103. def test_with_secret_returns_authenticator(self):
  104. """Test that create_authenticator returns Authenticator when secret is provided"""
  105. auth = create_authenticator("my-secret")
  106. assert isinstance(auth, Authenticator)
  107. def test_without_secret_returns_noop(self):
  108. """Test that create_authenticator returns NoOpAuthenticator when secret is None"""
  109. auth = create_authenticator(None)
  110. assert isinstance(auth, NoOpAuthenticator)
  111. def test_with_empty_secret_returns_noop(self):
  112. """Test that create_authenticator returns NoOpAuthenticator when secret is empty"""
  113. auth = create_authenticator("")
  114. assert isinstance(auth, NoOpAuthenticator)
  115. if __name__ == "__main__":
  116. pytest.main([__file__, "-v"])