cerebras.modelzoo.data.common.restartable_dataloader.RestartableDataLoader#

class cerebras.modelzoo.data.common.restartable_dataloader.RestartableDataLoader(*args, **kwargs)[source]#

Bases: torch.utils.data.DataLoader

Restartable dataloader for an torch.utils.data.Dataset.

The state we care about for allowing deterministic restart of instances of Dataset is the total number of samples streamed globally, which gets consumed by the sampler. Accordingly each worker saves the number of samples that it has streamed in state_dict(). We aggregate these together via summation to save the global number of samples streamed across all workers, which is the same thing that is used to set the state of the sampler on state dict load.

Constructs a RestartableDataLoader instance.

Methods

aggregate_state_dict

Aggregates states across all dataloaders into a single state.

deaggregate_state_dict

Deaggregates state from all dataloaders.

load_state_dict

Loads given state into the dataloader.

state_dict

Returns the state of the current dataloader.

state_dict()[source]#

Returns the state of the current dataloader.

load_state_dict(state_dict, strict=True)[source]#

Loads given state into the dataloader.

aggregate_state_dict(worker_states)[source]#

Aggregates states across all dataloaders into a single state.

deaggregate_state_dict(aggregated_state_dict, strict=True)[source]#

Deaggregates state from all dataloaders.