Saving/Loading Checkpoints#
Overview#
To save and load weights in a Cerebras run, we provide a custom Cerebras H5 based checkpoint format that is far more performant and efficient compared to the core PyTorch Pickle-based checkpoint format, especially when it comes to any models with extremely large weights, such as Large Language Models (LLMs).
To save a checkpoint, we provide a cerebras_pytorch.save
function that
you can use in exactly the same way as torch.save
:
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
...
}
cstorch.save(state_dict, "<filepath for saving model checkpoint>")
Similarly, we provide a cerebras_pytorch.load
function that can also be
used in exactly the same way as torch.load
:
state_dict = cstorch.load("<filepath for saving model checkpoint>")
model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
...
Checkpoint Closures#
It is only possible to fetch weights on predetermined checkpoint steps
configured using the DataExecutor
. The
reason this is so, is to make training more performant.
For example, if the configuration was checkpoint_steps=100
, you are only
allowed to fetch the weights to take a checkpoint every 100th step and at the
very end on the last step.
To aid this, you can use the checkpoint_closure
decorator which is a step closure that checks that the current
step is a checkpoint step before calling the function. In addition, using this
decorator ensures that the weights are available to fetch from the server before
they can be saved to the checkpoint file.
@cstorch.checkpoint_closure
def save_checkpoint():
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
...
}
cstorch.save(state_dict, "<filepath for saving model checkpoint>")
Converting checkpoints to a Pickle-based format#
If you have a checkpoint in the Cerebras H5-based format and wish to use it in a CPU/GPU workflow, it can easily be converted to a PyTorch pickle-based format:
state_dict = cstorch.load("<filepath to checkpoint>", map_location="cpu")
torch.save(state_dict, "<filepath to torch checkpoint>")
Note
This will eagerly load the entirety of the checkpoint into memory. Thus, it may cause memory issues when loading checkpoints for very large models.
In most cases, converting the checkpoint is unnecessary as the lazily
loaded checkpoint acquired via the cerebras_pytorch.load
function
will also work in CPU/GPU workflows.