Source code for modelzoo.vision.pytorch.input.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import filecmp
import math
import os
import random
import shutil

import torch
import torch.distributed as dist
from tqdm import tqdm

import cerebras_pytorch as cstorch
import cerebras_pytorch.distributed as dist


[docs]def is_gpu_distributed(): """ Returns True if DDP is enabled """ return ( torch.distributed.is_available() and torch.distributed.is_initialized() )
[docs]def task_id(): if dist.is_streamer(): return dist.get_streaming_rank() elif is_gpu_distributed(): return dist.get_rank() else: return 0
[docs]def num_tasks(): if dist.is_streamer(): return dist.num_streamers() elif is_gpu_distributed(): return dist.get_world_size() else: return 1
[docs]class ShardedSampler(torch.utils.data.Sampler): """ Modified from: https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler Sampler that restricts data loading to a subset of the dataset. Dataset is assumed to be of constant size. Args: dataset (torch.utils.data.Dataset): Dataset used for sampling. mode (modes): Instance of `modes` to indicate train or eval mode. shuffle (bool, optional): If `True` (default), sampler will shuffle the indices. seed (int, optional): Random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: `0`. drop_last (bool, optional): If `True`, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If `False`, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: `False`. """
[docs] def __init__(self, dataset, shuffle=True, seed=None, drop_last=False): self.num_tasks = num_tasks() self.task_id = task_id() self.dataset = dataset self.dataset_len = len(self.dataset) self.drop_last = drop_last if cstorch.use_cs() and not self.drop_last: raise ValueError( "On CS2 we do not support unequal batch sizes so `drop_last` " "must be set to `True`." ) # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_tasks: # Split to nearest available length that is evenly divisible. # This is to ensure each task receives the same amount of data when # using this sampler. self.num_samples = len(self.dataset) // self.num_tasks else: self.num_samples = math.ceil(len(self.dataset) / self.num_tasks) self.total_size = self.num_samples * self.num_tasks self.shuffle = shuffle self.seed = seed self.indices = list(range(self.dataset_len)) if not self.drop_last: # add extra samples to make it evenly divisible across tasks padding_indices_size = self.total_size - self.dataset_len # choose padding indices at random to reduce the chance of # reusing samples. random.seed(self.seed) padding_indices = random.sample(self.indices, padding_indices_size) self.indices += padding_indices else: # remove tail of data to make it evenly divisible. self.indices = self.indices[: self.total_size] assert len(self.indices) == self.total_size, ( f"Total `indices` after dropping/padding indices must be equal " f"to `total_size` of the dataset. Received total indices: " f"`{len(self.indices)}` and total size is: `{self.total_size}`." )
def __iter__(self): if self.shuffle: random.seed(self.seed) random.shuffle(self.indices) # subsample indices = self.indices[self.task_id : self.total_size : self.num_tasks] assert len(indices) == self.num_samples, ( f"Total `indices` for tasks must be equal to `num_samples` in a " f"task. Received total indices: `{len(indices)}` and samples in " f"task are: `{self.num_samples}`." ) yield from indices def __len__(self): return self.num_samples
##### Experimental to reduce first batch loading times for MAP style only ##### class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler)
[docs]class FastDataLoader(torch.utils.data.DataLoader):
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__( self, 'batch_sampler', _RepeatSampler(self.batch_sampler) ) self.iterator = super().__iter__()
def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator)
def _get_worker_cache_dir(src_dir): """Gets the path to worker cache dir corresponding to the src_dir""" src_dir = os.path.abspath(src_dir) cache_dir = os.path.normpath("/".join([dist.WORKER_CACHE_ROOT, src_dir])) os.makedirs(cache_dir, exist_ok=True) return cache_dir def _same_dirs_shallow(src_dir, dest_dir): """Takes a directory comparison obj and does a shallow comparison between the dirs src_dir and dest_dir The shallow comparison does a recursive check of the following: 1. Check if the dirs exist, if they don't then return False 2. Check if the files have a diff, or if there are additional files for either of the two dirs, if different, return False. 3. Repeat 1 and 2 on subdirs """ def _same_dirs_shallow_helper(dcmp: filecmp.dircmp): if not os.path.exists(dcmp.left) or not os.path.exists(dcmp.right): return False if dcmp.left_only: # If diff consists of only broken # symlinks, then its a match parent = dcmp.left for left_file in dcmp.left_only: if os.path.isdir( os.path.join(parent, left_file) ) or os.path.isfile(os.path.join(parent, left_file)): return False if dcmp.diff_files or dcmp.right_only: return False for sub_dcmp in dcmp.subdirs.values(): if not _same_dirs_shallow_helper(sub_dcmp): return False return True return _same_dirs_shallow_helper(filecmp.dircmp(src_dir, dest_dir))
[docs]def create_worker_cache(src_dir: str, force_overwrite: bool = False): """Checks for the dir in the worker_cache (SSD) on the worker node corresponding to the src_dir. If the directory exists and is same as the src_dir, it returns the dir path on worker_cache. Otherwise writes the directory to the worker_cache and returns the dir path. Writing to the cache can take a while, depending on the size of the src_dir: Displays a progress bar (in the worker logs) which shows progress of the cache Forces cache overwrite irrespective of a cache hit, when force_overwrite is True. """ from filelock import FileLock if ( os.path.commonprefix([src_dir, dist.WORKER_CACHE_ROOT]) == dist.WORKER_CACHE_ROOT ): raise RuntimeError( f"Ensure that the src_dir path does not have " f"a worker_cache path prefix: {dist.WORKER_CACHE_ROOT}" ) if not dist.is_streamer(): raise RuntimeError( "Ensure that create_worker_cache is called only for a worker node." ) dest_dir = _get_worker_cache_dir(src_dir) # Provide read/write permissions for the lock for all users with FileLock(f"{dest_dir}.lock", mode=0o666): if _same_dirs_shallow(src_dir, dest_dir) and not force_overwrite: print(f"WORKER CACHE HIT: Skipping overwrite") else: ( is_limit_hit, dir_size, available_space_for_copy, ) = dist.hit_worker_cache_limit(src_dir, dest_dir) if is_limit_hit: raise RuntimeError( f"Failed when copying the directory to the worker cache: {src_dir}," f" directory size: {dir_size} exceeds the available space on worker cache: {available_space_for_copy}." f"Please contact your system administrator to clear the worker cache." ) if os.path.exists(dest_dir): shutil.rmtree(dest_dir) # copy dirs to destination # get the total number of files to copy total_files = sum( [len(files) for root, dirs, files in os.walk(src_dir)] ) # copy directory with progress bar def copy2_with_progress(src_path, dst_path, update): # skip if its a broken symlink if os.path.isfile(src_path): shutil.copy2(src_path, dst_path) update(1) with tqdm( total=total_files, desc="Overwriting cache", unit="files", dynamic_ncols=True, ) as pbar: shutil.copytree( src_dir, dest_dir, symlinks=False, ignore=None, ignore_dangling_symlinks=True, copy_function=lambda f, d: copy2_with_progress( f, d, pbar.update ), ) return dest_dir