| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- import ssl
- from typing import Any, Dict, Optional, Union
- import truststore
- import httpx
- from attrs import define, evolve, field
- default_versioned_prefix = "/v2"
- @define
- class HTTPClient:
- """A class for keeping track of data related to the API
- The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
- ``base_url``: The base URL for the API, all requests are made to a relative path to this URL
- ``cookies``: A dictionary of cookies to be sent with every request
- ``headers``: A dictionary of headers to be sent with every request
- ``timeout``: The maximum amount of a time a request can take. API functions will raise
- httpx.TimeoutException if this is exceeded.
- ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
- but can be set to False for testing purposes.
- ``follow_redirects``: Whether or not to follow redirects. Default value is False.
- ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
- ``versioned_prefix``: A string to append to the base_url for versioning purposes (e.g., "/v2"). Default is "/v2".
- Attributes:
- raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
- status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
- argument to the constructor.
- """
- raise_on_unexpected_status: bool = field(default=False, kw_only=True)
- _base_url: str = field(alias="base_url")
- _versioned_prefix: str = field(
- alias="versioned_prefix", kw_only=True, default=default_versioned_prefix
- )
- _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
- _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
- _timeout: Optional[httpx.Timeout] = field(
- default=None, kw_only=True, alias="timeout"
- )
- _verify_ssl: Union[str, bool, ssl.SSLContext] = field(
- default=True, kw_only=True, alias="verify_ssl"
- )
- _follow_redirects: bool = field(
- default=False, kw_only=True, alias="follow_redirects"
- )
- _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
- _client: Optional[httpx.Client] = field(default=None, init=False)
- _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
- @property
- def versioned_url(self) -> str:
- return self._base_url + self._versioned_prefix
- def with_headers(self, headers: Dict[str, str]) -> "HTTPClient":
- """Get a new client matching this one with additional headers"""
- if self._client is not None:
- self._client.headers.update(headers)
- if self._async_client is not None:
- self._async_client.headers.update(headers)
- return evolve(self, headers={**self._headers, **headers})
- def with_cookies(self, cookies: Dict[str, str]) -> "HTTPClient":
- """Get a new client matching this one with additional cookies"""
- if self._client is not None:
- self._client.cookies.update(cookies)
- if self._async_client is not None:
- self._async_client.cookies.update(cookies)
- return evolve(self, cookies={**self._cookies, **cookies})
- def with_timeout(self, timeout: httpx.Timeout) -> "HTTPClient":
- """Get a new client matching this one with a new timeout (in seconds)"""
- if self._client is not None:
- self._client.timeout = timeout
- if self._async_client is not None:
- self._async_client.timeout = timeout
- return evolve(self, timeout=timeout)
- def set_httpx_client(self, client: httpx.Client) -> "HTTPClient":
- """Manually the underlying httpx.Client
- **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
- """
- self._client = client
- return self
- def get_httpx_client(self) -> httpx.Client:
- """Get the underlying httpx.Client, constructing a new one if not previously set"""
- # Use system trust store.
- verify = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- if self._verify_ssl is not None:
- verify = self._verify_ssl
- if self._client is None:
- self._client = httpx.Client(
- base_url=self.versioned_url,
- cookies=self._cookies,
- headers=self._headers,
- timeout=self._timeout,
- verify=verify,
- follow_redirects=self._follow_redirects,
- **self._httpx_args,
- )
- return self._client
- def __enter__(self) -> "HTTPClient":
- """Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
- self.get_httpx_client().__enter__()
- return self
- def __exit__(self, *args: Any, **kwargs: Any) -> None:
- """Exit a context manager for internal httpx.Client (see httpx docs)"""
- self.get_httpx_client().__exit__(*args, **kwargs)
- def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "HTTPClient":
- """Manually the underlying httpx.AsyncClient
- **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
- """
- self._async_client = async_client
- return self
- def get_async_httpx_client(self) -> httpx.AsyncClient:
- """Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
- # Use system trust store.
- verify = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- if self._verify_ssl is not None:
- verify = self._verify_ssl
- if self._async_client is None:
- self._async_client = httpx.AsyncClient(
- base_url=self.versioned_url,
- cookies=self._cookies,
- headers=self._headers,
- timeout=self._timeout,
- verify=verify,
- follow_redirects=self._follow_redirects,
- **self._httpx_args,
- )
- return self._async_client
- async def __aenter__(self) -> "HTTPClient":
- """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
- await self.get_async_httpx_client().__aenter__()
- return self
- async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
- """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
- await self.get_async_httpx_client().__aexit__(*args, **kwargs)
- @define
- class AuthenticatedHTTPClient:
- """A Client which has been authenticated for use on secured endpoints
- The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
- ``base_url``: The base URL for the API, all requests are made to a relative path to this URL
- ``cookies``: A dictionary of cookies to be sent with every request
- ``headers``: A dictionary of headers to be sent with every request
- ``timeout``: The maximum amount of a time a request can take. API functions will raise
- httpx.TimeoutException if this is exceeded.
- ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
- but can be set to False for testing purposes.
- ``follow_redirects``: Whether or not to follow redirects. Default value is False.
- ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
- ``versioned_prefix``: A string to append to the base_url for versioning purposes (e.g., "/v2"). Default is "/v2".
- Attributes:
- raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
- status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
- argument to the constructor.
- token: The token to use for authentication
- prefix: The prefix to use for the Authorization header
- auth_header_name: The name of the Authorization header
- """
- raise_on_unexpected_status: bool = field(default=False, kw_only=True)
- _base_url: str = field(alias="base_url")
- _versioned_prefix: str = field(
- alias="versioned_prefix", kw_only=True, default=default_versioned_prefix
- )
- _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
- _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
- _timeout: Optional[httpx.Timeout] = field(
- default=None, kw_only=True, alias="timeout"
- )
- _verify_ssl: Union[str, bool, ssl.SSLContext] = field(
- default=True, kw_only=True, alias="verify_ssl"
- )
- _follow_redirects: bool = field(
- default=False, kw_only=True, alias="follow_redirects"
- )
- _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
- _client: Optional[httpx.Client] = field(default=None, init=False)
- _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
- token: str
- prefix: str = "Bearer"
- auth_header_name: str = "Authorization"
- @property
- def versioned_url(self) -> str:
- return self._base_url + self._versioned_prefix
- def with_headers(self, headers: Dict[str, str]) -> "AuthenticatedHTTPClient":
- """Get a new client matching this one with additional headers"""
- if self._client is not None:
- self._client.headers.update(headers)
- if self._async_client is not None:
- self._async_client.headers.update(headers)
- return evolve(self, headers={**self._headers, **headers})
- def with_cookies(self, cookies: Dict[str, str]) -> "AuthenticatedHTTPClient":
- """Get a new client matching this one with additional cookies"""
- if self._client is not None:
- self._client.cookies.update(cookies)
- if self._async_client is not None:
- self._async_client.cookies.update(cookies)
- return evolve(self, cookies={**self._cookies, **cookies})
- def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedHTTPClient":
- """Get a new client matching this one with a new timeout (in seconds)"""
- if self._client is not None:
- self._client.timeout = timeout
- if self._async_client is not None:
- self._async_client.timeout = timeout
- return evolve(self, timeout=timeout)
- def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedHTTPClient":
- """Manually the underlying httpx.Client
- **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
- """
- self._client = client
- return self
- def get_httpx_client(self) -> httpx.Client:
- """Get the underlying httpx.Client, constructing a new one if not previously set"""
- # Use system trust store.
- verify = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- if self._verify_ssl is not None:
- verify = self._verify_ssl
- if self._client is None:
- self._headers[self.auth_header_name] = (
- f"{self.prefix} {self.token}" if self.prefix else self.token
- )
- self._client = httpx.Client(
- base_url=self.versioned_url,
- cookies=self._cookies,
- headers=self._headers,
- timeout=self._timeout,
- verify=verify,
- follow_redirects=self._follow_redirects,
- **self._httpx_args,
- )
- return self._client
- def __enter__(self) -> "AuthenticatedHTTPClient":
- """Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
- self.get_httpx_client().__enter__()
- return self
- def __exit__(self, *args: Any, **kwargs: Any) -> None:
- """Exit a context manager for internal httpx.Client (see httpx docs)"""
- self.get_httpx_client().__exit__(*args, **kwargs)
- def set_async_httpx_client(
- self, async_client: httpx.AsyncClient
- ) -> "AuthenticatedHTTPClient":
- """Manually the underlying httpx.AsyncClient
- **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
- """
- self._async_client = async_client
- return self
- def get_async_httpx_client(self) -> httpx.AsyncClient:
- """Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
- if self._async_client is None:
- self._headers[self.auth_header_name] = (
- f"{self.prefix} {self.token}" if self.prefix else self.token
- )
- self._async_client = httpx.AsyncClient(
- base_url=self.versioned_url,
- cookies=self._cookies,
- headers=self._headers,
- timeout=self._timeout,
- verify=self._verify_ssl,
- follow_redirects=self._follow_redirects,
- **self._httpx_args,
- )
- return self._async_client
- async def __aenter__(self) -> "AuthenticatedHTTPClient":
- """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
- await self.get_async_httpx_client().__aenter__()
- return self
- async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
- """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
- await self.get_async_httpx_client().__aexit__(*args, **kwargs)
|