Defer Weight Initialization#

In this page, you will explore how to defer model weight initialization using the Trainer class. By deferring the initialization of model weights, you can significantly reduce the time-to-first-loss, leading to faster iteration times and a more efficient training process.

By the end of this guide, you will understand how to implement deferred initialization and the advantages it brings to your model training.

Prerequisites#

Please ensure that you have read through the Trainer Overview beforehand. The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the python API.

Model Function#

In Basic Usage, it was shown that you could pass any torch.nn.Module into the Trainer. However, to do this, you need to have a concrete torch.nn.Module object to pass into the Trainer. Due to PyTorch’s eager nature, initializing a model can be very time consuming, especially for extremely large models.

To improve your experience, we introduce a mechanism by which you can defer your model’s weight initialization. The way to do this would be to pass in a function to the model argument of the Trainer that takes in no arguments and returns a torch.nn.Module object.

import torch
from cerebras.modelzoo import Trainer

def model_fn() -> torch.nn.Module:
    ...

trainer = Trainer(
    device="CSX",
    model=model_fn,
    ...
)

Any callable that matches this schema is accepted. So, you could construct your model inline using a lambda function as follows.

import torch

trainer = Trainer(
    device="CSX",
    model=lambda: torch.nn.Linear(784, 10),
    ...
)

Passing in a model function will allow the Trainer to employ the use of the Efficient weight initialization mechanism built into the Cerebras PyTorch API (see here for more details).

Empirically, deferring model weight initialization can reduce the time-to-first-loss (the amount of time it takes to get the first value back from the Wafer-Scale Cluster) by over 50%. This means, less time waiting around and faster iteration time overall.

Optimizer/Scheduler Functions#

One question that may be on your mind is, “what about the optimizer?”. The optimizer constructor takes in the model parameters. So, if the model initialization is being delayed, then how can the optimizer be constructed?

The answer is that the Trainer can also accept a function for the optimizer argument which is expected to take in a torch.nn.Module and return a Optimizer object.

import cerebras.pytorch as cstorch
import torch

trainer = Trainer(
    device="CSX",
    model=lambda: torch.nn.Linear(784, 10),
    optimizer=lambda model: cstorch.optim.SGD(
        model.parameters(),
        lr=0.01,
        momentum=0.9,
    ),
    ...
)

Similarly, the Trainer can also accept a function for the schedulers argument which is expected to take in a Optimizer object and return a Scheduler object.

import cerebras.pytorch as cstorch
import torch

trainer = Trainer(
    device="CSX",
    model=lambda: torch.nn.Linear(784, 10),
    optimizer=lambda model: cstorch.optim.SGD(
        model.parameters(),
        lr=0.01,
        momentum=0.9,
    ),
    schedulers=[
        lambda optimizer: cstorch.optim.lr_scheduler.LinearLR(
            optimizer,
            initial_learning_rate=0.01,
            end_learning_rate=0.01,
            total_iters=100,
        ),
    ],
)

Conclusion#

That is all there is to deferring model, optimizer, and scheduler initialization! By simply wrapping these components inside a callable, you can gain large improvements to iteration time and resource utilization, enhancing your overall experience with the Cerebras Model Zoo and ensuring more efficient training outcomes.

Further Reading#

To learn about how you can configure a Trainer instance using a YAML configuration file, you can check out:

  • Trainer YAML Overview

To learn more about how you can use the Trainer in some core workflows, you can check out:

To learn more about how you can extend the capabilities of the Trainer class, you can check out: