Work with Cerebras checkpoints#

The reading, writing, and manipulation of checkpoints can become bottlenecks for training when working with extremely large models, as these file sizes become very large. To address this, Cerebras uses an HDF5-based file format in order to store checkpoints. The content of the checkpoints are the same as for all common model checkpoint formats, so they are also easily convertible to other desired formats. Specific details about utilities and interfaces are provided for each framework supported.

PyTorch Checkpoint Format#

Our large model-optimized checkpoint format is based off the standard HDF5 file format. At a high-level, when saving a checkpoint, the Cerebras stack will take a PyTorch state dictionary, flatten it, and store it in an HDF5 file. For example, the following state dict:

{
    "a": {
        "b": 0.1,
        "c": 0.001,
    },
    "d": [0.1, 0.2, 0.3]
}

Would be flattened and stored into the H5 file as follows

{
    "a.b": 0.1,
    "a.c": 0.001,
    "d.0": 0.1,
    "d.1": 0.2,
    "d.2": 0.3,
}

A model/optimizer state dict can be saved in the new checkpoint format using the cbtorch.save method. e.g.

import cerebras.pytorch as cbtorch

...

state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
cbtorch.save(state_dict, "path/to/checkpoint")

...

A checkpoint saved using the above can be loaded using the cbtorch.load method. e.g.

import cerebras.pytorch as cbtorch

...

state_dict = cbtorch.load("path/to/checkpoint")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])

...

Note

If using the run.py scripts provided in the ModelZoo the above is all already taken care of in the runners used in the ModelZoo.

Converting Checkpoint Formats#

If using cbtorch.load is not a sufficient solution for loading the checkpoint into memory, a simple conversion can be done to the pickle format that PyTorch uses as follows

import torch
import cerebras.pytorch as cbtorch

state_dict = cbtorch.load("path/to/checkpoint")
torch.save(state_dict, "path/to/new/checkpoint")