Efficient weight initialization#

Overview#

This guide explores a method to initialize model parameters more efficiently by leveraging Cerebras’ hardware acceleration. Typically, initializing parameters on a CPU can be slow for large models and might lead to memory constraints - specifically, there’s a risk that parameters won’t fit within the available RAM, leading to potential overflow into swap memory or, in some cases, complete failure to allocate the necessary memory. To mitigate these issues, we introduce an approach using a cerebras.pytorch.backend instance, facilitating the parameter initialization process on a Cerebras device, similar to how one would use a torch.device.

For example:

# Initialize the Cerebras backend for efficient processing.
backend = cstorch.backend("CSX")

# Use the backend's device context manager for initializing the model.
with backend.device:
    model = Model()

# Compile the model using the Cerebras backend for optimized execution.
compiled_model = cstorch.compile(model, backend)

This method automatically moves the parameters to the Cerebras device, optimizing memory usage and enhancing initialization speed. This frees up memory for subsequent parameters and keeps the overall memory usage low.

This approach simplifies the process of achieving more efficient weight initialization. For a deeper understanding of the underlying mechanics, refer to the following subsections.

Lazy Weight Initialization#

Lazy Initialization enhances the model initialization process by tracing a model’s initialization. It also removes redundant computations that occur before initialization, significantly decreasing the time to achieve the first loss. This functionality transforms the way models are initialized, fostering a more efficient and resource-aware approach to training.

To invoke lazy initialization, proceed with the following steps:

# Initialize the Cerebras backend.
backend = cstorch.backend("CSX")

# Initialize the model within the backend's context for lazy initialization.
with backend.device:
    model = Model()

Initializing the model from within the backend’s device context manager allows us to fully capture the initialization graph to compute at a later time. Doing so creates many opportunities to optimize the initialization, such as:

  • Rapid weight initialization and initiate the model compilation process

  • Minimizes the underlying computational efforts

  • Enhances the parallelization of weight initialization computations

  • Substantially reduces overall memory usage

  • Reduces the number of file write operations needed

Consequently, lazy initialization is the default setting, automatically applied when a model is initialized within the backend device’s context manager.

If you encounter numerical or convergence issues, you can disable this feature to see if its the model initialization that is the cause of the issues:

backend.device.config.lazy_initialization = False

Disabling lazy initialization shifts the initialization scheme closer to eager model initialization - which substantially lengthens the time needed to initialize the model and delays the receipt of the first result from the cluster.

To learn more about how the benefits described above are achieved, read the following subsections.

Parallelizing Weight Initialization#

Upon completing the tracing of the model’s initialization, for efficiency gains we divide the weight graphs into distinct subgraphs that can be executed concurrently. Since they do not have any dependencies between them, they can be executed in parallel.

By default, four initialization subgraphs are executed in parallel. This default was chosen empirically using data from many different experiments.

To change the number of subgraphs that are initialized at once, set the following:

# Set the maximum number of parallel subgraphs for initialization.
backend.device.config.max_async_parallel_compute = 4

Increasing the parallelization may cause out-of-memory issues. In the worst case, we could potentially be initializing N extremely large weights at once. This is typically unlikely as smaller weights outnumber large ones, if Out-of-Memory issues occur, reducing the maximum parallelization level using the configuration variable might help.

Deterministic Pseudo-Random Weight Initialization#

When initializing weights in parallel using a pseudo-random number generator, the simultaneous initialization of multiple weights involving random sampling brings up concerns regarding the determinism of weight initialization.

The adopted solution is to assign a unique pseudo-random number generator to each random sampling operation. These generators are individually seeded using the default random number generator, allowing for deterministic seeding with torch.manual_seed(seed).

Example transformation for deterministic behavior:

# Initialize a linear layer.
model = torch.nn.Linear(15, 10)

# Assign a unique generator for deterministic random initialization.
with torch.no_grad():
    model.weight.uniform_()

We inject a generator so that what is effectively getting executed is:

model = torch.nn.Linear(15, 10)

with torch.no_grad():
    g = torch.Generator()
    g.manual_seed(torch.randint(..., generator=torch.default_generator))
    model.weight.uniform_(generator=g)

This is done for every single random operator in the initialization graph, thus making each initialization subgraph completely independent from one another. Consequently, this approach guarantees that even when multiple weights are initialized in parallel, the outcomes are deterministic.

Optimizing Weight Initialization#

After capturing and dividing the initialization graph into subgraphs, there’s an opportunity to refine and optimize these subgraphs prior to their execution.

For example, take the initialization of the following weight:

# Define a linear layer.
model = torch.nn.Linear(15, 10)

with torch.no_grad():
    # Normal initialization.
    model.weight.normal_()
    # Uniform initialization
    model.weight.uniform_()

The call to normal_ is rendered redundant as the very next call to uniform_ will overwrite all the values in model.weight anyways. Eliminating the call to normal_ won’t affect the final result, maintaining the same outcomes while streamlining the process.

By default, this optimization process prunes a specific set of operations known to overwrite all the values of the input tensor. These operations include, but are not limited to:

  • fill

  • normal

  • random

  • uniform

  • zero

If you come across a situation where the numerical values, like the weight distribution, do not align with the expected outcome, turn off the initialization optimization by configuring the following settings:

# Disable initialization optimization.
backend.device.config.optimize_initialization = False

Implementation Notes#

  1. Ensure that the Cerebras backend is correctly initialized and accessible.

  2. When working with large models, monitor memory usage to prevent overflow.

  3. Know which operations are pruned by default and how this affects the initialization process.

Best Practices#

  1. While lazy initialization can speed up the setup process, be aware of its implications and switch to eager initialization if you face issues related to model convergence or performance.

  2. Validate the initialization’s impact on model training to ensure that efficiency gains don’t compromise model performance.

  3. Ensure that your system has enough memory to support the chosen level of parallelization without running into out-of-memory issues.

Conclusion#

By adopting these advanced initialization techniques, you can significantly enhance the efficiency and speed of weight initialization, particularly for large-scale models. These methods leverage the power of Cerebras, offering a practical solution to the challenges associated with traditional initialization approaches on CPUs.