cerebras.modelzoo.data.vision.classification.sampler.RepeatedAugSampler#

class cerebras.modelzoo.data.vision.classification.sampler.RepeatedAugSampler[source]#

Bases: torch.utils.data.Sampler

Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU). Heavily based on ‘torch.utils.data.DistributedSampler’.

This is borrowed from the DeiT Repo: https://github.com/facebookresearch/deit/blob/main/samplers.py

Methods

set_epoch

__init__(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, num_repeats=3, batch_size=256)[source]#
__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

static __new__(cls, *args: Any, **kwargs: Any) Any#