Checkpoint Formats

When moving to extremely large models reading, writing and manipulating checkpoints becomes a bottleneck. For that reason Cerebras has moved to using an HDF5 based file format in order to store checkpoints. The content of the checkpoints remains the same so they are convertible to native formats for use outside of Cerebras. Specific details about utilities and interfaces are provided for each framework supported.

Tensorflow Checkpoint Format

Our use of the tensorflow estimator interface leads to the use of the tensorflow saver for checkpoint interactions. Unfortunately, the saver doesn’t provide the ability to do iterative updates when writing to a single file.

Usage

Starting with existing checkpoint

Our appliance estimator will use an existing checkpoint if provided in the model_dir or as a warm start path. However, we also provide a utility to convert from tensorflow to our H5 format. That is provided either as a command line utility when the wheel is installed or a package.

$ tensorflow-to-h5 --help
usage: tensorflow-to-h5 [-h] saver_path h5_path

Convert Tensorflow Saver Checkpoint to Cerebras H5 Format

positional arguments:
  saver_path  Path to existing saver checkpoint
  h5_path     Path to store converted h5 checkpoint

$ python
>>> from cerebras_tensorflow.saver import tf_h5_saver

Converting Cerebras checkpoint

After training on the appliance converting from the Cerebras format back to tensorflow saver can also be performed vis a command line utility or a package.

$ h5-to-tensorflow --help
usage: h5-to-tensorflow [-h] h5_path saver_path

Convert Cerebras H5 Checkpoint to Tensorflow Saver Format

positional arguments:
  h5_path     Path to existing h5 checkpoint
  saver_path  Path to store converted saver checkpoint

$ python
>>> from cerebras_tensorflow.saver import tf_h5_saver

For models larger than 20B, writing a tensorflow saver checkpoint can use a prohibitive amount of ram or swap. In this case, using the package provides more flexibility for sharding a checkpoint by weight names in the desired layout.

PyTorch Checkpoint Format