Cannot load Cerebras checkpoints in GPUs#

You trained a model using Cerebras cluster and you would like to reload the model for continued gpu training/inference. However, when setting a device different than the Cerebras Wafer-Scale cluster, the model loading is through pickle, but the Cerebras cluster model is in hdf5 format.

Work around#

When moving to extremely large models reading, writing and manipulating checkpoints becomes a bottleneck. For that reason Cerebras has moved to using an HDF5 based file format in order to store checkpoints. Cerebras provides conversion scripts to convert between checkpoint file formats as explain in Work with Cerebras checkpoints.

Here is an example:

  1. Train a GPT2 small model from Cerebras Model Zoo on a Cerebras cluster for 200 steps

  2. Convert the checkpoint collected as part of the training in a Cerebras cluster CS2 to .pkl format. You will require the Cerebras virtual environment created in Setup Cerebras virtual environment.

> source venv_cerebras_pt/bin/activate
(venv_cerebras_pt) > python

>>import sys
>>sys.path.append('<base_dir_of_modelzoo>')
>>from modelzoo.common.pytorch import cbtorch
>>import torch
>>h5_ckpt = cbtorch.load('<path_to_checkpoint_200.mdl>')
>>pkl_ckpt = torch.save(h5_ckpt, '<path_to_ckpt_200_in_pkl>')
  1. Copy over checkpoint save in .pkl format to GPU setup

  2. Checkout modelzoo in GPU env and install GPU dependencies for PyTorch, as explained in Clone Cerebras Model Zoo and GPU Requirements for Running Cerebras ModelZoo.

  3. Adjust gpt2_small params to set train_input.batch_size=4

  4. Resume training using converted checkpoint

> source venv/bin/activate
(venv)> python modelzoo/transformers/pytorch/gpt2/run.py \
      GPU \
      -p gpt2_small_gpu.yaml\
      -m train\
      -o gpt2_small_gpu_model_dir\
      --checkpoint_path gpt2_small_model_dir/ckpt_200.pkl

WARNING:root:The following model params are unused: allow_multireplica
2023-04-28 16:06:55,102 INFO:   Loading weights from checkpoint gpt2_small_model_dir/ckpt_200.pkl
2023-04-28 16:06:56,704 INFO:   Saving checkpoint at step : 200.
2023-04-28 16:07:13,434 INFO:   | Train Device=cuda, Step=205, Loss=8.74219, Time=16:07:13
2023-04-28 16:07:14,815 INFO:   | Train Device=cuda, Step=210, Loss=8.52344, Time=16:07:14
2023-04-28 16:07:16,194 INFO:   | Train Device=cuda, Step=215, Loss=8.60156, Time=16:07:16
2023-04-28 16:07:17,576 INFO:   | Train Device=cuda, Step=220, Loss=8.57812, Time=16:07:17
2023-04-28 16:07:18,959 INFO:   | Train Device=cuda, Step=225, Loss=8.84375, Time=16:07:18
2023-04-28 16:07:20,362 INFO:   | Train Device=cuda, Step=230, Loss=8.89062, Time=16:07:20
2023-04-28 16:07:21,760 INFO:   | Train Device=cuda, Step=235, Loss=8.71875, Time=16:07:21
2023-04-28 16:07:23,143 INFO:   | Train Device=cuda, Step=240, Loss=8.32812, Time=16:07:23
2023-04-28 16:07:24,566 INFO:   | Train Device=cuda, Step=245, Loss=8.35156, Time=16:07:24
2023-04-28 16:07:25,950 INFO:   | Train Device=cuda, Step=250, Loss=8.23438, Time=16:07:25
2023-04-28 16:07:25,951 INFO:   Saving checkpoint at step : 250.
2023-04-28 16:07:41,007 INFO:   Training Completed Successfully!