| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- import ssl
- from unittest.mock import AsyncMock
- import pytest
- from fastapi.security import HTTPAuthorizationCredentials
- from gpustack.api.auth import get_current_user, worker_auth
- from gpustack.api.exceptions import UnauthorizedException
- from gpustack.routes.auth import oidc_callback
- class DummyWorkerConfig:
- token = "registration-token"
- def get_server_url(self):
- return "http://example.com"
- @pytest.mark.asyncio
- async def test_get_current_user_accepts_x_api_key(monkeypatch):
- session = object()
- request = type("Request", (), {})()
- request.state = type("State", (), {})()
- request.headers = {}
- request.client = type("Client", (), {"host": "10.0.0.1"})()
- request.app = type("App", (), {})()
- request.app.state = type("State", (), {})()
- request.app.state.server_config = type(
- "Config", (), {"gateway_mode": None, "force_auth_localhost": True}
- )()
- expected_user = type("User", (), {"is_active": True})()
- expected_key = object()
- auth_mock = AsyncMock(return_value=(expected_user, expected_key))
- monkeypatch.setattr("gpustack.api.auth.get_user_from_api_token", auth_mock)
- user = await get_current_user(
- request=request,
- session=session,
- x_api_key="sk_test_value",
- )
- auth_mock.assert_awaited_once_with(session, "sk_test_value")
- assert user is expected_user
- assert request.state.user is expected_user
- assert request.state.api_key is expected_key
- @pytest.mark.asyncio
- async def test_worker_auth_accepts_x_api_key():
- request = type("Request", (), {})()
- request.headers = {"X-Higress-Llm-Model": "claude-sonnet"}
- request.app = type("App", (), {})()
- request.app.state = type("State", (), {})()
- request.app.state.token = "worker-token"
- request.app.state.config = DummyWorkerConfig()
- request.app.state.http_client_no_proxy = object()
- assert await worker_auth(request=request, x_api_key="worker-token") is None
- @pytest.mark.asyncio
- async def test_worker_auth_rejects_missing_credentials():
- request = type("Request", (), {})()
- request.headers = {"X-Higress-Llm-Model": "claude-sonnet"}
- request.app = type("App", (), {})()
- request.app.state = type("State", (), {})()
- request.app.state.token = "worker-token"
- request.app.state.config = DummyWorkerConfig()
- request.app.state.http_client_no_proxy = object()
- with pytest.raises(UnauthorizedException):
- await worker_auth(request=request)
- @pytest.mark.asyncio
- async def test_get_current_user_falls_back_to_x_api_key_when_bearer_empty(
- monkeypatch,
- ):
- session = object()
- request = type("Request", (), {})()
- request.state = type("State", (), {})()
- request.headers = {}
- request.client = type("Client", (), {"host": "10.0.0.1"})()
- request.app = type("App", (), {})()
- request.app.state = type("State", (), {})()
- request.app.state.server_config = type(
- "Config", (), {"gateway_mode": None, "force_auth_localhost": True}
- )()
- expected_user = type("User", (), {"is_active": True})()
- expected_key = object()
- auth_mock = AsyncMock(return_value=(expected_user, expected_key))
- monkeypatch.setattr("gpustack.api.auth.get_user_from_api_token", auth_mock)
- user = await get_current_user(
- request=request,
- session=session,
- bearer_token=HTTPAuthorizationCredentials(scheme="Bearer", credentials=""),
- x_api_key="sk_test_value",
- )
- auth_mock.assert_awaited_once_with(session, "sk_test_value")
- assert user is expected_user
- @pytest.mark.asyncio
- async def test_worker_auth_falls_back_to_x_api_key_when_bearer_empty():
- request = type("Request", (), {})()
- request.headers = {"X-Higress-Llm-Model": "claude-sonnet"}
- request.app = type("App", (), {})()
- request.app.state = type("State", (), {})()
- request.app.state.token = "worker-token"
- request.app.state.config = DummyWorkerConfig()
- request.app.state.http_client_no_proxy = object()
- assert (
- await worker_auth(
- request=request,
- bearer_token=HTTPAuthorizationCredentials(scheme="Bearer", credentials=""),
- x_api_key="worker-token",
- )
- is None
- )
- @pytest.mark.asyncio
- async def test_oidc_callback_uses_system_trust_store(monkeypatch):
- captured = {}
- class FakeAsyncClient:
- def __init__(self, **kwargs):
- captured.update(kwargs)
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc, tb):
- return None
- async def request(self, method, url, data=None):
- return type(
- "Resp",
- (),
- {
- "status_code": 200,
- "text": '{"access_token":"token","id_token":"id"}',
- },
- )()
- request = type("Request", (), {})()
- request.app = type("App", (), {})()
- request.app.state = type("State", (), {})()
- request.app.state.server_config = type(
- "Config",
- (),
- {
- "oidc_client_id": "client-id",
- "oidc_client_secret": "client-secret",
- "oidc_redirect_uri": "https://gpustack.example.com/auth/oidc/callback",
- "openid_configuration": {
- "token_endpoint": "https://issuer.example.com/token"
- },
- "external_auth_name": None,
- "external_auth_full_name": None,
- "external_auth_avatar_url": None,
- "external_auth_default_inactive": False,
- },
- )()
- request.app.state.jwt_manager = type(
- "JWTManager", (), {"create_jwt_token": lambda self, username: "jwt-token"}
- )()
- request.query_params = {"code": "auth-code"}
- monkeypatch.setattr("gpustack.routes.auth.httpx.AsyncClient", FakeAsyncClient)
- monkeypatch.setattr("gpustack.routes.auth.use_proxy_env_for_url", lambda url: False)
- monkeypatch.setattr(
- "gpustack.routes.auth.get_oidc_user_data",
- AsyncMock(return_value={"email": "user@example.com", "name": "Test User"}),
- )
- monkeypatch.setattr(
- "gpustack.routes.auth.User.first_by_field", AsyncMock(return_value=object())
- )
- response = await oidc_callback(request=request, session=object())
- assert response.status_code in (302, 307)
- assert captured["trust_env"] is False
- assert captured["timeout"] is not None
- assert isinstance(captured["verify"], ssl.SSLContext)
|