tools_manager.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import json
  2. import logging
  3. import os
  4. from pathlib import Path
  5. import shutil
  6. import stat
  7. import time
  8. from typing import Optional, Dict
  9. import zipfile
  10. import requests
  11. from gpustack.utils.compat_importlib import pkg_resources
  12. from gpustack.utils import platform
  13. from gpustack.worker.backend_dependency_manager import BackendDependencyManager
  14. logger = logging.getLogger(__name__)
  15. BUILTIN_GGUF_PARSER_VERSION = "v0.24.0"
  16. class ToolsManager:
  17. """
  18. ToolsManager is responsible for managing prebuilt binary tools including the following:
  19. - `fastfetch`
  20. - `gguf-parser`
  21. """
  22. def __init__(
  23. self,
  24. tools_download_base_url: str = None,
  25. data_dir: Optional[str] = None,
  26. bin_dir: Optional[str] = None,
  27. pipx_path: Optional[str] = None,
  28. system: Optional[str] = None,
  29. arch: Optional[str] = None,
  30. ):
  31. with pkg_resources.path("gpustack.third_party", "bin") as third_party_bin_path:
  32. self.third_party_bin_path: Path = third_party_bin_path
  33. self.versions_file = third_party_bin_path.joinpath("versions.json")
  34. self._current_tools_version = {}
  35. if os.path.exists(self.versions_file):
  36. try:
  37. with open(self.versions_file, 'r', encoding='utf-8') as file:
  38. self._current_tools_version = json.load(file)
  39. except Exception as e:
  40. logger.warning(f"Failed to load versions.json: {e}")
  41. self._os = system if system else platform.system()
  42. self._arch = arch if arch else platform.arch()
  43. self._download_base_url = tools_download_base_url
  44. self._data_dir = data_dir
  45. self._bin_dir = bin_dir
  46. self._pipx_path = pipx_path
  47. # Initialize backend dependency manager
  48. self._dependency_manager = None
  49. def init_dependency_manager(
  50. self, backend: str, version: str, model_env: Dict[str, str]
  51. ):
  52. """
  53. Init dependency_manager for custom backend version and dependencies.
  54. No need for other scenarios.
  55. """
  56. self._dependency_manager = BackendDependencyManager(
  57. backend=backend, version=version, model_env=model_env
  58. )
  59. def _check_and_set_download_base_url(self):
  60. urls = [
  61. "https://github.com",
  62. "https://gpustack-1303613262.cos.ap-guangzhou.myqcloud.com",
  63. "https://gh-proxy.com/https://github.com",
  64. ]
  65. test_path = f"/gpustack/gguf-parser-go/releases/download/{BUILTIN_GGUF_PARSER_VERSION}/gguf-parser-linux-amd64"
  66. test_size = 512 * 1024 # 512KB
  67. download_tests = []
  68. for url in urls:
  69. test_url = f"{url}{test_path}"
  70. try:
  71. start_time = time.time()
  72. headers = {"Range": f"bytes=0-{test_size - 1}"}
  73. response = requests.get(
  74. test_url, headers=headers, timeout=5, stream=True
  75. )
  76. response.raise_for_status()
  77. if "Content-Range" not in response.headers:
  78. continue
  79. if len(response.content) == 0:
  80. continue
  81. elapsed_time = time.time() - start_time
  82. download_tests.append((url, elapsed_time))
  83. logger.debug(f"Tested {url}, elapsed time {elapsed_time:.2f} seconds")
  84. except Exception as e:
  85. logger.debug(f"Failed to connect to {url}: {e}")
  86. if not download_tests:
  87. raise Exception(
  88. f"It is required to download dependency tools from the internet, but failed to connect to any of {urls}"
  89. )
  90. best_url, _ = min(download_tests, key=lambda x: x[1])
  91. self._download_base_url = best_url
  92. logger.debug(
  93. f"Using {best_url} as the base URL for downloading dependency tools"
  94. )
  95. def prepare_tools(self):
  96. """
  97. Prepare prebuilt binary tools.
  98. """
  99. logger.debug("Preparing dependency tools")
  100. logger.debug(f"OS: {self._os}, Arch: {self._arch}")
  101. self.download_gguf_parser()
  102. self.download_fastfetch()
  103. def remove_cached_tools(self):
  104. """
  105. Remove all cached tools.
  106. """
  107. if os.path.exists(self.third_party_bin_path):
  108. shutil.rmtree(self.third_party_bin_path)
  109. def save_archive(self, archive_path: str):
  110. """
  111. Save all downloaded tools as a tar archive.
  112. """
  113. # Ensure the directory exists
  114. target_dir = os.path.dirname(archive_path)
  115. if target_dir and not os.path.exists(target_dir):
  116. os.makedirs(target_dir)
  117. # Remove extension from archive_path for make_archive. e.g., .tar.gz
  118. base_name = os.path.splitext(os.path.splitext(archive_path)[0])[0]
  119. logger.info(f"Saving dependency tools to {archive_path}")
  120. shutil.make_archive(base_name, "gztar", self.third_party_bin_path)
  121. def load_archive(self, archive_path: str):
  122. """
  123. Load downloaded tools from a tar archive.
  124. """
  125. if not os.path.isfile(archive_path):
  126. raise FileNotFoundError(f"Archive file not found: {archive_path}")
  127. if not os.path.exists(self.third_party_bin_path):
  128. os.makedirs(self.third_party_bin_path)
  129. logger.info(f"Loading dependency tools from {archive_path}")
  130. shutil.unpack_archive(archive_path, self.third_party_bin_path)
  131. def download_gguf_parser(self):
  132. version = BUILTIN_GGUF_PARSER_VERSION
  133. gguf_parser_dir = self.third_party_bin_path.joinpath("gguf-parser")
  134. os.makedirs(gguf_parser_dir, exist_ok=True)
  135. file_name = "gguf-parser"
  136. suffix = ""
  137. if self._os == "windows":
  138. suffix = ".exe"
  139. file_name += suffix
  140. target_file = gguf_parser_dir.joinpath(file_name)
  141. if (
  142. os.path.isfile(target_file)
  143. and self._current_tools_version.get(file_name) == version
  144. ):
  145. logger.debug(f"{file_name} already exists, skipping download")
  146. return
  147. platform_name = self._get_gguf_parser_platform_name()
  148. url_path = f"gpustack/gguf-parser-go/releases/download/{version}/gguf-parser-{platform_name}{suffix}"
  149. logger.info(f"Downloading gguf-parser-{platform_name} '{version}'")
  150. self._download_file(url_path, target_file)
  151. if self._os != "windows":
  152. st = os.stat(target_file)
  153. os.chmod(target_file, st.st_mode | stat.S_IEXEC)
  154. # Update versions.json
  155. self._update_versions_file(file_name, version)
  156. def _get_gguf_parser_platform_name(self) -> str:
  157. platform_name = ""
  158. if self._os == "darwin":
  159. platform_name = "darwin-universal"
  160. elif self._os == "linux" and self._arch == "amd64":
  161. platform_name = "linux-amd64"
  162. elif self._os == "linux" and self._arch == "arm64":
  163. platform_name = "linux-arm64"
  164. elif self._os == "windows" and self._arch == "amd64":
  165. platform_name = "windows-amd64"
  166. elif self._os == "windows" and self._arch == "arm64":
  167. platform_name = "windows-arm64"
  168. else:
  169. raise Exception(f"Unsupported platform: {self._os} {self._arch}")
  170. return platform_name
  171. def download_fastfetch(self):
  172. version = "2.25.0.1"
  173. fastfetch_dir = self.third_party_bin_path.joinpath("fastfetch")
  174. fastfetch_tmp_dir = fastfetch_dir.joinpath("tmp")
  175. platform_name = self._get_fastfetch_platform_name()
  176. file_name = "fastfetch"
  177. if self._os == "windows":
  178. file_name += ".exe"
  179. target_file = os.path.join(fastfetch_dir, file_name)
  180. if (
  181. os.path.isfile(target_file)
  182. and self._current_tools_version.get(file_name) == version
  183. ):
  184. logger.debug(f"{file_name} already exists, skipping download")
  185. return
  186. logger.info(f"Downloading fastfetch-{platform_name} '{version}'")
  187. tmp_file = os.path.join(fastfetch_tmp_dir, f"fastfetch-{platform_name}.zip")
  188. if os.path.exists(fastfetch_tmp_dir):
  189. shutil.rmtree(fastfetch_tmp_dir)
  190. os.makedirs(fastfetch_tmp_dir, exist_ok=True)
  191. url_path = f"gpustack/fastfetch/releases/download/{version}/fastfetch-{platform_name}.zip"
  192. self._download_file(url_path, tmp_file)
  193. self._extract_file(tmp_file, fastfetch_tmp_dir)
  194. extracted_fastfetch = fastfetch_tmp_dir.joinpath(
  195. f"fastfetch-{platform_name}",
  196. "usr",
  197. "bin",
  198. "fastfetch",
  199. )
  200. if self._os == "windows":
  201. extracted_fastfetch = fastfetch_tmp_dir.joinpath(
  202. "fastfetch.exe",
  203. )
  204. if os.path.exists(extracted_fastfetch):
  205. shutil.copy(extracted_fastfetch, target_file)
  206. else:
  207. raise Exception("failed to find fastfetch binary in extracted archive")
  208. if self._os != "windows":
  209. st = os.stat(target_file)
  210. os.chmod(target_file, st.st_mode | stat.S_IEXEC)
  211. # Clean up.
  212. if os.path.exists(fastfetch_tmp_dir):
  213. shutil.rmtree(fastfetch_tmp_dir)
  214. # Update versions.json
  215. self._update_versions_file(file_name, version)
  216. def _update_versions_file(self, tool_name: str, version: str):
  217. updated_versions = self._current_tools_version.copy()
  218. updated_versions[tool_name] = version
  219. try:
  220. with open(self.versions_file, 'w', encoding='utf-8') as file:
  221. json.dump(updated_versions, file, indent=4)
  222. self._current_tools_version[tool_name] = version
  223. except Exception as e:
  224. logger.error(f"Failed to update versions.json: {e}")
  225. def _get_fastfetch_platform_name(self) -> str:
  226. platform_name = ""
  227. if self._os == "darwin":
  228. platform_name = "macos-universal"
  229. elif self._os == "linux" and self._arch == "amd64":
  230. platform_name = "linux-amd64"
  231. elif self._os == "linux" and self._arch == "arm64":
  232. platform_name = "linux-aarch64"
  233. elif self._os == "windows":
  234. platform_name = "windows-amd64"
  235. else:
  236. raise Exception(f"unsupported platform: {self._os} {self._arch}")
  237. return platform_name
  238. def _download_file(
  239. self,
  240. url_path: str,
  241. target_path: str,
  242. base_url: str = None,
  243. headers: Optional[Dict[str, str]] = None,
  244. ):
  245. """Download a file from the URL to the target path."""
  246. if not base_url and not self._download_base_url:
  247. self._check_and_set_download_base_url()
  248. final_base_url = base_url or self._download_base_url
  249. url = f"{final_base_url}/{url_path}"
  250. max_retries = 5
  251. retries = 0
  252. while retries < max_retries:
  253. try:
  254. with requests.get(
  255. url,
  256. stream=True,
  257. timeout=30,
  258. headers=headers,
  259. ) as response:
  260. response.raise_for_status()
  261. with open(target_path, 'wb') as f:
  262. for chunk in response.iter_content(chunk_size=8192):
  263. f.write(chunk)
  264. break
  265. except Exception as e:
  266. retries += 1
  267. if retries >= max_retries:
  268. raise Exception(f"Error downloading from {url}: {e}")
  269. else:
  270. logger.debug(
  271. f"Attempt {retries} failed: {e}. Retrying in 2 seconds..."
  272. )
  273. time.sleep(2)
  274. @staticmethod
  275. def _extract_file(file_path, target_dir):
  276. """Extract a file to the target directory."""
  277. try:
  278. with zipfile.ZipFile(file_path, 'r') as zip_ref:
  279. zip_ref.extractall(target_dir)
  280. except zipfile.BadZipFile as e:
  281. raise Exception(f"error extracting {file_path}: {e}")
  282. except Exception as e:
  283. raise Exception(f"error extracting {file_path}: {e}")