generated_http_client.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import ssl
  2. from typing import Any, Dict, Optional, Union
  3. import truststore
  4. import httpx
  5. from attrs import define, evolve, field
  6. default_versioned_prefix = "/v2"
  7. @define
  8. class HTTPClient:
  9. """A class for keeping track of data related to the API
  10. The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
  11. ``base_url``: The base URL for the API, all requests are made to a relative path to this URL
  12. ``cookies``: A dictionary of cookies to be sent with every request
  13. ``headers``: A dictionary of headers to be sent with every request
  14. ``timeout``: The maximum amount of a time a request can take. API functions will raise
  15. httpx.TimeoutException if this is exceeded.
  16. ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
  17. but can be set to False for testing purposes.
  18. ``follow_redirects``: Whether or not to follow redirects. Default value is False.
  19. ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
  20. ``versioned_prefix``: A string to append to the base_url for versioning purposes (e.g., "/v2"). Default is "/v2".
  21. Attributes:
  22. raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
  23. status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
  24. argument to the constructor.
  25. """
  26. raise_on_unexpected_status: bool = field(default=False, kw_only=True)
  27. _base_url: str = field(alias="base_url")
  28. _versioned_prefix: str = field(
  29. alias="versioned_prefix", kw_only=True, default=default_versioned_prefix
  30. )
  31. _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
  32. _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
  33. _timeout: Optional[httpx.Timeout] = field(
  34. default=None, kw_only=True, alias="timeout"
  35. )
  36. _verify_ssl: Union[str, bool, ssl.SSLContext] = field(
  37. default=True, kw_only=True, alias="verify_ssl"
  38. )
  39. _follow_redirects: bool = field(
  40. default=False, kw_only=True, alias="follow_redirects"
  41. )
  42. _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
  43. _client: Optional[httpx.Client] = field(default=None, init=False)
  44. _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
  45. @property
  46. def versioned_url(self) -> str:
  47. return self._base_url + self._versioned_prefix
  48. def with_headers(self, headers: Dict[str, str]) -> "HTTPClient":
  49. """Get a new client matching this one with additional headers"""
  50. if self._client is not None:
  51. self._client.headers.update(headers)
  52. if self._async_client is not None:
  53. self._async_client.headers.update(headers)
  54. return evolve(self, headers={**self._headers, **headers})
  55. def with_cookies(self, cookies: Dict[str, str]) -> "HTTPClient":
  56. """Get a new client matching this one with additional cookies"""
  57. if self._client is not None:
  58. self._client.cookies.update(cookies)
  59. if self._async_client is not None:
  60. self._async_client.cookies.update(cookies)
  61. return evolve(self, cookies={**self._cookies, **cookies})
  62. def with_timeout(self, timeout: httpx.Timeout) -> "HTTPClient":
  63. """Get a new client matching this one with a new timeout (in seconds)"""
  64. if self._client is not None:
  65. self._client.timeout = timeout
  66. if self._async_client is not None:
  67. self._async_client.timeout = timeout
  68. return evolve(self, timeout=timeout)
  69. def set_httpx_client(self, client: httpx.Client) -> "HTTPClient":
  70. """Manually the underlying httpx.Client
  71. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  72. """
  73. self._client = client
  74. return self
  75. def get_httpx_client(self) -> httpx.Client:
  76. """Get the underlying httpx.Client, constructing a new one if not previously set"""
  77. # Use system trust store.
  78. verify = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  79. if self._verify_ssl is not None:
  80. verify = self._verify_ssl
  81. if self._client is None:
  82. self._client = httpx.Client(
  83. base_url=self.versioned_url,
  84. cookies=self._cookies,
  85. headers=self._headers,
  86. timeout=self._timeout,
  87. verify=verify,
  88. follow_redirects=self._follow_redirects,
  89. **self._httpx_args,
  90. )
  91. return self._client
  92. def __enter__(self) -> "HTTPClient":
  93. """Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
  94. self.get_httpx_client().__enter__()
  95. return self
  96. def __exit__(self, *args: Any, **kwargs: Any) -> None:
  97. """Exit a context manager for internal httpx.Client (see httpx docs)"""
  98. self.get_httpx_client().__exit__(*args, **kwargs)
  99. def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "HTTPClient":
  100. """Manually the underlying httpx.AsyncClient
  101. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  102. """
  103. self._async_client = async_client
  104. return self
  105. def get_async_httpx_client(self) -> httpx.AsyncClient:
  106. """Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
  107. # Use system trust store.
  108. verify = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  109. if self._verify_ssl is not None:
  110. verify = self._verify_ssl
  111. if self._async_client is None:
  112. self._async_client = httpx.AsyncClient(
  113. base_url=self.versioned_url,
  114. cookies=self._cookies,
  115. headers=self._headers,
  116. timeout=self._timeout,
  117. verify=verify,
  118. follow_redirects=self._follow_redirects,
  119. **self._httpx_args,
  120. )
  121. return self._async_client
  122. async def __aenter__(self) -> "HTTPClient":
  123. """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
  124. await self.get_async_httpx_client().__aenter__()
  125. return self
  126. async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
  127. """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
  128. await self.get_async_httpx_client().__aexit__(*args, **kwargs)
  129. @define
  130. class AuthenticatedHTTPClient:
  131. """A Client which has been authenticated for use on secured endpoints
  132. The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
  133. ``base_url``: The base URL for the API, all requests are made to a relative path to this URL
  134. ``cookies``: A dictionary of cookies to be sent with every request
  135. ``headers``: A dictionary of headers to be sent with every request
  136. ``timeout``: The maximum amount of a time a request can take. API functions will raise
  137. httpx.TimeoutException if this is exceeded.
  138. ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
  139. but can be set to False for testing purposes.
  140. ``follow_redirects``: Whether or not to follow redirects. Default value is False.
  141. ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
  142. ``versioned_prefix``: A string to append to the base_url for versioning purposes (e.g., "/v2"). Default is "/v2".
  143. Attributes:
  144. raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
  145. status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
  146. argument to the constructor.
  147. token: The token to use for authentication
  148. prefix: The prefix to use for the Authorization header
  149. auth_header_name: The name of the Authorization header
  150. """
  151. raise_on_unexpected_status: bool = field(default=False, kw_only=True)
  152. _base_url: str = field(alias="base_url")
  153. _versioned_prefix: str = field(
  154. alias="versioned_prefix", kw_only=True, default=default_versioned_prefix
  155. )
  156. _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
  157. _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
  158. _timeout: Optional[httpx.Timeout] = field(
  159. default=None, kw_only=True, alias="timeout"
  160. )
  161. _verify_ssl: Union[str, bool, ssl.SSLContext] = field(
  162. default=True, kw_only=True, alias="verify_ssl"
  163. )
  164. _follow_redirects: bool = field(
  165. default=False, kw_only=True, alias="follow_redirects"
  166. )
  167. _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
  168. _client: Optional[httpx.Client] = field(default=None, init=False)
  169. _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
  170. token: str
  171. prefix: str = "Bearer"
  172. auth_header_name: str = "Authorization"
  173. @property
  174. def versioned_url(self) -> str:
  175. return self._base_url + self._versioned_prefix
  176. def with_headers(self, headers: Dict[str, str]) -> "AuthenticatedHTTPClient":
  177. """Get a new client matching this one with additional headers"""
  178. if self._client is not None:
  179. self._client.headers.update(headers)
  180. if self._async_client is not None:
  181. self._async_client.headers.update(headers)
  182. return evolve(self, headers={**self._headers, **headers})
  183. def with_cookies(self, cookies: Dict[str, str]) -> "AuthenticatedHTTPClient":
  184. """Get a new client matching this one with additional cookies"""
  185. if self._client is not None:
  186. self._client.cookies.update(cookies)
  187. if self._async_client is not None:
  188. self._async_client.cookies.update(cookies)
  189. return evolve(self, cookies={**self._cookies, **cookies})
  190. def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedHTTPClient":
  191. """Get a new client matching this one with a new timeout (in seconds)"""
  192. if self._client is not None:
  193. self._client.timeout = timeout
  194. if self._async_client is not None:
  195. self._async_client.timeout = timeout
  196. return evolve(self, timeout=timeout)
  197. def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedHTTPClient":
  198. """Manually the underlying httpx.Client
  199. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  200. """
  201. self._client = client
  202. return self
  203. def get_httpx_client(self) -> httpx.Client:
  204. """Get the underlying httpx.Client, constructing a new one if not previously set"""
  205. # Use system trust store.
  206. verify = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  207. if self._verify_ssl is not None:
  208. verify = self._verify_ssl
  209. if self._client is None:
  210. self._headers[self.auth_header_name] = (
  211. f"{self.prefix} {self.token}" if self.prefix else self.token
  212. )
  213. self._client = httpx.Client(
  214. base_url=self.versioned_url,
  215. cookies=self._cookies,
  216. headers=self._headers,
  217. timeout=self._timeout,
  218. verify=verify,
  219. follow_redirects=self._follow_redirects,
  220. **self._httpx_args,
  221. )
  222. return self._client
  223. def __enter__(self) -> "AuthenticatedHTTPClient":
  224. """Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
  225. self.get_httpx_client().__enter__()
  226. return self
  227. def __exit__(self, *args: Any, **kwargs: Any) -> None:
  228. """Exit a context manager for internal httpx.Client (see httpx docs)"""
  229. self.get_httpx_client().__exit__(*args, **kwargs)
  230. def set_async_httpx_client(
  231. self, async_client: httpx.AsyncClient
  232. ) -> "AuthenticatedHTTPClient":
  233. """Manually the underlying httpx.AsyncClient
  234. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  235. """
  236. self._async_client = async_client
  237. return self
  238. def get_async_httpx_client(self) -> httpx.AsyncClient:
  239. """Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
  240. if self._async_client is None:
  241. self._headers[self.auth_header_name] = (
  242. f"{self.prefix} {self.token}" if self.prefix else self.token
  243. )
  244. self._async_client = httpx.AsyncClient(
  245. base_url=self.versioned_url,
  246. cookies=self._cookies,
  247. headers=self._headers,
  248. timeout=self._timeout,
  249. verify=self._verify_ssl,
  250. follow_redirects=self._follow_redirects,
  251. **self._httpx_args,
  252. )
  253. return self._async_client
  254. async def __aenter__(self) -> "AuthenticatedHTTPClient":
  255. """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
  256. await self.get_async_httpx_client().__aenter__()
  257. return self
  258. async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
  259. """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
  260. await self.get_async_httpx_client().__aexit__(*args, **kwargs)