Efficient Weight Initialization#

Overview#

Initializing the model parameters on the CPU torch.device, while completely valid, can be slow for extremely large models, and can cause memory issues as the parameters may not fit within RAM and may spill over into swap memory or outright fail to allocate more memory.

To address this, instantiate a cerebras_pytorch.backend that contains a Cerebras device that can be used as a context manager, much like a torch.device.

For example:

backend = cstorch.backend("CSX")

with backend.device:
    model = Model()

# compile the model the same way as before
compiled_model = cstorch.compile(model, backend)

The parameters get moved to the Cerebras device automatically, which saves the parameter data that will be sent to the Cerebras Wafer-Scale cluster. This frees up memory for subsequent parameters and keeps the overall memory usage low.

Lazy Weight Initialization#

Lazy Initialization provides the ability to trace a model’s initialization. Furthermore, it eliminates unnecessary pre-initialization steps, resulting in an up to 25% reduction in time to first loss. This innovative feature redefines model initiation, enabling more efficient and resource-conscious training.

To enable lazy initialization, proceed with the following steps:

backend = cstorch.backend("CSX")
backend.device.config.lazy_initialization = True

with backend.device:
    model = Model()

Setting the configuration variable to “True” activates model initialization tracing. This captures the entire initialization process and enables concurrent execution with compilation.

Note

We have disabled the configuration variable for lazy traced initialization by default in release 2.1.0 due to limitations in fully tracing models with the current PyTorch version. These limitations stem from bugs and inherent constraints within both our current compile stack and PyTorch. We anticipate addressing these limitations in a future release, at which point lazily traced initialization will be enabled by default.

Limitations#

Change in the order of weight initialization#

After completing model initialization tracing, we group weight initialization operations for efficiency gains. However, this grouping can introduce unexpected side effects due to operations that affect the system’s internal state, like random number generator state changes upon calling random operators. For example, if multiple random operations are used to initialize a single weight, the order in which these operations are executed can differ depending on whether they are locally grouped or not. This can lead to them being called in a different order.

For example:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.w1 = torch.nn.Parameter(torch.empty(2, 2))
        self.w2 = torch.nn.Parameter(torch.empty(2, 2))

        self.reset_parameters()

    def reset_parameters(self):
        self.w1.data.uniform_(-1, 1)
        self.w2.data.normal_()

backend = cstorch.backend("CSX")
backend.device.config.lazy_initialization = True

with backend.device:
    model = Model()

def init_weights(model):
    torch.nn.init.kaiming_uniform_(model.w1)
    torch.nn.init.kaiming_normal_(model.w2)

model.apply(init_weights)

Notice that both model.w1 and model.w2 are initialized twice. The first in the call to model.reset_parameters() where model.w1 is sampled from a uniform distribution from -1 to 1 and model.w2 is sampled from a normal distribution. However, after construction, the model is initialized again via the call to model.apply(init_weights) where model.w1 is sampled this time from the kaiming uniform distribution and model.w2 is sampled from the kaiming normal distribution.

The following code displays what runs in the eager initialization flow:

model.w1.data.uniform_(-1, 1)
model.w2.data.normal_()

torch.nn.init.kaiming_uniform_(model.w1)
torch.nn.init.normal_uniform_(model.w2)

However, in the lazy initialization flow, the initialization operators are grouped on a per-weight basis. So, effectively what actually gets run is:

model.w1.data.uniform_(-1, 1)
torch.nn.init.kaiming_uniform_(model.w1)

model.w2.data.normal_()
torch.nn.init.normal_uniform_(model.w2)

Notice that the order that the initialization operations are run has changed.

Although grouping weight initialization during tracing improves efficiency, we just saw that there could be side effects from certain operations, like changes to the random number generator state. This is because the order of execution for such operations may differ between the traced (lazy) and non-traced (eager) initialization flows, leading to different final weight values.

For example:

cpu_model = Model()
cpu_model.reset_parameters()

with backend.device:
    model = Model()
    model.reset_parameters()

# This assert will fail!
assert torch.allclose(cpu_model.w1, model.w1.to("cpu"))

Note, although the initial weights won’t match the eager initialization flow, they are nonetheless valid weights. The order of initialization affects the values, but the weights are still being sampled from the same distributions. So, numerically, there should be no numeric stability or convergence issues that arise as a result of using lazy weight initialization.

Explicit Aliases on weights cause compile failures#

One known limitation of model initialization tracing is its impact on certain operators applied directly to weights during the forward pass. More specifically, any operations applied directly to weights that

  • create aliases (e.g. transpose())

  • are views AND mutations (e.g. transpose_())

are affected by lazy initialization.

For example:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(4, 2)

    def forward(self, x):
        return x @ self.fc1.weight.T + self.fc1.bias

backend = cstorch.backend("CSX")

with backend.device:
    model = Model()

While the above model’s initialization is traced, it has the known limitation where any operators applied directly to weights that create aliases during the forward pass can cause compile failures in the Cerebras stack. More specifically, the transpose operation on fc1.weight creates an alias that is problematic and will cause compile issues.

However, it is important to note that when the same operators are applied implicitly, meaning they are not directly applied to the weight tensor, then there is no issue. This is shown in the example below:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(4, 2)

    def forward(self, x):
        return F.linear(x, self.fc1.weight, self.fc1.bias)

F.linear gets translated to the addmm operator which only implicitly applies a transpose operation to fc1.weight. In the above example, there are no explicit aliases being created using the weights and thus, no issues are encountered.

We will address this compile issue in our future releases.