Saving/Loading checkpoints#

Overview#

To save and load weights in a Cerebras run, we provide a custom Cerebras H5 based checkpoint format that is far more performant and efficient compared to the core PyTorch Pickle-based checkpoint format, especially when it comes to any models with extremely large weights, such as Large Language Models (LLMs).

To save a checkpoint, we provide a cerebras.pytorch.save function that you can use in exactly the same way as torch.save:

state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    ...
}
cstorch.save(state_dict, "<filepath for saving model checkpoint>")

Similarly, we provide a cerebras.pytorch.load function that can also be used in exactly the same way as torch.load:

state_dict = cstorch.load("<filepath for saving model checkpoint>")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
...

Checkpoint Closures#

It is only possible to fetch weights on predetermined checkpoint steps configured using the DataExecutor. The reason this is so, is to make training more performant.

For example, if the configuration was checkpoint_steps=100, you are only allowed to fetch the weights to take a checkpoint every 100th step and at the very end on the last step.

To aid this, you can use the checkpoint_closure decorator which is a step closure that checks that the current step is a checkpoint step before calling the function. In addition, using this decorator ensures that the weights are available to fetch from the server before they can be saved to the checkpoint file.

@cstorch.checkpoint_closure
def save_checkpoint():
    state_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        ...
    }
    cstorch.save(state_dict, "<filepath for saving model checkpoint>")

Converting checkpoints to a Pickle-based format#

If you have a checkpoint in the Cerebras H5-based format and wish to use it in a CPU/GPU workflow, it can easily be converted to a PyTorch pickle-based format:

state_dict = cstorch.load("<filepath to checkpoint>", map_location="cpu")
torch.save(state_dict, "<filepath to torch checkpoint>")

Note

This will eagerly load the entirety of the checkpoint into memory. Thus, it may cause memory issues when loading checkpoints for very large models.

In most cases, converting the checkpoint is unnecessary as the lazily loaded checkpoint acquired via the cerebras.pytorch.load function will also work in CPU/GPU workflows.