Writing a Custom Training Loop#

Overview#

Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.

Proceed with the following steps to learn how to write a custom training loop for a simple, fully connected model for training on the MNIST dataset.

Prerequisites#

You have installed the cerebras_pytorch package in your environment.

Validate the package installation#

To check whether the cerebras_pytorch package is installed correctly, issue the following command:

import cerebras_pytorch as cstorch

Note

From here on, we will be using cstorch as the alias for cerebras_pytorch

Define your model#

When using the Cerebras PyTorch API, you can define your model in the same way you would in a Vanilla PyTorch workflow:

import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)
    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))

model = Model()

Note

Weight initialization for extremely large models can cause out-of-memory errors. See the page on Efficient Weight Initialization to see how to work around this issue.

Compile your model#

Once the model has been instantiated, compile the model by calling the cerebras_pytorch.compile, e.g.

compiled_model = cstorch.compile(model, backend="CSX")

You must pass in the backend you wish to compile the model with. You can simply pass in the type of backend if you wish to use all default arguments, or you can instantiate a backend object using cerebras_pytorch.backend to customize it, e.g.

backend = cstorch.backend("CSX", compile_dir="/path/to/compile")
compiled_model = cstorch.compile(model, backend)

Note

The call to cstorch.compile doesn’t actually compile the model. Similar to torch.compile it only prepares the model for compilation. Compilation only happens after the first iteration once the input shapes are known.

Optimize Model Parameters#

To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras_pytorch.optim, e.g.

optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Note

For convenience, we also include a configuration helper function configure_optimizer.

If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing Custom Optimizers

DataLoaders#

To send data to the Wafer-Scale cluster, you must wrap your PyTorch dataloader with cerebras_pytorch.utils.data.DataLoader, e.g.

def get_torch_dataloader(batch_size):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "/path/to/data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
                transforms.Lambda(
                    lambda x: torch.as_tensor(x, dtype=torch.float16)
                ),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        )
    )

    return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

cerebras_loader = cstorch.utils.data.DataLoader(get_torch_dataloader, batch_size=64)

The Cerebras PyTorch dataloader takes in some callable that returns a PyTorch dataloader. It must be done this way so that every single worker can create their own PyTorch dataloader instance to maximize distributed parallelism.

Define the Training Step#

To run a single training iteration on the Cerebras Wafer-Scale cluster, we must first, capture everything that is intended to run on the cluster. To do this, define a function which contains everything that happens in a single training iteration, and decorate it using cerebras_pytorch.trace.

For example:

loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss

This function gets traced and sent to the cluster for compilation and execution.

Note

Currently, the decorated function is only traced once. Any changes to the computation graph in subsequent iterations are not seen. Please see the page on static graphs for more details.

Define an Execution#

To program an execution run on the Cerebras Wafer-Scale cluster, you must define an instance of the cerebras_pytorch.utils.data.DataExecutor, e.g.

executor = cstorch.utils.data.DataExecutor(
    dataloader, num_steps=1000, checkpoint_steps=100
)

It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.

Configuring the Cerebras Wafer Scale Cluster#

To configure the Cerebras Wafer-Scale cluster, pass a CSConfig object to the executor.

For example:

executor = cstorch.utils.data.DataExecutor(
    ...,
    cs_config=cstorch.utils.CSConfig(
        num_csx=2,
        num_workers_per_csx=3
    )
)

See the class documentation for CSConfig for all the options configurable.

Note

Most options have reasonable defaults and do not need to be changed.

Train your model#

Once the above is defined, you can iterate through the executor to train your model.

@cstorch.step_closure
def print_loss(loss: torch.Tensor):
    print(f"Loss: {loss.item()}")

for inputs, targets in executor:
    loss = training_step(inputs, targets)
    print_loss(loss)

Note

Notice how the loss was passed into a function decorated by step_closure. This is required to retrieve the loss value from the Cerebras Wafer Scale Cluster before it can be used. Please see the page on step closures for more details.

Putting it all together#

Combining all of the above steps, we can create a super minimal training script for a simple, fully connected model training on the MNIST dataset:

import cerebras_pytorch as cstorch

import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))

model = Model()

compiled_model = cstorch.compile(model, backend="CSX")

optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def get_torch_dataloader(batch_size):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "/path/to/data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
                transforms.Lambda(
                    lambda x: torch.as_tensor(x, dtype=torch.float16)
                ),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        )
    )

    return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

cerebras_loader = cstorch.utils.data.DataLoader(get_torch_dataloader, batch_size=64)

loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss

executor = cstorch.utils.data.DataExecutor(
    cerebras_loader, num_steps=1000, checkpoint_steps=100
)

for inputs, targets in executor:
    loss = training_step(inputs, targets)

Note

For a full-fledged training script example, see End-to-end Examples.

Further reading#