modelzoo.transformers.data_processing.h5_map_dataset.samplers.CBSampler#

class modelzoo.transformers.data_processing.h5_map_dataset.samplers.CBSampler[source]#

Bases: torch.utils.data.Sampler, modelzoo.transformers.data_processing.h5_map_dataset.samplers.PaddedSampler

A sampler to handle sharding, batching, and skipping of map style datasets intended for use on CSX. Sharding is performed in such a way that data order is independent of the number of systems being used and the number of workers per system.

Create a sampler to handle shuffling in a deterministic and restartable way as well as sharding.

Parameters
  • data_source (torch.utils.data.Dataset) – dataset to sample from

  • shuffle (bool) – whether or not to shuffle the dataset

  • seed (int) – The seed used to make shuffling deterministic

  • start_index (int) – The index of the first sample to yield

  • shard (bool) – Whether or not to shard the dataset across Cerebras data streamer nodes

  • batch_size (int) – The batch size to use to compute sharded indices and group samples into batches. If None, no batching will be performed. When running on worker nodes, this should be the per-system batch size rather than the global batch size or the microbatch size. The per-system batch size is defined as global_batch_size / num_csx and can be found using the modelzoo.common.pytorch.input_utils.get_streaming_batch_size function. When running on the coordinator node, this should be the global batch size. Again, the get_streaming_batch_size function will return the appropriate result.

  • num_samples (int) – The number of samples to shuffle over. In multi- epoch training, it is common to set this to the total number of samples that you plan to see in your training run to get smoother loss curves and improved convergence.

  • pad_last (bool) – Flag to enable padding of the last batch so that the last batch has the same batch size as the rest of the batches. Only used if batch_size is not None and drop_last is False.

Methods

set_state

Sets the state of the sampler to continue deterministically from a prior run.

Attributes

pad_index

__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

__init__(data_source, shuffle=True, seed=None, start_index=0, shard=True, batch_size=None, drop_last=True, num_samples=None, pad_last=False)[source]#

Create a sampler to handle shuffling in a deterministic and restartable way as well as sharding.

Parameters
  • data_source (torch.utils.data.Dataset) – dataset to sample from

  • shuffle (bool) – whether or not to shuffle the dataset

  • seed (int) – The seed used to make shuffling deterministic

  • start_index (int) – The index of the first sample to yield

  • shard (bool) – Whether or not to shard the dataset across Cerebras data streamer nodes

  • batch_size (int) – The batch size to use to compute sharded indices and group samples into batches. If None, no batching will be performed. When running on worker nodes, this should be the per-system batch size rather than the global batch size or the microbatch size. The per-system batch size is defined as global_batch_size / num_csx and can be found using the modelzoo.common.pytorch.input_utils.get_streaming_batch_size function. When running on the coordinator node, this should be the global batch size. Again, the get_streaming_batch_size function will return the appropriate result.

  • num_samples (int) – The number of samples to shuffle over. In multi- epoch training, it is common to set this to the total number of samples that you plan to see in your training run to get smoother loss curves and improved convergence.

  • pad_last (bool) – Flag to enable padding of the last batch so that the last batch has the same batch size as the rest of the batches. Only used if batch_size is not None and drop_last is False.

static __new__(cls, *args: Any, **kwargs: Any) Any#
set_state(start_index)[source]#

Sets the state of the sampler to continue deterministically from a prior run.

Parameters

start_index – the total number of samples streamed globally across all workers from a previous run.