cerebras.modelzoo.data.common.HDF5IterableDataset.RestartableDataLoader#

class cerebras.modelzoo.data.common.HDF5IterableDataset.RestartableDataLoader[source]#

Bases: torch.utils.data.DataLoader

This custom dataloader class specifies the ‘state_dict’, ‘aggregate_state_dict’, ‘load_state_dict’ and ‘deaggregate_state_dict’ methods. These methods dictate what worker state information is stored (local or global streaming info) and how it is to be aggregated and retrieved. To deterministically restart an instance of HDF5IterableDataset, it requires the number of samples already seen in the previous run. This info is stored in the samples_streamed key inside the worker state dict. Upon restart, the load_state_dict method sets the samples_seen class variable which determines the number of samples to be skipped.

Methods

aggregate_state_dict

deaggregate_state_dict

load_state_dict

state_dict

__init__(*args, **kwargs)[source]#
__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

static __new__(cls, *args: Any, **kwargs: Any) Any#