downloaders.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import logging
  2. import os
  3. from typing import List, Optional, Union
  4. from pathlib import Path
  5. from tqdm.contrib.concurrent import thread_map
  6. from huggingface_hub import HfApi, hf_hub_download, snapshot_download
  7. from modelscope.hub.api import HubApi
  8. from modelscope.hub.snapshot_download import (
  9. snapshot_download as modelscope_snapshot_download,
  10. )
  11. from modelscope.hub.utils.utils import model_id_to_group_owner_name
  12. from gpustack.schemas.models import Model, ModelSource, SourceEnum, get_mmproj_filename
  13. from gpustack.utils import file
  14. from gpustack.utils.hub import (
  15. match_hugging_face_files,
  16. match_model_scope_file_paths,
  17. FileEntry,
  18. )
  19. from gpustack.utils.locks import HeartbeatSoftFileLock
  20. logger = logging.getLogger(__name__)
  21. def download_model(
  22. model: ModelSource,
  23. local_dir: Optional[str] = None,
  24. cache_dir: Optional[str] = None,
  25. huggingface_token: Optional[str] = None,
  26. ) -> List[str]:
  27. if model.source == SourceEnum.HUGGING_FACE:
  28. return HfDownloader.download(
  29. repo_id=model.huggingface_repo_id,
  30. filename=model.huggingface_filename,
  31. extra_filename=get_mmproj_filename(model),
  32. token=huggingface_token,
  33. local_dir=local_dir,
  34. cache_dir=os.path.join(cache_dir, "huggingface"),
  35. owner_worker_id=getattr(model, "worker_id", None),
  36. )
  37. elif model.source == SourceEnum.MODEL_SCOPE:
  38. return ModelScopeDownloader.download(
  39. model_id=model.model_scope_model_id,
  40. file_path=model.model_scope_file_path,
  41. extra_file_path=get_mmproj_filename(model),
  42. local_dir=local_dir,
  43. cache_dir=os.path.join(cache_dir, "model_scope"),
  44. owner_worker_id=getattr(model, "worker_id", None),
  45. )
  46. elif model.source == SourceEnum.LOCAL_PATH:
  47. return file.get_sharded_file_paths(model.local_path)
  48. def get_model_file_info(
  49. model: Model,
  50. huggingface_token: Optional[str] = None,
  51. cache_dir: Optional[str] = None,
  52. ) -> List[FileEntry]:
  53. if model.source == SourceEnum.HUGGING_FACE:
  54. return HfDownloader.get_model_file_info(
  55. model=model,
  56. token=huggingface_token,
  57. )
  58. elif model.source == SourceEnum.MODEL_SCOPE:
  59. return ModelScopeDownloader.get_model_file_info(
  60. model=model,
  61. )
  62. elif model.source == SourceEnum.LOCAL_PATH:
  63. sharded_or_original_file_paths = file.get_sharded_file_paths(model.local_path)
  64. file_list = [
  65. FileEntry(f, file.getsize(f)) for f in sharded_or_original_file_paths
  66. ]
  67. return file_list
  68. raise ValueError(f"Unsupported model source: {model.source}")
  69. class HfDownloader:
  70. _registry_url = "https://huggingface.co"
  71. @classmethod
  72. def get_model_file_info(cls, model: Model, token: Optional[str]) -> List[FileEntry]:
  73. api = HfApi(token=token)
  74. repo_info = api.repo_info(model.huggingface_repo_id, files_metadata=True)
  75. file_list = [FileEntry(f.rfilename, f.size) for f in repo_info.siblings]
  76. return file_list
  77. @classmethod
  78. def download(
  79. cls,
  80. repo_id: str,
  81. filename: Optional[str],
  82. extra_filename: Optional[str],
  83. token: Optional[str] = None,
  84. local_dir: Optional[Union[str, os.PathLike[str]]] = None,
  85. cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
  86. max_workers: int = 8,
  87. owner_worker_id: Optional[int] = None,
  88. ) -> List[str]:
  89. """Download a model from the Hugging Face Hub.
  90. Args:
  91. repo_id:
  92. The model repo id.
  93. filename:
  94. A filename or glob pattern to match the model file in the repo.
  95. token:
  96. The Hugging Face API token.
  97. local_dir:
  98. The local directory to save the model to.
  99. max_workers (`int`, *optional*):
  100. Number of concurrent threads to download files (1 thread = 1 file download).
  101. Defaults to 8.
  102. Returns:
  103. The paths to the downloaded model files.
  104. """
  105. group_or_owner, name = model_id_to_group_owner_name(repo_id)
  106. lock_filename = os.path.join(cache_dir, group_or_owner, f"{name}.lock")
  107. if local_dir is None:
  108. local_dir = os.path.join(cache_dir, group_or_owner, name)
  109. logger.info(f"Retrieving file lock: {lock_filename}")
  110. with HeartbeatSoftFileLock(lock_filename, owner_worker_id=owner_worker_id):
  111. if filename:
  112. return cls.download_file(
  113. repo_id=repo_id,
  114. filename=filename,
  115. token=token,
  116. local_dir=local_dir,
  117. extra_filename=extra_filename,
  118. )
  119. snapshot_download(
  120. repo_id=repo_id,
  121. token=token,
  122. local_dir=local_dir,
  123. )
  124. return [local_dir]
  125. @classmethod
  126. def download_file(
  127. cls,
  128. repo_id: str,
  129. filename: Optional[str],
  130. token: Optional[str] = None,
  131. local_dir: Optional[Union[str, os.PathLike[str]]] = None,
  132. max_workers: int = 8,
  133. extra_filename: Optional[str] = None,
  134. ) -> List[str]:
  135. """Download a model from the Hugging Face Hub.
  136. Args:
  137. repo_id: The model repo id.
  138. filename: A filename or glob pattern to match the model file in the repo.
  139. token: The Hugging Face API token.
  140. local_dir: The local directory to save the model to.
  141. Returns:
  142. The path to the downloaded model.
  143. """
  144. matching_files = match_hugging_face_files(
  145. repo_id, filename, extra_filename, token
  146. )
  147. if len(matching_files) == 0:
  148. raise ValueError(f"No file found in {repo_id} that match {filename}")
  149. logger.info(f"Downloading model {repo_id}/{filename}")
  150. subfolder = (
  151. None
  152. if (subfolder := str(Path(matching_files[0]).parent)) == "."
  153. else subfolder
  154. )
  155. unfolder_matching_files = [Path(file).name for file in matching_files]
  156. downloaded_files = []
  157. def _inner_hf_hub_download(repo_file: str):
  158. downloaded_file = hf_hub_download(
  159. repo_id=repo_id,
  160. filename=repo_file,
  161. token=token,
  162. subfolder=subfolder,
  163. local_dir=local_dir,
  164. )
  165. downloaded_files.append(downloaded_file)
  166. thread_map(
  167. _inner_hf_hub_download,
  168. unfolder_matching_files,
  169. desc=f"Fetching {len(unfolder_matching_files)} files",
  170. max_workers=max_workers,
  171. )
  172. logger.info(f"Downloaded model {repo_id}/{filename}")
  173. return sorted(downloaded_files)
  174. def __call__(self):
  175. return self.download()
  176. class ModelScopeDownloader:
  177. @classmethod
  178. def get_model_file_info(cls, model: Model) -> List[FileEntry]:
  179. api = HubApi()
  180. repo_files = api.get_model_files(model.model_scope_model_id, recursive=True)
  181. file_list = [FileEntry(f.get("Path"), f.get("Size")) for f in repo_files]
  182. return file_list
  183. @classmethod
  184. def download(
  185. cls,
  186. model_id: str,
  187. file_path: Optional[str],
  188. extra_file_path: Optional[str],
  189. local_dir: Optional[Union[str, os.PathLike[str]]] = None,
  190. cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
  191. owner_worker_id: Optional[int] = None,
  192. ) -> List[str]:
  193. """Download a model from Model Scope.
  194. Args:
  195. model_id:
  196. The model id.
  197. file_path:
  198. A filename or glob pattern to match the model file in the repo.
  199. cache_dir:
  200. The cache directory to save the model to.
  201. Returns:
  202. The path to the downloaded model.
  203. """
  204. group_or_owner, name = model_id_to_group_owner_name(model_id)
  205. lock_filename = os.path.join(cache_dir, group_or_owner, f"{name}.lock")
  206. if local_dir is None:
  207. local_dir = os.path.join(cache_dir, group_or_owner, name)
  208. logger.info(f"Retrieving file lock: {lock_filename}")
  209. with HeartbeatSoftFileLock(lock_filename, owner_worker_id=owner_worker_id):
  210. if file_path:
  211. matching_files = match_model_scope_file_paths(
  212. model_id, file_path, extra_file_path
  213. )
  214. if len(matching_files) == 0:
  215. raise ValueError(
  216. f"No file found in {model_id} that match {file_path}"
  217. )
  218. model_dir = modelscope_snapshot_download(
  219. model_id=model_id,
  220. local_dir=local_dir,
  221. allow_patterns=matching_files,
  222. )
  223. return [os.path.join(model_dir, file) for file in matching_files]
  224. modelscope_snapshot_download(
  225. model_id=model_id,
  226. local_dir=local_dir,
  227. )
  228. return [local_dir]