cerebras.modelzoo.data.common.h5_map_dataset.dataset.RestartableDataLoader#

class cerebras.modelzoo.data.common.h5_map_dataset.dataset.RestartableDataLoader[source]#

Bases: torch.utils.data.DataLoader

The state we care about for allowing deterministic restart of instances of HDF5Dataset is the total number of samples streamed globally, which gets consumed by the sampler. Accordingly each worker saves the number of samples that it has streamed in state_dict(). We aggregate these together via summation to save the global number of samples streamed across all workers, which is the same thing that is used to set the state of the sampler on state dict load.

Methods

aggregate_state_dict

Sum samples streamed across all workers to get the number of samples streamed globally

deaggregate_state_dict

No deaggregation needed since the sampler needs the global number of samples streamed

load_state_dict

Set sampler state with the total number of samples streamed globally

state_dict

Save number of samples streamed for current worker

validate_state_dict

__init__(*args, **kwargs)[source]#
state_dict()[source]#

Save number of samples streamed for current worker

load_state_dict(state_dict)[source]#

Set sampler state with the total number of samples streamed globally

aggregate_state_dict(worker_states)[source]#

Sum samples streamed across all workers to get the number of samples streamed globally

deaggregate_state_dict(aggregated_state_dict)[source]#

No deaggregation needed since the sampler needs the global number of samples streamed

__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

static __new__(cls, *args: Any, **kwargs: Any) Any#