test_auth.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import ssl
  2. from unittest.mock import AsyncMock
  3. import pytest
  4. from fastapi.security import HTTPAuthorizationCredentials
  5. from gpustack.api.auth import get_current_user, worker_auth
  6. from gpustack.api.exceptions import UnauthorizedException
  7. from gpustack.routes.auth import oidc_callback
  8. class DummyWorkerConfig:
  9. token = "registration-token"
  10. def get_server_url(self):
  11. return "http://example.com"
  12. @pytest.mark.asyncio
  13. async def test_get_current_user_accepts_x_api_key(monkeypatch):
  14. session = object()
  15. request = type("Request", (), {})()
  16. request.state = type("State", (), {})()
  17. request.headers = {}
  18. request.client = type("Client", (), {"host": "10.0.0.1"})()
  19. request.app = type("App", (), {})()
  20. request.app.state = type("State", (), {})()
  21. request.app.state.server_config = type(
  22. "Config", (), {"gateway_mode": None, "force_auth_localhost": True}
  23. )()
  24. expected_user = type("User", (), {"is_active": True})()
  25. expected_key = object()
  26. auth_mock = AsyncMock(return_value=(expected_user, expected_key))
  27. monkeypatch.setattr("gpustack.api.auth.get_user_from_api_token", auth_mock)
  28. user = await get_current_user(
  29. request=request,
  30. session=session,
  31. x_api_key="sk_test_value",
  32. )
  33. auth_mock.assert_awaited_once_with(session, "sk_test_value")
  34. assert user is expected_user
  35. assert request.state.user is expected_user
  36. assert request.state.api_key is expected_key
  37. @pytest.mark.asyncio
  38. async def test_worker_auth_accepts_x_api_key():
  39. request = type("Request", (), {})()
  40. request.headers = {"X-Higress-Llm-Model": "claude-sonnet"}
  41. request.app = type("App", (), {})()
  42. request.app.state = type("State", (), {})()
  43. request.app.state.token = "worker-token"
  44. request.app.state.config = DummyWorkerConfig()
  45. request.app.state.http_client_no_proxy = object()
  46. assert await worker_auth(request=request, x_api_key="worker-token") is None
  47. @pytest.mark.asyncio
  48. async def test_worker_auth_rejects_missing_credentials():
  49. request = type("Request", (), {})()
  50. request.headers = {"X-Higress-Llm-Model": "claude-sonnet"}
  51. request.app = type("App", (), {})()
  52. request.app.state = type("State", (), {})()
  53. request.app.state.token = "worker-token"
  54. request.app.state.config = DummyWorkerConfig()
  55. request.app.state.http_client_no_proxy = object()
  56. with pytest.raises(UnauthorizedException):
  57. await worker_auth(request=request)
  58. @pytest.mark.asyncio
  59. async def test_get_current_user_falls_back_to_x_api_key_when_bearer_empty(
  60. monkeypatch,
  61. ):
  62. session = object()
  63. request = type("Request", (), {})()
  64. request.state = type("State", (), {})()
  65. request.headers = {}
  66. request.client = type("Client", (), {"host": "10.0.0.1"})()
  67. request.app = type("App", (), {})()
  68. request.app.state = type("State", (), {})()
  69. request.app.state.server_config = type(
  70. "Config", (), {"gateway_mode": None, "force_auth_localhost": True}
  71. )()
  72. expected_user = type("User", (), {"is_active": True})()
  73. expected_key = object()
  74. auth_mock = AsyncMock(return_value=(expected_user, expected_key))
  75. monkeypatch.setattr("gpustack.api.auth.get_user_from_api_token", auth_mock)
  76. user = await get_current_user(
  77. request=request,
  78. session=session,
  79. bearer_token=HTTPAuthorizationCredentials(scheme="Bearer", credentials=""),
  80. x_api_key="sk_test_value",
  81. )
  82. auth_mock.assert_awaited_once_with(session, "sk_test_value")
  83. assert user is expected_user
  84. @pytest.mark.asyncio
  85. async def test_worker_auth_falls_back_to_x_api_key_when_bearer_empty():
  86. request = type("Request", (), {})()
  87. request.headers = {"X-Higress-Llm-Model": "claude-sonnet"}
  88. request.app = type("App", (), {})()
  89. request.app.state = type("State", (), {})()
  90. request.app.state.token = "worker-token"
  91. request.app.state.config = DummyWorkerConfig()
  92. request.app.state.http_client_no_proxy = object()
  93. assert (
  94. await worker_auth(
  95. request=request,
  96. bearer_token=HTTPAuthorizationCredentials(scheme="Bearer", credentials=""),
  97. x_api_key="worker-token",
  98. )
  99. is None
  100. )
  101. @pytest.mark.asyncio
  102. async def test_oidc_callback_uses_system_trust_store(monkeypatch):
  103. captured = {}
  104. class FakeAsyncClient:
  105. def __init__(self, **kwargs):
  106. captured.update(kwargs)
  107. async def __aenter__(self):
  108. return self
  109. async def __aexit__(self, exc_type, exc, tb):
  110. return None
  111. async def request(self, method, url, data=None):
  112. return type(
  113. "Resp",
  114. (),
  115. {
  116. "status_code": 200,
  117. "text": '{"access_token":"token","id_token":"id"}',
  118. },
  119. )()
  120. request = type("Request", (), {})()
  121. request.app = type("App", (), {})()
  122. request.app.state = type("State", (), {})()
  123. request.app.state.server_config = type(
  124. "Config",
  125. (),
  126. {
  127. "oidc_client_id": "client-id",
  128. "oidc_client_secret": "client-secret",
  129. "oidc_redirect_uri": "https://gpustack.example.com/auth/oidc/callback",
  130. "openid_configuration": {
  131. "token_endpoint": "https://issuer.example.com/token"
  132. },
  133. "external_auth_name": None,
  134. "external_auth_full_name": None,
  135. "external_auth_avatar_url": None,
  136. "external_auth_default_inactive": False,
  137. },
  138. )()
  139. request.app.state.jwt_manager = type(
  140. "JWTManager", (), {"create_jwt_token": lambda self, username: "jwt-token"}
  141. )()
  142. request.query_params = {"code": "auth-code"}
  143. monkeypatch.setattr("gpustack.routes.auth.httpx.AsyncClient", FakeAsyncClient)
  144. monkeypatch.setattr("gpustack.routes.auth.use_proxy_env_for_url", lambda url: False)
  145. monkeypatch.setattr(
  146. "gpustack.routes.auth.get_oidc_user_data",
  147. AsyncMock(return_value={"email": "user@example.com", "name": "Test User"}),
  148. )
  149. monkeypatch.setattr(
  150. "gpustack.routes.auth.User.first_by_field", AsyncMock(return_value=object())
  151. )
  152. response = await oidc_callback(request=request, session=object())
  153. assert response.status_code in (302, 307)
  154. assert captured["trust_env"] is False
  155. assert captured["timeout"] is not None
  156. assert isinstance(captured["verify"], ssl.SSLContext)