network.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import asyncio
  2. import contextlib
  3. from functools import lru_cache
  4. import os
  5. import random
  6. import socket
  7. import time
  8. from typing import Optional, Tuple, List
  9. from urllib.parse import urlparse
  10. import aiohttp
  11. import psutil
  12. from datetime import datetime, timezone
  13. import ipaddress
  14. import ssl
  15. import requests
  16. import truststore
  17. def normalize_route_path(path: str) -> str:
  18. """
  19. Normalize the route path by adding / at the beginning if not present.
  20. """
  21. if not path.startswith("/"):
  22. path = "/" + path
  23. return path
  24. def get_first_non_loopback_ip(expected_ifname: Optional[str] = None) -> str:
  25. """
  26. Get the first non-loopback IPv4 address of the machine.
  27. Returns:
  28. The IPv4 address as a string.
  29. """
  30. # Fallback to scanning all interfaces
  31. for name, addrs in psutil.net_if_addrs().items():
  32. if expected_ifname is not None and name != expected_ifname:
  33. continue
  34. for addr in addrs:
  35. if addr.family == socket.AF_INET and not addr.address.startswith(
  36. ("127.", "169.254.")
  37. ):
  38. return addr.address
  39. if expected_ifname is not None:
  40. raise Exception(
  41. f"No non-loopback IPv4 address found on interface {expected_ifname}."
  42. )
  43. raise Exception("No non-loopback IPv4 address found.")
  44. def is_ipaddress(ip_str: str) -> bool:
  45. """
  46. Check if the given string is a valid IP address.
  47. Returns:
  48. True if valid IP address, False otherwise.
  49. """
  50. try:
  51. ipaddress.ip_address(ip_str)
  52. return True
  53. except ValueError:
  54. return False
  55. def _get_ifname_by_local_ip(
  56. ip_address: str,
  57. address_family: socket.AddressFamily = socket.AF_INET,
  58. ) -> Optional[str]:
  59. """
  60. Given an IP address, return the interface name if it exists and is not loopback/link-local.
  61. Returns:
  62. The interface name as a string, or None if not found.
  63. """
  64. try:
  65. ip = ipaddress.ip_address(ip_address)
  66. except ValueError:
  67. return None
  68. if ip.is_loopback or ip.is_link_local:
  69. return None
  70. for ifname, addrs in psutil.net_if_addrs().items():
  71. for addr in addrs:
  72. if addr.family == address_family and addr.address == ip_address:
  73. return ifname
  74. return None
  75. def get_ifname_by_ip_hostname(
  76. ip_address_hostname: str,
  77. address_family: socket.AddressFamily = socket.AF_INET,
  78. ) -> Optional[str]:
  79. """
  80. Get the interface name by IP address using psutil.
  81. Args:
  82. ip_address_hostname:
  83. The IP address or hostname to look for. If a hostname is provided, it will be resolved to an IP address.
  84. address_family:
  85. The address family (default is socket.AF_INET).
  86. Returns:
  87. The interface name associated with the given IP address or hostname.
  88. """
  89. local_ifname = _get_ifname_by_local_ip(
  90. ip_address_hostname, address_family=address_family
  91. )
  92. if local_ifname is not None:
  93. return local_ifname
  94. cases: List[Tuple[socket.AddressFamily, str]] = [
  95. (address_family, ip_address_hostname),
  96. ]
  97. if address_family == socket.AF_INET:
  98. cases.append((socket.AF_INET, "8.8.8.8"))
  99. if address_family == socket.AF_INET6:
  100. cases.append((socket.AF_INET6, "2001:4860:4860::8888"))
  101. for af, test_ip in cases:
  102. with contextlib.suppress(Exception):
  103. with socket.socket(af, socket.SOCK_DGRAM) as s:
  104. # the port is arbitrary since we won't actually send any data
  105. s.connect((test_ip, 1))
  106. local_ifname = _get_ifname_by_local_ip(s.getsockname()[0], af)
  107. if local_ifname is not None:
  108. return local_ifname
  109. return None
  110. def parse_port_range(port_range: str) -> Tuple[int, int]:
  111. """
  112. Parse the port range string to a tuple of start and end port.
  113. """
  114. start, end = port_range.split("-")
  115. return int(start), int(end)
  116. def get_free_port(
  117. port_range: str,
  118. unavailable_ports: Optional[set[int]] = None,
  119. host: str = "127.0.0.1",
  120. ) -> int:
  121. start, end = parse_port_range(port_range)
  122. if unavailable_ports is None:
  123. unavailable_ports = set()
  124. if len(unavailable_ports) >= end - start + 1:
  125. raise Exception("No free port available in the port range.")
  126. while True:
  127. port = random.randint(start, end)
  128. if port in unavailable_ports:
  129. continue
  130. if is_port_available(port, host):
  131. return port
  132. else:
  133. unavailable_ports.add(port)
  134. if len(unavailable_ports) == end - start + 1:
  135. raise Exception("No free port available in the port range.")
  136. continue
  137. def is_port_available(port: int, host: str = "127.0.0.1") -> bool:
  138. """
  139. Test if a port is available.
  140. Returns:
  141. True if the port is available, False otherwise.
  142. """
  143. # Then, try to connect (if someone is listening, connect will succeed)
  144. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  145. try:
  146. s.settimeout(0.5)
  147. result = s.connect_ex((host, port))
  148. if result == 0:
  149. # Someone is listening, port is not available
  150. return False
  151. except Exception:
  152. pass
  153. return True
  154. async def is_url_reachable(
  155. url: str, timeout_in_second: int = 10, retry_interval_in_second: int = 3
  156. ) -> bool:
  157. """Check if a url is reachable.
  158. Args:
  159. url (str): url to check.
  160. timeout (int): timeout in seconds. Defaults to 10.
  161. retry_interval_in_second (int, optional): retry inteval. Defaults to 3.
  162. Returns:
  163. bool: True if the url is reachable, False otherwise
  164. """
  165. end_time = time.time() + timeout_in_second
  166. while time.time() < end_time:
  167. try:
  168. use_proxy_env = use_proxy_env_for_url(url)
  169. async with aiohttp.ClientSession(trust_env=use_proxy_env) as session:
  170. async with session.get(url, timeout=2) as response:
  171. if response.status == 200:
  172. return True
  173. except Exception:
  174. await asyncio.sleep(retry_interval_in_second)
  175. return False
  176. def is_offline(
  177. last_update: Optional[datetime],
  178. timeout_seconds: int,
  179. now: Optional[datetime] = None,
  180. ) -> Tuple[bool, Optional[str]]:
  181. """
  182. Check if the last_update time is offline based on the timeout_seconds.
  183. Args:
  184. last_update: The last update time (UTC datetime). If None, it means no record.
  185. timeout_seconds: The threshold in seconds to consider offline.
  186. now: The current time (UTC datetime), defaults to datetime.now(timezone.utc)
  187. Returns:
  188. Tuple[bool, Optional[str]]: (Whether offline, last_update readable string)
  189. - If last_update is None, returns "unknown"
  190. - Otherwise returns formatted time "%Y-%m-%d %H:%M:%S UTC"
  191. """
  192. if now is None:
  193. now = datetime.now(timezone.utc)
  194. if last_update is None:
  195. return True, "unknown"
  196. last_update_ts = int(last_update.timestamp())
  197. now_ts = int(now.timestamp())
  198. is_offline_flag = (now_ts - last_update_ts) > timeout_seconds
  199. last_update_str = last_update.strftime("%Y-%m-%d %H:%M:%S UTC")
  200. return is_offline_flag, last_update_str
  201. def check_registry_reachable(address: str) -> bool:
  202. """
  203. Check if the registry is reachable.
  204. To avoid frequent checks, cache the result for a short period via global lock.
  205. Returns:
  206. bool: True if the registry is reachable, False otherwise.
  207. """
  208. url = f"{address}/v2/"
  209. try:
  210. resp = requests.get(url, timeout=3)
  211. reachable = resp.status_code < 500
  212. except Exception:
  213. reachable = False
  214. return reachable
  215. @lru_cache(maxsize=1)
  216. def _get_no_proxy_cidrs() -> Tuple[ipaddress.IPv4Network, ...]:
  217. """
  218. Parse NO_PROXY environment variable to get a list of CIDR networks.
  219. """
  220. no_proxy = (os.getenv("NO_PROXY") or os.getenv("no_proxy") or "").strip()
  221. if not no_proxy:
  222. return ()
  223. cidrs = []
  224. for entry in no_proxy.split(","):
  225. entry = entry.strip()
  226. if not entry:
  227. continue
  228. try:
  229. net = ipaddress.IPv4Network(entry, strict=False)
  230. cidrs.append(net)
  231. except ValueError:
  232. # Ignore non-CIDR entries (including domain names, plain IPs, etc.)
  233. pass
  234. return tuple(cidrs)
  235. def use_proxy_env_for_url(url: str) -> bool:
  236. """
  237. Determine if proxy environment variables (HTTP_PROXY, HTTPS_PROXY, etc.)
  238. should be used for the given URL.
  239. This is a workaround for the fact that current HTTP clients (e.g., httpx)
  240. do not support CIDR notation in NO_PROXY.
  241. Ref: https://github.com/encode/httpx/issues/1536
  242. - If the host is an IP address:
  243. Do **not** use proxy if it falls within any CIDR defined in NO_PROXY.
  244. -> Return False in that case.
  245. - If the host is a domain name:
  246. Defer to the HTTP client's standard NO_PROXY logic (which doesn't support CIDR),
  247. so assume proxy **should** be used unless explicitly overridden elsewhere.
  248. -> Return True.
  249. Args:
  250. url (str): Full URL (e.g., 'http://192.168.1.10:8080/path')
  251. Returns:
  252. bool: True if proxy environment variables should be used, False if the request
  253. should bypass the proxy (e.g., due to NO_PROXY CIDR match).
  254. """
  255. try:
  256. parsed = urlparse(url)
  257. host = parsed.hostname
  258. if not host:
  259. return True
  260. try:
  261. ip = ipaddress.ip_address(host)
  262. except ValueError:
  263. # It's a domain name -> defer to standard NO_PROXY logic (no CIDR support)
  264. return True
  265. # Check against user-defined CIDRs in NO_PROXY
  266. for net in _get_no_proxy_cidrs():
  267. if ip in net:
  268. # Host is in a NO_PROXY CIDR -> bypass proxy
  269. return False
  270. return True
  271. except Exception:
  272. # On any error (e.g., malformed URL), default to using proxy
  273. return True
  274. @lru_cache(maxsize=1)
  275. def get_system_trust_store_ssl_context() -> ssl.SSLContext:
  276. """
  277. Return an SSL context backed by the operating system trust store.
  278. """
  279. return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)