Deterministically restart dataloader after pausing training#
As you train a large language model, you might need to pause the training run for various reasons (e.g., changing the batch size or a hyperparameter, addressing training instabilities, training with more hardware, etc.). Models are typically checkpointed in order to restart the training at the same point where training stopped. This feature offers a similar capability for input-generating dataloaders.
A model trained on duplicate samples performs worse than the one trained on deduplicated data ([1, 2]). Even if a data set is deduplicated, if the dataloader is not deterministically restarted after a pause in training there is a chance for a model to be trained on samples or batches the model has already seen. This could lead to memorization and degrade upstream and downstream performance of the model.
Through this feature, users may resume training deterministically – from the same point in the input-generating dataloader where a previous run was halted – thereby ensuring the model is not trained with repeated data samples.
This feature is currently only supported for our input-generating workers in PyTorch for Weight Streaming execution. It requires train_input.num_workers=1.
How to enable#
Enable deterministic data restarts for your training run with just one simple step!
For your original run from
global step=0, set the parameter
cerebras.save_iter_state_path in the config yaml file. This parameter accepts a string literal path to the mounted directory where the data checkpoints will be written.
This directory should be visible to the worker nodes in the Weight Streaming mode. This means, this path should be a mounted path and specified under
--mount_dirs. The directory will be created if it does not already exist.
As an example, we can provide the mounted modelzoo path to the
cerebras: save_iter_state_path: </path/to/mounted/modelzoo/dir>
And that’s all!
Once the run is started, you can confirm if the dataloader is saving its state under the provided path. You should see two types of files in path
data_iter_checkpoint_state_file_global file records the integer step at which the last weight checkpoint for a given run is captured.
The set of files
data_iter_state_file_worker_<worker_id>_step_<global_step>.txt records the iterator state at each of the global steps.
The number of individual worker checkpoints will match the number of weight checkpoints – i.e., the frequency at which the dataloader state is saved is the same at which we capture the model checkpoints.
At the time of restart, provide the same path for
cerebras.save_iter_state_path in the yaml and the input dataloader will automatically restart from the same state of the original run.
In some cases, you might want to rewind and restart from a different step in the previous training run. To achieve this, modify
data_iter_checkpoint_state_file_global and set it to the global step you want the run to restart from. Note that this global step must be one of the steps where a dataloader checkpoint was saved.
You may restart the run with a different batch size as well, and the deterministic data restart feature should work seamlessly.
You need to set the parameter
runconfig.num_workers_per_csx in the config. For large language model training, it is advised to set it to 1.
runconfig: num_workers_per_csx: 1
This argument can also be passed to run command as below:
python run.py CSX weight_streaming -p </path/to/params> -m <mode> --model_dir </path/to/model/dir> --num_workers_per_csx 1
To restart a run from a different checkpoint then latest, user needs to manually modify the global state file.