Restartable dataloaders#

Overview#

In the Cerebras PyTorch API 2.0, we provide revamped support for deterministically restarting any custom input-generating dataloaders used for a run. This feature enables the saving and loading of the dataloader state and seamlessly integrates with our existing mechanism of capturing checkpoints for a run.

Saving DataLoader State#

Similar to how you call state_dict on components such as the model and optimizer to fetch and save state information in our Cerebras H5-based checkpoint format. You can save the state of your dataloader by calling state_dict on the Cerebras PyTorch dataloader wrapper that must be initialized at the beginning of a custom training loop.

cerebras_dataloader = cstorch.utils.data.DataLoader(input_fn, *args, **kwargs)
...
state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "dataloader": cerebras_dataloader.state_dict(),
    ...
}
cstorch.save(state_dict, "<path to save checkpoint to>")

Note

  • Our typical workflow in Model Zoo already includes this call on the Cerebras PyTorch dataloader wrapper to save the state of the dataloader being used for the run.

  • The dataloader state can only be saved at a checkpoint step – i.e. you should wrap the method invoking the call to save the dataloader state in the checkpoint_closure decorator.

Loading DataLoader State#

Upon restarting a run from a Cerebras checkpoint file, you can fetch the saved dataloader state (if it exists) from the loaded checkpoint and pass it to the load_state_dict method on the Cerebras PyTorch dataloader wrapper to load your dataloader’s state, e.g.

state_dict = cstorch.load("<path to H5 checkpoint file>")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])

if "dataloader" in state_dict:
    cerebras_dataloader.load_state_dict(state_dict["dataloader"])

And that is all!

Now to specify what “state” information of your dataloader is to be saved in a checkpoint when state_dict is called on the Cerebras PyTorch dataloader, and how this state information should be loaded to rewind your dataloader, your dataloader must conform to the protocol class RestartableDataLoader.

Restartable DataLoader API#

By implementing methods state_dict, aggregate_state_dict, deaggregate_state_dict and load_state_dict, with the appropriate method signatures, your dataloader is guaranteed to be restartable. That is, you are able to save the state of your dataloader in a checkpoint and load it by the mechanism described above.

Recall that in a distributed setting, each input worker per CSX creates its own instance of the dataloader for parallelism. Thus, implementing these four methods will determine how your dataloader’s state should be saved and loaded to enable deterministic restarts for such settings.

To illustrate the usage of this protocol with an example, we define our CustomRestartableDataLoader class below. The following subsections describe each method signature more generally and within the context of our custom class.

import cerebras.pytorch as cstorch

import torch

class CustomRestartableDataLoader(torch.utils.data.DataLoader):

    def state_dict(self) -> Dict[str, Any]:
        worker_state = cstorch.distributed.get_worker_state()
        state_dict = {
            "worker_step": worker_state.worker_step,
            "worker_id": worker_state.global_worker_id
            "some_state_info": <any other state info>
            ...
        }

        return state_dict

    def aggregate_state_dict(self, worker_states):
        return {
            "step_worker_0": worker_states[0]["worker_step"],
            "id_worker_1": worker_states[1]["worker_id"],
            "combined_step_sum": worker_states[0]["worker_step"] + worker_states[1]["worker_step"]
        }

    def deaggregate_state_dict(self, aggregated_state_dict):
        if "combined_step_sum" not in aggregated_state_dict:
            raise RuntimeError(
                "The aggregated state dict must contain key `combined_step_sum`. "
                "This means that the dataloader state in the checkpoint you are "
                "loading from is not compatible with the dataloader currently "
                "in use."
            )

        return {
            "combined_step": aggregated_state_dict["combined_step_sum"]
        }

    def load_state_dict(self, state_dict):
        if "combined_step" not in state_dict:
            raise RuntimeError(
                "The state dict must contain key `combined_step`, but it does not. "
                "This means that the dataloader state in the checkpoint you are "
                "loading from is not compatible with the dataloader currently "
                "in use."
            )

        print(f"Loading state using combined steps: {state_dict["combined_step"]}")
        ...

state_dict#

Use this method to specify what state information each input-generating worker should capture at an appliance checkpoint step. By default, each worker captures some internal state info using our new Cerebras dataloader checkpoint format defined by the DataLoaderCheckpoint dataclass. Please refer to the linked docs on this class for detailed information on each attribute. Essentially, in your definition of state_dict you may choose to save any of the aforementioned internal state info per worker. We expose an API method get_worker_state that you may utilize in your implementation of state_dict to fetch the worker’s internal state info, e.g.

def state_dict(self) -> Dict[str, Any]:
    worker_state = cstorch.distributed.get_worker_state()
    state_dict = {
        "worker_step": worker_state.worker_step,
        "worker_id": worker_state.global_worker_id
        "some_state_info": <any other state info>
        ...
    }

    return state_dict

