| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837 |
- import asyncio
- from concurrent.futures import ProcessPoolExecutor
- from functools import partial
- import glob
- from itertools import chain
- import logging
- from pathlib import Path
- import platform
- import time
- import threading
- from typing import Dict, Tuple
- from filelock import Timeout
- from modelscope.hub.constants import TEMPORARY_FOLDER_NAME, API_FILE_DOWNLOAD_CHUNK_SIZE
- from multiprocessing import Manager, cpu_count
- from huggingface_hub._local_folder import get_local_download_paths
- from huggingface_hub.file_download import get_hf_file_metadata, hf_hub_url
- import huggingface_hub.constants
- from huggingface_hub.utils import build_hf_headers
- from gpustack.api.exceptions import NotFoundException
- from gpustack.config.config import Config
- from gpustack.logging import setup_logging
- from gpustack.schemas.model_files import ModelFile, ModelFileUpdate, ModelFileStateEnum
- from gpustack.client import ClientSet
- from gpustack.schemas.models import SourceEnum
- from gpustack.server.bus import Event, EventType
- from gpustack.utils import hub
- from gpustack.utils.file import delete_path, get_local_file_size_in_byte
- from gpustack.worker import downloaders
- from gpustack.config.registration import read_worker_token
- from gpustack.utils.locks import read_lock_info, get_lock_path
- logger = logging.getLogger(__name__)
- max_concurrent_downloads = 5
- def _cleanup_download_log(config_log_dir, model_file_id):
- """
- Clean up the download log file
- """
- try:
- log_dir = Path(config_log_dir) / "serve"
- download_log_file_path = log_dir / f"model_file_{model_file_id}.download.log"
- if not download_log_file_path.exists():
- return
- download_log_file_path.unlink()
- logger.debug(f"Cleaned up download log file: {download_log_file_path}")
- except Exception as e:
- logger.warning(
- f"Failed to clean up download log file for model file {model_file_id}: {e}"
- )
- class ModelFileManager:
- def __init__(
- self,
- worker_id: int,
- clientset: ClientSet,
- cfg: Config,
- ):
- self._worker_id = worker_id
- self._config = cfg
- self._clientset = clientset
- self._active_downloads: Dict[int, Tuple] = {}
- self._download_pool = None
- async def watch_model_files(self):
- self._prerun()
- while True:
- try:
- logger.debug("Started watching model files.")
- await self._clientset.model_files.awatch(
- callback=self._handle_model_file_event
- )
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Failed to watch model files: {e}")
- await asyncio.sleep(5)
- def _prerun(self):
- self._mp_manager = Manager()
- self._download_pool = ProcessPoolExecutor(
- max_workers=min(max_concurrent_downloads, cpu_count()),
- )
- def _handle_model_file_event(self, event: Event):
- mf = ModelFile.model_validate(event.data)
- if mf.worker_id != self._worker_id:
- # Ignore model files that are not assigned to this worker.
- return
- logger.trace(f"Received model file event: {event.type} {mf.id} {mf.state}")
- if event.type == EventType.DELETED:
- asyncio.create_task(self._handle_deletion(mf))
- elif event.type in {EventType.CREATED, EventType.UPDATED}:
- if mf.state != ModelFileStateEnum.DOWNLOADING:
- return
- self._create_download_task(mf)
- def _update_model_file(self, id: int, **kwargs):
- model_file_public = self._clientset.model_files.get(id=id)
- model_file_update = ModelFileUpdate(**model_file_public.model_dump())
- for key, value in kwargs.items():
- setattr(model_file_update, key, value)
- self._clientset.model_files.update(id=id, model_update=model_file_update)
- async def _handle_deletion(self, model_file: ModelFile):
- entry = self._active_downloads.pop(model_file.id, None)
- if entry:
- future, cancel_flag = entry
- cancel_flag.set()
- future.cancel()
- try:
- await asyncio.wrap_future(future)
- except (asyncio.CancelledError, NotFoundException):
- pass
- except Exception as e:
- logger.error(
- f"Error while cancelling download for {model_file.readable_source}(id: {model_file.id}): {e}"
- )
- finally:
- logger.info(
- f"Cancelled download for deleted model: {model_file.readable_source}(id: {model_file.id})"
- )
- if model_file.cleanup_on_delete:
- await self._delete_model_file(model_file)
- async def get_hf_file_metadata(self, model_file: ModelFile, filename: str):
- token = self._config.huggingface_token
- url = hf_hub_url(model_file.huggingface_repo_id, filename)
- headers = build_hf_headers(token=token)
- metadata = await asyncio.to_thread(
- get_hf_file_metadata,
- url=url,
- timeout=huggingface_hub.constants.DEFAULT_ETAG_TIMEOUT,
- headers=headers,
- token=token,
- )
- return metadata
- async def _get_incomplete_model_files( # noqa: C901
- self, model_file: ModelFile
- ) -> set:
- """
- Finds cached files of models being downloaded.
- 1.For models from Hugging Face, their .incomplete filenames are encoded. The process requires:
- [filename_pattern → model_name → etag → incomplete_pattern → .incomplete_filename] to ultimately confirm the file.
- 2.For models from ModelScope, the incomplete files are stored in a temporary folder.
- we just need to find them by the filename pattern.
- """
- paths_to_delete = set()
- try:
- if model_file.source == SourceEnum.HUGGING_FACE:
- if not model_file.huggingface_filename:
- # The resolved_paths in vLLM model points to entire dir of cache, delete it directly
- paths_to_delete.update(model_file.resolved_paths)
- return paths_to_delete
- for path in model_file.resolved_paths:
- path_obj = Path(str(path))
- filename_pattern = path_obj.name
- local_dir = path_obj.parent
- download_paths = get_local_download_paths(
- local_dir, filename_pattern
- )
- cache_dir = download_paths.lock_path.parent
- filename = ""
- # Get actual filename by pattern
- for cache_file in await asyncio.to_thread(
- glob.glob, str(cache_dir / filename_pattern) + "*"
- ):
- # cut off the path and useless extension
- filename = cache_file.rsplit("/", 1)[-1]
- filename = filename.rsplit(".", 1)[0]
- break
- metadata = await self.get_hf_file_metadata(model_file, filename)
- # Collect lock files and incomplete files
- paths_to_delete.add(str(cache_dir / (filename + ".lock")))
- paths_to_delete.add(str(cache_dir / (filename + ".metadata")))
- for item_path_str in await asyncio.to_thread(
- glob.glob, str(cache_dir / f"*.{metadata.etag}.incomplete")
- ):
- paths_to_delete.add(item_path_str)
- elif model_file.source == SourceEnum.MODEL_SCOPE:
- if not model_file.model_scope_file_path:
- # The resolved_paths in vLLM model points to entire dir of cache, delete it directly
- paths_to_delete.update(model_file.resolved_paths)
- return paths_to_delete
- for path in model_file.resolved_paths:
- path_obj = Path(str(path))
- filename_pattern = path_obj.name
- local_dir = path_obj.parent
- for delete_file in await asyncio.to_thread(
- glob.glob,
- str(local_dir / f"{TEMPORARY_FOLDER_NAME}/{filename_pattern}"),
- ):
- paths_to_delete.add(delete_file)
- except Exception as e:
- logger.error(
- f"Error deleting incomplete Download files for "
- f"file '{filename}': {e}"
- )
- return paths_to_delete
- async def _delete_incomplete_model_files(self, model_file: ModelFile):
- paths_to_delete = await self._get_incomplete_model_files(model_file)
- for delete_file in paths_to_delete:
- logger.info(f"Attempting to delete incomplete file: {delete_file}")
- await asyncio.to_thread(delete_path, delete_file)
- async def _delete_model_file(self, model_file: ModelFile):
- try:
- if model_file.resolved_paths:
- paths = chain.from_iterable(
- glob.glob(p) if '*' in p else [p] for p in model_file.resolved_paths
- )
- for path in paths:
- delete_path(path)
- await self._delete_incomplete_model_files(model_file)
- # Clean up download log file when deleting model file
- _cleanup_download_log(self._config.log_dir, model_file.id)
- logger.info(
- f"Deleted model file {model_file.readable_source}(id: {model_file.id}) from disk"
- )
- except Exception as e:
- logger.error(
- f"Failed to delete {model_file.readable_source}(id: {model_file.id}: {e}"
- )
- await self._update_model_file(
- model_file.id,
- state=ModelFileStateEnum.ERROR,
- state_message=f"Deletion failed: {str(e)}",
- )
- def _create_download_task(self, model_file: ModelFile):
- if model_file.id in self._active_downloads:
- return
- cancel_flag = self._mp_manager.Event()
- download_task = ModelFileDownloadTask(model_file, self._config, cancel_flag)
- future = self._download_pool.submit(download_task.run)
- self._active_downloads[model_file.id] = (future, cancel_flag)
- logger.debug(f"Created download task for {model_file.readable_source}")
- async def _check_completion():
- try:
- await asyncio.wrap_future(future)
- except NotFoundException:
- logger.info(
- f"Model file {model_file.readable_source} not found. Maybe it was cancelled."
- )
- except Exception as e:
- logger.error(f"Failed to download model file: {e}")
- await self._update_model_file(
- model_file.id,
- state=ModelFileStateEnum.ERROR,
- state_message=str(e),
- )
- finally:
- self._active_downloads.pop(model_file.id, None)
- logger.debug(f"Download completed for {model_file.readable_source}")
- asyncio.create_task(_check_completion())
- class ModelFileDownloadTask:
- def __init__(self, model_file: ModelFile, cfg: Config, cancel_flag):
- self._model_file = model_file
- self._config = cfg
- self._cancel_flag = cancel_flag
- # Store download log file paths for related model instances
- self._instance_download_log_file = None
- self._download_completed = False
- # Time control for log updates
- self._last_log_update_time = 0
- self._log_update_interval = 2.0 # 2 seconds interval
- # Multi-file progress tracking with ANSI cursor control
- # Counter for generating unique tqdm IDs
- self._tqdm_counter = 0
- # Dict[tqdm_id, line_number] - tracks which line each file occupies
- self._file_line_mapping = {}
- # Dict[tqdm_id, {'last_update_time': float, 'last_progress': float}]
- self._file_progress_tracking = {}
- self._tqdm_file_basename = {}
- # Number of header lines in the log file
- self._log_header_lines = 1
- self._resume_threshold = 0
- if self._model_file.source == SourceEnum.MODEL_SCOPE:
- self._resume_threshold = API_FILE_DOWNLOAD_CHUNK_SIZE
- elif self._model_file.source == SourceEnum.HUGGING_FACE:
- self._resume_threshold = huggingface_hub.constants.DOWNLOAD_CHUNK_SIZE
- def prerun(self):
- setup_logging(self._config.debug)
- self._clientset = ClientSet(
- base_url=self._config.get_server_url(),
- api_key=read_worker_token(self._config.data_dir),
- )
- self._download_start_time = time.time()
- self._ensure_model_file_size_and_paths()
- self._speed_lock = threading.Lock()
- # Lock for _model_downloaded_size/_last_download_update_time/_last_downloaded_size to avoid race condition
- self._model_downloaded_size = 0
- self._last_download_update_time = 0
- self._last_downloaded_size = 0
- self._setup_instance_log_files()
- logger.debug(f"Initializing task for {self._model_file.readable_source}")
- self._update_progress_func = partial(
- self._update_model_file_progress, self._model_file.id
- )
- self._model_file_size = self._model_file.size
- self.hijack_tqdm_progress()
- def _setup_instance_log_files(self):
- try:
- log_dir = Path(self._config.log_dir) / "serve"
- # Use model file ID for shared download log across all instances using the same model file
- download_log_file_path = (
- log_dir / f"model_file_{self._model_file.id}.download.log"
- )
- # Delete existing download log file to avoid reading previous download logs
- # when redeploying the same model after deleting model_instance but keeping model_file
- if download_log_file_path.exists():
- try:
- download_log_file_path.unlink()
- logger.debug(
- f"Deleted existing download log file: {download_log_file_path}"
- )
- except Exception as e:
- logger.warning(
- f"Failed to delete existing download log file {download_log_file_path}: {e}"
- )
- self._instance_download_log_file = str(download_log_file_path)
- logger.debug(f"Setup shared download log file: {download_log_file_path}")
- except Exception as e:
- logger.warning(f"Failed to setup instance download log files: {e}")
- def _write_log_with_windows_lock(self, log_file_path: str, log_message: str):
- """
- Write log message to file using Windows msvcrt file locking
- """
- try:
- import msvcrt
- except ImportError:
- # msvcrt not available, fallback to basic write
- self._write_log_without_lock(log_file_path, log_message)
- return
- with open(log_file_path, 'a', encoding='utf-8') as f:
- try:
- # Acquire exclusive lock on the file
- # Lock a single byte at the beginning of the file for coordination
- f.seek(0)
- msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1)
- f.seek(0, 2) # Move to end of file for appending
- f.write(log_message)
- f.flush() # Ensure immediate write to disk
- except (OSError, IOError):
- # If locking fails, fallback to basic write
- f.seek(0, 2) # Move to end of file for appending
- f.write(log_message)
- f.flush()
- finally:
- try:
- f.seek(0)
- msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
- except (OSError, IOError):
- pass # Ignore unlock errors
- def _write_log_with_unix_lock(self, log_file_path: str, log_message: str):
- """
- Write log message to file using Unix/Linux fcntl file locking
- """
- try:
- import fcntl
- except ImportError:
- # fcntl not available, fallback to basic write
- self._write_log_without_lock(log_file_path, log_message)
- return
- with open(log_file_path, 'a', encoding='utf-8') as f:
- try:
- # Acquire exclusive lock on the file
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
- f.write(log_message)
- f.flush() # Ensure immediate write to disk
- finally:
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
- def _write_log_without_lock(self, log_file_path: str, log_message: str):
- """
- Write log message to file without file locking (fallback method)
- """
- try:
- with open(log_file_path, 'a', encoding='utf-8') as f:
- f.write(log_message)
- f.flush() # Ensure immediate write to disk
- except Exception as e:
- logger.warning(
- f"Failed to write to instance download log {log_file_path}: {e}"
- )
- def _write_to_instance_download_logs(
- self, message: str, is_error=False, use_tqdm_format=False
- ):
- """
- Write download log message to all associated model instance download log files
- Skip writing if download is completed to avoid unnecessary logs
- """
- if not self._instance_download_log_file:
- return
- if use_tqdm_format:
- # For tqdm-style progress with ANSI control sequences
- if message.startswith('\033[') or message.startswith('\r\033['):
- # This is an ANSI control message, write it directly without additional formatting
- log_message = message
- else:
- # Regular tqdm message without timestamp
- log_message = f"{message}\n"
- else:
- timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
- log_level = "ERROR" if is_error else "INFO"
- log_message = f"[{timestamp}] [{log_level}] {message}\n"
- # Increment header lines counter for non-tqdm messages
- self._log_header_lines += 1
- # Determine file locking mechanism based on platform
- is_windows = platform.system() == 'Windows'
- # Ensure log directory exists
- Path(self._instance_download_log_file).parent.mkdir(parents=True, exist_ok=True)
- # Use appropriate locking method based on platform
- if is_windows:
- self._write_log_with_windows_lock(
- self._instance_download_log_file, log_message
- )
- else:
- self._write_log_with_unix_lock(
- self._instance_download_log_file, log_message
- )
- def run(self):
- try:
- self.prerun()
- self._write_to_instance_download_logs(
- f"Model file download task started: {self._model_file.readable_source}"
- )
- self._download_model_file()
- self._write_to_instance_download_logs(
- f"Model file download task completed successfully: {self._model_file.readable_source}"
- )
- except asyncio.CancelledError:
- self._write_to_instance_download_logs(
- f"Download task cancelled: {self._model_file.readable_source}"
- )
- except Timeout:
- lock_path = get_lock_path(self._config.cache_dir, self._model_file)
- info = read_lock_info(lock_path) if lock_path else None
- owner_id = info.get("worker_id") if info else None
- current_worker_id = self._model_file.worker_id
- if owner_id is None or owner_id != current_worker_id:
- logger.warning(
- f"Download model {self._model_file.readable_source} timed out: "
- f"lock held by other worker, please try again later."
- )
- return
- logger.warning(
- f"Download model {self._model_file.readable_source} timed out waiting to acquire the lock. "
- f"There might be another download task currently downloading the same model to the same disk directory."
- )
- except Exception as e:
- self._write_to_instance_download_logs(
- f"Download task failed: {self._model_file.readable_source} - {str(e)}",
- is_error=True,
- )
- self._update_model_file(
- self._model_file.id,
- state=ModelFileStateEnum.ERROR,
- state_message=str(e),
- )
- def _download_model_file(self):
- self._write_to_instance_download_logs(
- f"Downloading model file: {self._model_file.readable_source}"
- )
- model_paths = downloaders.download_model(
- self._model_file,
- local_dir=self._model_file.local_dir,
- cache_dir=self._config.cache_dir,
- huggingface_token=self._config.huggingface_token,
- )
- self._download_completed = True
- self._update_model_file(
- self._model_file.id,
- state=ModelFileStateEnum.READY,
- download_progress=100,
- resolved_paths=model_paths,
- )
- self._write_to_instance_download_logs(
- f"Successfully downloaded {self._model_file.readable_source}"
- )
- def hijack_tqdm_progress(task_self):
- """
- Monkey patch the tqdm progress bar to update the model instance download progress.
- tqdm is used by hf_hub_download under the hood.
- """
- from tqdm import tqdm
- _original_init = (
- tqdm._original_init if hasattr(tqdm, "_original_init") else tqdm.__init__
- )
- _original_update = (
- tqdm._original_update if hasattr(tqdm, "_original_update") else tqdm.update
- )
- def _new_init(self: tqdm, *args, **kwargs):
- task_self._handle_tqdm_init(self, _original_init, *args, **kwargs)
- def _new_update(self: tqdm, n=1):
- task_self._handle_tqdm_update(self, _original_update, n)
- tqdm.__init__ = _new_init
- tqdm.update = _new_update
- tqdm._original_init = _original_init
- tqdm._original_update = _original_update
- def _handle_tqdm_init(self, tqdm_instance, original_init, *args, **kwargs):
- kwargs["disable"] = False # enable the progress bar anyway
- original_init(tqdm_instance, *args, **kwargs)
- # Assign unique ID and line number for this tqdm instance
- tqdm_id = self._tqdm_counter
- self._tqdm_counter += 1
- tqdm_instance._gpustack_id = tqdm_id
- # Assign a fixed line number for this file (same as tqdm_id)
- line_number = tqdm_id
- self._file_line_mapping[tqdm_id] = line_number
- # Initialize progress tracking for this file
- self._file_progress_tracking[tqdm_id] = {
- 'last_update_time': 0,
- 'last_progress': 0.0,
- }
- if hasattr(self, '_model_file_size'):
- # Resume downloading
- self._model_downloaded_size += tqdm_instance.n
- # Write initial progress line for this file using ANSI cursor positioning
- file_desc = getattr(tqdm_instance, 'desc', None) or f"File {tqdm_id}"
- self._assign_file_basename(tqdm_id, file_desc)
- self._write_progress_with_cursor_positioning(
- line_number, f"{file_desc}: Initializing...", tqdm_id
- )
- def _handle_tqdm_update(self, tqdm_instance, original_update, n=1):
- # Get the tqdm ID and line number for this instance
- tqdm_id = getattr(tqdm_instance, '_gpustack_id', None)
- if not tqdm_id or tqdm_id not in self._file_line_mapping:
- return
- if self._resume_threshold and n > self._resume_threshold:
- # https://github.com/modelscope/modelscope/blob/609442d271bd7ed106a0933b1937289be7c1ad01/modelscope/hub/file_download.py#L417-L422
- # During download reconnection events, the progress bar may recalculate based on the current downloaded size.
- # We need to intercept this behavior and read the actual cached file size to correct the progress display.
- n = self._adjust_downloaded_by_cache_size(tqdm_instance, n)
- original_update(tqdm_instance, n)
- if self._cancel_flag.is_set():
- raise asyncio.CancelledError("Download cancelled")
- line_number = self._file_line_mapping[tqdm_id]
- with self._speed_lock:
- self._model_downloaded_size += n
- try:
- # Update overall progress
- progress = round(
- (self._model_downloaded_size / self._model_file_size) * 100, 2
- )
- # Update individual file progress using ANSI cursor positioning
- current_time = time.time()
- # Get file-specific progress tracking info
- file_tracking = self._file_progress_tracking.get(
- tqdm_id, {'last_update_time': 0, 'last_progress': 0.0}
- )
- # Calculate individual file progress percentage
- if tqdm_instance.total and tqdm_instance.total > 0:
- file_progress = (tqdm_instance.n / tqdm_instance.total) * 100
- else:
- file_progress = 0.0
- # Check if we should log based on time (2 seconds) or progress change (1%)
- time_elapsed = current_time - file_tracking['last_update_time']
- should_log = (
- time_elapsed >= self._log_update_interval # 2 seconds elapsed
- or file_progress >= 100.0 # Always log when complete
- or (
- tqdm_instance.total is not None
- and tqdm_instance.n >= tqdm_instance.total
- ) # Always log when download completes
- )
- if should_log:
- # Update progress to server
- self._update_progress_func(progress)
- # Format progress message using tqdm's string representation
- progress_str = str(tqdm_instance)
- self._write_progress_with_cursor_positioning(
- line_number, progress_str, tqdm_id
- )
- # Update file-specific tracking info
- self._file_progress_tracking[tqdm_id] = {
- 'last_update_time': current_time,
- 'last_progress': file_progress,
- }
- # Keep global update time for backward compatibility
- self._last_log_update_time = current_time
- if file_progress >= 100.0:
- self._recover_cursor_to_end()
- except Exception as e:
- error_msg = f"Failed to update model file: {e}"
- self._write_to_instance_download_logs(
- f"Download error: {error_msg}", is_error=True
- )
- raise Exception(error_msg)
- def _adjust_downloaded_by_cache_size(self, tqdm_instance, n: int) -> int:
- try:
- actual_size = self._get_cache_file_actual_size(tqdm_instance)
- if actual_size is None:
- return n
- base = tqdm_instance.n or 0
- delta = actual_size - base
- logger.debug(f"_adjust_downloaded_by_cache_size success, delta = {n}")
- return delta if delta > 0 else 0
- except Exception:
- return n
- def _get_cache_file_actual_size(self, tqdm_instance) -> int | None:
- try:
- source = self._model_file.source
- paths = self._model_file.resolved_paths or []
- if not paths:
- return None
- target_basename = None
- tid = getattr(tqdm_instance, '_gpustack_id', None)
- if tid is not None:
- target_basename = self._tqdm_file_basename.get(tid)
- if source == SourceEnum.MODEL_SCOPE:
- for path in paths:
- p = Path(str(path))
- if p.is_dir():
- size = self._get_incomplete_size_from_dir(
- p, target_basename, tqdm_instance
- )
- if size is not None:
- return size
- continue
- filename_pattern = p.name
- local_dir = p.parent
- if target_basename and target_basename != filename_pattern:
- continue
- incomplete_path = (
- local_dir / TEMPORARY_FOLDER_NAME / filename_pattern
- )
- if incomplete_path.exists():
- return get_local_file_size_in_byte(str(incomplete_path))
- return None
- except Exception:
- return None
- def _get_incomplete_size_from_dir(
- self, local_dir: Path, target_basename, tqdm_instance
- ) -> int | None:
- try:
- tb = target_basename
- if tb is None:
- desc = getattr(tqdm_instance, 'desc', None)
- temp_dir = local_dir / TEMPORARY_FOLDER_NAME
- if desc and temp_dir.exists():
- for f in temp_dir.iterdir():
- name = f.name
- if name and name in desc:
- tb = name
- break
- if tb:
- incomplete_path = local_dir / TEMPORARY_FOLDER_NAME / tb
- if incomplete_path.exists():
- return get_local_file_size_in_byte(str(incomplete_path))
- return None
- except Exception:
- return None
- def _assign_file_basename(self, tqdm_id: int, desc: str | None):
- try:
- if not desc:
- return
- paths = self._model_file.resolved_paths or []
- for path in paths:
- b = Path(str(path)).name
- if b and b in desc:
- self._tqdm_file_basename[tqdm_id] = b
- return
- except Exception:
- return
- def _write_progress_with_cursor_positioning(
- self, line_number: int, message: str, tqdm_id: int
- ):
- """Write progress message to a specific line using ANSI cursor positioning"""
- if not self._instance_download_log_file:
- return
- try:
- # Calculate the actual line position in the file
- actual_line = line_number + self._log_header_lines
- # Create ANSI escape sequence to position cursor at specific line, column 1
- cursor_position = f"\033[{actual_line};1H"
- # Clear the entire line to remove any residual characters
- clear_line = "\033[2K"
- # Add timestamp and tqdm_id prefix to the message
- timestamp = time.strftime('%H:%M:%S')
- formatted_message = (
- f"[{timestamp}] [{tqdm_id}]" if tqdm_id > 0 else f"[{timestamp}]"
- )
- formatted_message = f"{formatted_message} {message}"
- # Combine cursor positioning, line clearing, and new content
- ansi_message = f"{cursor_position}{clear_line}{formatted_message}\n"
- # Write to log file using the existing infrastructure
- self._write_to_instance_download_logs(ansi_message, use_tqdm_format=True)
- except Exception as e:
- logger.warning(
- f"Failed to write progress with cursor positioning to line {line_number}: {e}"
- )
- def _recover_cursor_to_end(self):
- """Recover cursor to end of log file"""
- max_line_number = (
- max(self._file_line_mapping.values()) if self._file_line_mapping else 0
- )
- line_num = max_line_number + self._log_header_lines + 1
- self._write_to_instance_download_logs(
- f"\033[{line_num};1H", use_tqdm_format=True # Move cursor to end of file
- )
- def _ensure_model_file_size_and_paths(self):
- if self._model_file.size is not None:
- return
- repo_file_list = downloaders.get_model_file_info(
- self._model_file,
- huggingface_token=self._config.huggingface_token,
- cache_dir=self._config.cache_dir,
- )
- (size, file_paths) = hub.match_file_and_calculate_size(
- files=repo_file_list,
- model=self._model_file,
- cache_dir=self._config.cache_dir,
- )
- self._update_model_file(
- self._model_file.id, size=size, resolved_paths=file_paths
- )
- self._model_file.size = size
- self._model_file.resolved_paths = file_paths
- def _update_model_file_progress(self, model_file_id: int, progress: float):
- self._update_model_file(model_file_id, download_progress=progress)
- def _update_model_file(self, id: int, **kwargs):
- model_file_public = self._clientset.model_files.get(id=id)
- model_file_update = ModelFileUpdate(**model_file_public.model_dump())
- for key, value in kwargs.items():
- setattr(model_file_update, key, value)
- self._clientset.model_files.update(id=id, model_update=model_file_update)
|