Multi-Phase Training#

On this page, you will learn how to set up Multi-Phase training using the Trainer class. Multi-Phase training allows you to combine multiple training phases with different batch sizes or max sequence lengths in a single config file or python script.

Prerequisites#

Please ensure that you have read through the next tutorials beforehand:

The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the Python API.

Multi-Phase Training#

In Multi-Phase training, you may want to define several distinct training phases. For example, the training pipeline for the Llama-3 model might involve varying batch sizes or max sequence lengths across different phases. Each of these phases is defined by an instance of the Trainer.

Let’s consider an example. In the Pretraining with Upstream Validation, you’ve learned how to construct the Trainer for the Llama-3 model. Now, let’s add a new training phase with a different batch size and new max sequence length.

To define each phase you need to construct a separate Trainer instance. For example:

trainer:
- trainer:
  ...
- trainer:
  ...

Note

The number of Trainer instances is not limited and each Trainer can have different parameters, so you can construct arbitrary training/validation pipelines including different models, dataloders, etc.

For each phase we define different batch size and different max sequence lengths.

trainer:
- trainer:
    init: &init
      backend:
        backend_type: CSX
        cluster_config:
          num_csx: 16
      seed: 2024
      model:
        # Embedding
        vocab_size: 128256
        hidden_size: 4096
        position_embedding_type: "rotary"
        pos_scaling_factor: 1.0
        rope_theta: 500000.0
        rotary_dim: 128
        share_embedding_weights: false
        max_position_embeddings: 8192
        embedding_dropout_rate: 0.0
        embedding_layer_norm: false

        # Decoder
        num_hidden_layers: 32
        dropout_rate: 0.0
        layer_norm_epsilon: 1.0e-5
        norm_type: "rmsnorm"

        # Decoder - Attention
        num_heads: 32
        attention_type: "scaled_dot_product"
        attention_module: "multiquery_attention"
        attention_dropout_rate: 0.0
        use_projection_bias_in_attention: false
        use_ffn_bias_in_attention: false
        extra_attention_params:
            num_kv_groups: 8

        # Decoder - ffn
        filter_size: 14336
        nonlinearity: "swiglu"
        use_ffn_bias: false

        # Task-specific
        use_bias_in_output: false
        loss_scaling: "num_tokens"
        loss_weight: 1.0

        # Initializer
        initializer_range: 0.02

        # Cerebras parameters
        mixed_precision: True
        fp16_type: "cbfloat16"

      optimizer:
        AdamW:
          betas: [0.9, 0.95]
          correct_bias: True
          weight_decay: 0.1

      schedulers:
      - CosineDecayLR:
          initial_learning_rate: 3.0e-5
          end_learning_rate: 3.0e-6
          total_iters: 528

      precision:
        fp16_type: cbfloat16
        loss_scaling_factor: dynamic
        max_gradient_norm: 1.0

      loop:
        num_steps: 10000
        eval_frequency: 1000
        eval_steps: 1000

      checkpoint:
        steps: 1000

      callbacks:
      - ComputeNorm: {}
      - CheckLoss: {}
      - ModelEvalMetrics: {}

      loggers:
      - ProgressLogger: {}
      - TensorBoardLogger: {}
    fit:
      train_dataloader:
        data_processor: GptHDF5MapDataProcessor
        data_dir: "/data/llama_v3_dataset_vocab128256_msl8192/train"
        batch_size: 80
        shuffle: False
        shuffle_seed: 1337
        num_workers: 8
        prefetch_factor: 10
        persistent_workers: True # Important to avoid seeding at each epoch
      val_dataloader:
      - data_processor: GptHDF5MapDataProcessor
        data_dir: "/data/llama_v3_dataset_vocab128256_msl8192/val"
        batch_size: 80
        shuffle: False
        shuffle_seed: 1337
        num_workers: 8
        prefetch_factor: 10
        persistent_workers: True # Important to avoid seeding at each epoch
- trainer:
    init:
      <<: *init
    fit:
      train_dataloader:
        data_processor: GptHDF5MapDataProcessor
        data_dir: "/data/llama_v3_dataset_vocab128256_msl512/train"
        batch_size: 40
        shuffle: False
        shuffle_seed: 1337
        num_workers: 8
        prefetch_factor: 10
        persistent_workers: True # Important to avoid seeding at each epoch
      val_dataloader:
      - data_processor: GptHDF5MapDataProcessor
        data_dir: "/data/llama_v3_dataset_vocab128256_msl512/val"
        batch_size: 40
        shuffle: False
        shuffle_seed: 1337
        num_workers: 8
        prefetch_factor: 10
        persistent_workers: True # Important to avoid seeding at each epoch

Note

It’s important to note that when using YAML, you have to construct a Trainer instance for each phase, which adds some overhead to your run due to time spent on compile and weights transfer. If you are using Python API, you can construct a single Trainer object and call fit using different DataLoader objects.

Multi-Phase Training (Advanced)#

A more advanced example of Multi-Phase training involves changing model parameters between training phases. For instance, you might want to switch the learning rate scheduler from CosineDecayLR to ConstantLR. To accomplish this, you need to create two instances of the Trainer and carefully manage checkpoint loading between phases to account for the changes in model parameters.

Note

In the example below, please note that the model, optimizer, and other parameters are similar to those in the previous example. These parameters have been omitted to simplify the example.

trainer:
- trainer:
    init:
      ...
      schedulers:
      - CosineDecayLR:
          initial_learning_rate: 3.0e-5
          end_learning_rate: 3.0e-6
          total_iters: 528
- trainer:
    init:
      ...
      schedulers:
      - ConstantLR:
          learning_rate: 1.0e-6
      callbacks:
      - LoadCheckpointStates:
          load_checkpoint_states: "model,grad_scaler,optimizer,global_step"

In this example, each Trainer constructs and compiles a model where in the second phase we changed the scheduler to ConstantLR, so to avoid any issues with checkpoint loading we specify which parameters needs to be loaded. For further reading please follow Checkpointing.

Caveats#

When running Multi-Phase training using Python API, you may hit an issue:

RuntimeError: Cannot instantiate multiple backends. A backend with type CSX has already been instantiated.

Please ensure that when you construct a Trainer, you only instantiate a single backend. For example:

backend = cstorch.backend(
    "CSX",
    ...
)

trainer1 = Trainer(
    backend=backend,
    ...
)

trainer2 = Trainer(
    backend=backend,
    ...
)

Conclusion#

This tutorial showcases some of the use cases where Multi-Phase training can be applied. However, you are not limited to these examples and can construct as many Trainers as you need, combining different models, schedulers, optimizers, dataloaders, and more.