Note

  • The call to get_worker_state is well-defined only inside of your implementation of state_dict; calling this method anywhere else will result in a RunimeError exception.

  • Ensure that any other state info you choose to save must be picklable using the dill package.

aggregate_state_dict#

This method accepts the list of individual worker states dicts as an argument. Each state dict inside this list holds per-worker state information as defined in your implementation of the state_dict signature method.

Use this method to specify how to combine the state information of all workers in a single, consolidated state dict, e.g.

def aggregate_state_dict(self, worker_states):
    return {
        "step_worker_0": worker_states[0]["worker_step"],
        "id_worker_1": worker_states[1]["worker_id"],
        "combined_step_sum": worker_states[0]["worker_step"] + worker_states[1]["worker_step"]
    }

Note

  • The aggregated state dict represents the state of your dataloader and will eventually be saved in our Cerebras H5 checkpoint file when state_dict is invoked on the Cerebras PyTorch dataloader wrapper to save your dataloader’s state.

  • In the example above, we’re assuming two total workers used for the run. In the aggregated state dict, we are choosing to save worker 0’s step, worker 1’s global worker id, and the summed step count of both workers as the state of our dataloader.

  • You can expect the worker_states list to be ordered by the global worker id of each worker.

deaggregate_state_dict#

This method accepts an aggregated state dict as an argument. The aggregated state dict represents the state of your dataloader, as specified in the aggregate_state_dict method signature of your dataloader.

To load your data loader’s state, use this method to specify how the consolidated dataloader state loaded from a checkpoint should be disaggregated into a single state dict defining how each worker should load its state, e.g.

def deaggregate_state_dict(self, aggregated_state_dict):
    if "combined_step_sum" not in aggregated_state_dict:
        raise RuntimeError(
            "The aggregated state dict must contain key `combined_step_sum`. "
            "This means that the dataloader state in the checkpoint you are "
            "loading from is not compatible with the dataloader currently "
            "in use."
        )

    return {
        "combined_step": aggregated_state_dict["combined_step_sum"]
    }

In the example above, our implementation has an explicit check to ensure that we’re loading state captured by this dataloader. Upon restart, we assume that each worker cares about the combined step count of all workers in the previous run at the checkpoint we’re loading from; thus, the deaggregation method constructs and returns a single state holding the combined step info.

Note

This method will be particularly useful when the number of workers per box changes between subsequent runs; use this to specify which state dict should be loaded by each worker upon restart.

load_state_dict#

This method accepts a disaggregated state dict as an argument, as defined in your implementation of deaggregate_state_dict.

Use this method to specify how the worker should load its state from the provided, disaggregated state dict, e.g.

def load_state_dict(self, state_dict):
    if "combined_step" not in state_dict:
        raise RuntimeError(
            "The state dict must contain key `combined_step`, but it does not. "
            "This means that the dataloader state in the checkpoint you are "
            "loading from is not compatible with the dataloader currently "
            "in use."
        )

    print(f"Loading state using combined steps: {state_dict["combined_step"]}")
    ...

Again, we have an explicit check to ensure that the disaggregated state dict being used by each worker to load its state upon restart is the same as that specified by our data loader’s implementation of deaggregate_state_dict. For this example, each worker simply prints the combined step count from the previous run, but you can imagine using this step count to set other properties on your data loader that enable it to restart deterministically.

Putting it All Together#

Combining all of the above steps, we have the following steps to set up our custom restartable dataloader whose state can be captured via checkpointing:

def restartable_torch_dataloader(batch_size):
    from torchvision import datasets

    train_dataset = datasets.MNIST(
        "/path/to/data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
                transforms.Lambda(
                    lambda x: torch.as_tensor(x, dtype=torch.float16)
                ),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        )
    )

    return CustomRestartableDataLoader(train_dataset, batch_size=batch_size, shuffle=True)

cerebras_dataloader = cstorch.utils.data.DataLoader(restartable_torch_dataloader, batch_size=64)
...
@cstorch.checkpoint_closure
def save_checkpoint():
    state_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "dataloader": cerebras_dataloader.state_dict(),
        ...
    }
    cstorch.save(state_dict, "<path to save checkpoint to>")
...
def load_checkpoint(checkpoint_file_path):
    state_dict = cstorch.load(checkpoint_file_path)

    model.load_state_dict(state_dict["model"])
    optimizer.load_state_dictt(state_dict["optimizer"])

    if "dataloader" in state_dict:
        cerebras_dataloader.load_state_dict(state_dict["dataloader"])

Note

It is not necessary for your dataloader to be of type torch.utils.data.DataLoader to enable the saving and loading of its state; in fact, any iterable that returns a structure comprising torch tensors can be programmed to be restartable as long as it implements the four signature methods conforming to the Cerebras RestartableDataLoader protocol class.