Port PyTorch model to Cerebras#
Option 1 (Easy): Modify reference models in Cerebras Model Zoo GIT repository#
The Cerebras Model Zoo repository contains reference implementations in PyTorch of popular neural networks such as BERT, GPT-2, T5, and UNet. These implementations have been modularized to separate data preprocessing, model implementation, and additional functions for execution.
If your primary goal is to use one of these models, even with some model or data preprocessing changes, we recommend start from the Cerebras Model Zoo and add the changes you need.
Example 1: Changing the data loader#
For this example, we work with the PyTorch implementation of FC_MNIST in the Cerebras Model Zoo. We create a synthetic dataloader to evaluate performance of the network with respect to different input sizes and number of classes.
In data.py
, we create a function called get_random_dataloader
that creates random images
and labels. We instrument the function to specify in the params.yaml file the number of
examples, the batch size the seed, the image_size and the number of classes of this
dataset.
import torch
import numpy as np
def get_random_dataloader(input_params,shuffle,num_classes):
num_examples = input_params.get("num_examples")
batch_size = input_params.get("batch_size")
seed = input_params.get("seed",1)
image_size = input_params.get("image_size",[1,28,28])
# Note: please cast the tensor to be of dtype `np.int32` when running on CS-2 systems and to `np.int64` when running on cpus/gpus.
np.random.seed(seed)
image = np.random.random(size = [num_examples,]+image_size).astype(np.float32)
label = np.random.randint(low =0, high = num_classes, size = num_examples).astype(np.int32)
dataset = torch.utils.data.TensorDataset(
torch.from_numpy(image),
torch.from_numpy(label)
)
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=input_params.get("num_workers", 0),
)
def get_train_dataloader(params):
return get_random_dataloader(
params["train_input"],
params["train_input"].get("shuffle"),
params["model"].get("num_classes")
)
def get_eval_dataloader(params):
return get_random_dataloader(
params["eval_input"],
False,
params["model"].get("num_classes")
)
In model.py
, we change the fix number of classes to a parameter in the
params.yaml file.
class MNIST(nn.Module):
def __init__(self, model_params):
super().__init__()
self.loss_fn = nn.NLLLoss()
self.fc_layers = []
input_size = model_params.get("input_size",784)
num_classes = model_params.get("num_classes",10)
...
self.last_layer = nn.Linear(input_size, num_classes)
...
In configs/params.yaml, we add the additional fields used in the dataloader and model definition.
train_input:
batch_size: 128
drop_last_batch: True
num_examples: 1000
seed: 123
image_size: [1,28,28]
shuffle: True
eval_input:
data_dir: "./data/mnist/val"
batch_size: 128
num_examples: 1000
drop_last_batch: True
seed: 1234
image_size: [1,28,28]
model:
name: "fc_mnist"
mixed_precision: True
input_size: 784 #1*28*28
num_classes: 10
...
Option 2 (Moderate): Create new models leveraging Cerebras run function available in Cerebras Model Zoo#
All PyTorch implementations in the Cerebras Model Zoo use a common harness to manage execution on CS system and other hardware. This harness implements the necessary code changes to compile a model for a Cerebras cluster, run a compiled model on a Cerebras cluster, or run the model on CPU/GPU. Therefore, it provides a training/evaluation interface in which models and data preprocessing scripts can be plugged into, without worrying about line-by-line modifications to have Cerebras-friendly code.
If your primary goal is to develop new model and data preprocessing scripts, we suggest
to start by leveraging the common backbone in Cerebras Model Zoo, the run
function.
As an example, let’s port a PyTorch implementation of a fully connected dense neural network for the MNIST dataset.
Prerequisites#
To use the run
function, you must have the Cerebras Model Zoo repository
compatible with the release installed in the target Cerebras cluster. The run
function can be
imported as
from modelzoo.common.pytorch.run_utils import run
All the code related with run function lives inside the Cerebras Model Zoo and can be found in the common/pytorch folder.
How to use the run
function#
The run
function modularizes the model implementation, the data loaders, the
hyperparameters and the execution. To use the run
function you need:
Params YAML file where optimizers and runtime configuration is specified.
Implementation that includes:
Model definition
Data loaders for training and evaluation
Step 1: Define Model#
To define the model architecture, the run
function requires a callable
(either class or function) that takes as input a dictionary of params and
returns a torch.nn.Module
whose forward
implementation returns a
loss tensor.
As an example, let’s implement FC_MNIST parametrized by the depth and the hidden
size of the network. We assume that the input size is 784 and the last output
dimension is 10. We use ReLU
as non linearity, and a negative log
likelihood loss.
In model.py
:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MNISTModel(nn.Module):
def __init__(self, model_params):
super().__init__()
self.fc_layers = []
input_size = 784
# Depth is len(hidden_sizes)
model_params["hidden_sizes"] = [
model_params["hidden_size"]
] * model_params["depth"]
for hidden_size in model_params["hidden_sizes"]:
fc_layer = nn.Linear(input_size, hidden_size)
self.fc_layers.append(fc_layer)
input_size = hidden_size
self.fc_layers = nn.ModuleList(self.fc_layers)
self.last_layer = nn.Linear(input_size, 10)
self.nonlin = nn.ReLU()
self.dropout = nn.Dropout(model_params["dropout"])
self.loss_fn = nn.NLLLoss()
def forward(self, batch):
inputs, targets = batch
x = torch.flatten(inputs, 1)
for fc_layer in self.fc_layers:
x = fc_layer(x)
x = self.nonlin(x)
x = self.dropout(x)
pred_logits = self.last_layer(x)
outputs = F.log_softmax(pred_logits, dim=1)
loss = self.loss_fn(outputs, targets)
return loss
Note
The input to a torch.nn.Module
object defined in the run function
includes both the inputs and the labels to compute the loss. It is up to the model
to extract the inputs and labels from the batch before using them
Note
The output of the model is expected to be the loss of that forward pass
Step 2: Define dataloaders#
To define the data loaders, the run function requires a callable (either class or
function) that takes as input a dictionary of params, and returns a
torch.utils.data.DataLoader
. When running training, the train_data_fn
must be provided.
When running evaluation, the eval_data_fn
must be provided.
As an example, to implement FCMNIST, we create two different functions for training and evaluation. We use
torchvision.datasets
functionality to download MNIST dataset. Each of these functions
returns a torch.utils.data.DataLoader
.
In data.py
:
import torch
from torchvision import datasets, transforms
def get_train_dataloader(params):
input_params = params["train_input"]
batch_size = input_params.get("batch_size")
dtype = torch.float16 if input_params["to_float16"] else torch.float32
shuffle = input_params["shuffle"]
train_dataset = datasets.MNIST(
input_params["data_dir"],
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=dtype)
),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
drop_last=input_params["drop_last_batch"],
shuffle=shuffle,
num_workers=input_params.get("num_workers", 0),
)
return train_loader
def get_eval_dataloader(params):
input_params = params["eval_input"]
batch_size = input_params.get("batch_size")
dtype = torch.float16 if input_params["to_float16"] else torch.float32
eval_dataset = datasets.MNIST(
input_params["data_dir"],
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=dtype)
),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=batch_size,
drop_last=input_params["drop_last_batch"],
shuffle=False,
num_workers=input_params.get("num_workers", 0),
)
return eval_loader
Step 3: Set up the run
function#
The run function must be imported from run_utils.py. Always remember to append the parent directory of the Cerebras Model Zoo repository in your local setup.
All the input parameters of run
function are callables that take as input a
dictionary, called params
. params
is a dictionary containing all of the model and
data parameters specified by the params YAML file of the model.
Parameter |
Type |
Notes |
---|---|---|
|
|
Required. A callable that takes in a dictionary of
parameters. Returns a |
|
|
Required during training run. |
|
|
Required during evaluation run. |
|
|
Optional. A callable that takes in a dictionary of parameters. Sets default parameters. |
For the FCMNIST example, with all of the elements in place, now we import the run function from modelzoo.common.pytorch.run_utils. We append the parent directory of Cerebras Model Zoo.
In run.py:
import os
import sys
#Append path to parent directory of Cerebras ModelZoo Repository
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from modelzoo.common.pytorch.run_utils import run
from data import (
get_train_dataloader,
get_eval_dataloader,
)
from model import MNISTModel
def main():
run(MNISTModel, get_train_dataloader, get_eval_dataloader)
if __name__ == '__main__':
main()
Manage common params for multiple experiments#
To avoid params replication between multiple similar experiments, the run
function has
an optional input parameter called default_params_fn
. This parameter modifies the
dictionary of the params YAML file, adding default values of unspecified params.
Setting up a default_params_fn
could be beneficial if the user is planning multiple
experiments in which only a small subset of the params YAML file changes. The
default_params_fn
sets up the values shared in all of the
experiments. The user can create different configuration YAML files to only address
the changes between experiments.
The default_params_fn
should be a callable that takes in the params
dictionary and
returns a new dictionary. If the default_params_fn
is omitted, the params
dictionary
will be used as is.
Step 4: Create params YAML file#
At runtime, the run
function requires a separate params YAML. This file is specified during
execution with the flag --params
in the command line.
For example, this is the params.yaml
file for the FCMNIST implementation. We customize the fields in train_input
, eval_input
, model
, to be used inside
get_train_dataloader
, get_eval_dataloader
, MNISTModel
. We also specify the required
optimizer
and runconfig
params.
train_input:
data_dir: "./data/mnist/train"
batch_size: 128
drop_last_batch: True
shuffle: True
to_float16: True
eval_input:
data_dir: "./data/mnist/val"
batch_size: 128
drop_last_batch: True
to_float16: True
model:
name: "fc_mnist"
mixed_precision: True
depth: 10
hidden_size: 50
dropout: 0.0
activation_fn: "relu"
optimizer:
optimizer_type: "SGD"
learning_rate: 0.001
momentum: 0.9
loss_scaling_factor: 1.0
runconfig:
max_steps: 10000
checkpoint_steps: 2000
log_steps: 50
seed: 1
The params YAML file has the following sections:
Section |
Required |
Notes |
---|---|---|
|
Yes |
Used by run to set up logging and execution.
It expects fields: |
|
Yes |
Used by |
|
No |
By convention, it is used to customize the model
architecture in |
|
No |
By convention, it is used to customize train_data_fn. Fields are tailored to needs inside train_data_fn. |
|
No |
By convention, it is used to customize eval_data_fn. Fields are tailored to needs inside eval_data_fn. |
Optimizer#
There are a number of optimizer parameters that can be used to configure the optimizer for the run.
Currently, the only supported optimizers are SGD
and AdamW
. The optimizer type can be
specified via the optimizer_type
sub parameter. Below are the required and optional
params that can be used to configure them
|
Parameters |
Descriptions |
---|---|---|
|
|
See “Learning Rate Scheduling” subsection. |
|
The momentum factor. |
|
|
Optional. weight decay. (L2 penalty) (Default: 0.0) |
|
|
|
See “Learning Rate Scheduling” subsection. |
|
Optional. Adam’s first beta parameter. (Default: 0.9) |
|
|
Optional. Adam’s second beta parameter. (Default: 0.999) |
|
|
Optional. Whether or not to correct bias in Adam. (Default: False) |
|
|
Parameters to exclude from weight decay. |
All above parameters being sub parameters to the optimizer top-level parameter. Refer to the Cerebras Model Zoo git repository for examples of how to configure the optimizer.
Learning Rate Scheduler#
We also support various learning rate schedulers. They are configurable using the
learning_rate
sub parameter. Valid configurations include the following:
|
Parameters |
Descriptions |
Constant |
A floating point number specifying the learning rate to be used throughout . |
|
|
|
The constant values to use. |
|
The steps on which to change the learning rate values. |
|
|
|
The starting learning rate value. |
|
The final learning rate value. |
|
|
The number of steps over which to transition from the starting learning rate to the final learning rate. |
|
|
|
The starting learning rate value. |
|
The number of steps to decay the learning rate. |
|
|
The rate at which to decay the learning rate. |
Loss scaling#
We support static and dynamic loss scaling which are configurable through the
optimizer
’s subparameters:
|
A constant scalar value means configure for static loss scaling.
Passing in the string |
|
The initial loss scale value if |
|
The number of steps after which to increase the loss scaling condition.
(Default: |
|
The minimum loss scale value that can be chosen by dynamic loss scaling.
(Default: |
|
The maximum loss scale value that can be chosen by dynamic loss scaling.
(Default: |
Global Gradient Clipping#
We support global gradient clipping by value
or by the normalized value.
They are configurable through the optimizer
’s subparameters:
|
max norm of the gradients |
|
max value of the gradients |
Note
The above subparameters are mutually exclusive. They cannot both be specified at the same time.
Step 5: Execute script with run function#
All models in Cerebras Model Zoo use the run
function inside the script run.py
for both PyTorch and TensorFlow implementations.
Therefore, once you have ported your model to use the run
function, you can follow the steps in Launch your job section to launch your training or evaluation job.
Additional functionality#
Logging#
By default, the run function logs training information to the console and to TensorBoard, as explained in Measure throughput of your model. You can also define your your own scalar and tensor summaries
Evaluation metrics#
The Cerebras Model Zoo git repository uses a base class to compute evaluation metrics called
CBMetric
. Metrics already defined in the Model Zoo git repository can be imported as:
from modelzoo.common.pytorch.metrics import (
AccuracyMetric,
DiceCoefficientMetric,
MeanIOUMetric,
PerplexityMetric,
)
As an example, the GPT2 implementation in PyTorch uses some of these metrics.
How to use evaluation metrics#
Registration: All metrics must be registered with the corresponding
torch.nn.Module
class. This is automatically done when theCBMetric
object is constructed. That is, to register a metric to atorch.nn.Module
class, construct the metric object in thetorch.nn.Module
class’ constructor.Update: The metrics are stateful. This means that every call to the metric object with the appropriate arguments automatically the latest metric value and save it in the metric’s internal state.
Logging: At the very end of the run, the final metrics values will be computed and then logged both to the console and to the TensorBoard
SummaryWriter
.
More on evaluation metrics#
Implementation of CBMetric
class can be found here.
The CBMetric
class is a base class for creating metrics on CS devices. Subclasses must
override methods to provide the full functionality of the metric. These methods are
meant to split the computation graph into two portions:
update_on_device
: Compiles and runs on the device (i.e., CS system).update_on_host
: Runs on the host (i.e., CPU).
These metrics also support running on CPU and GPU.