Sample Training and Eval Scripts for Cerebras PyTorch API (Experimental)#
For a full overview of the Cerebras PyTorch Experimental API and its components please see How to port your code using Cerebras PyTorch API (Experimental).
Define Dataloader and Input Functions#
The dataloader must be defined in a file separate from the model and main execution loop, as shown below.
dataloader.py#
import os
import torch
from torchvision import datasets, transforms
def get_mnist_dataset(train=True):
data_dir = os.path.join(os.getcwd(), 'mnist_dataset')
return datasets.MNIST(
data_dir,
train=train,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.float16)
),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
),
)
def input_fn_train(batch_size=4, drop_last=False):
train_dataset = get_mnist_dataset(train=True)
return torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, drop_last=drop_last, shuffle=True,
)
def input_fn_eval(batch_size=4, drop_last=False):
eval_dataset = get_mnist_dataset(train=False)
return torch.utils.data.DataLoader(
eval_dataset, batch_size=batch_size, drop_last=drop_last, shuffle=False,
)
Training Example#
In the same directory as the dataloader, create a training script as follows:
training.py#
""" Example of training script for FC MNIST model on CSX with Weight Streaming. """
import logging
import os
import cerebras_pytorch.experimental as cstorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from dataloader import input_fn_train, input_fn_eval
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.fc_layers = []
input_size = 784
hidden_size = 50
depth = 10
hidden_sizes = [hidden_size] * depth
for hidden_size in hidden_sizes:
fc_layer = nn.Linear(input_size, hidden_size)
self.fc_layers.append(fc_layer)
input_size = hidden_size
self.fc_layers = nn.ModuleList(self.fc_layers)
self.last_layer = nn.Linear(input_size, 10)
self.nonlin = nn.ReLU()
self.dropout = nn.Dropout(p=0.0)
def forward(self, inputs):
x = torch.flatten(inputs, 1)
for fc_layer in self.fc_layers:
x = fc_layer(x)
x = self.nonlin(x)
x = self.dropout(x)
pred_logits = self.last_layer(x)
outputs = F.log_softmax(pred_logits, dim=1)
return outputs
# CONFIGURABLE VARIABLES FOR THIS SCRIPT
# Can optionally move these arguments to a params file and configure from there.
MODEL_DIR = "./"
COMPILE_ONLY = False
VALIDATE_ONLY = False
TRAINING_STEPS = 10
CKPT_STEPS = 5
LOG_STEPS = 5
# Checkpoint-related configurations
CHECKPOINT_STEPS = 5
IS_PRETRAINED_CHECKPOINT = False
def main_training_loop():
torch.manual_seed(2023)
model = MNISTModel()
compiled_model = cstorch.compile(model, backend="WSE_WS")
# Define loss function for FC MNIST Model
loss_fn = torch.nn.NLLLoss()
# Define the optimizer used for training.
# This example will be using SGD from cerebras_pytorch.experimental.optim.Optimizer
# For a complete list of optimizers available in the experimental API, please see
# https://docs.cerebras.net/en/latest/wsc/port/porting-pytorch-to-cs/cstorch-api.html#initializing-the-optimizer
optimizer = cstorch.optim.configure_optimizer(
optimizer_type="SGD",
params=model.parameters(),
lr=0.01,
momentum=0.0,
)
# Optionally define the learning rate scheduler
# This example will be using LinearLR from cerebras_pytorch.experimental.optim.lr_scheduler
# For a complete list of lr schedulers available in the experimental API, please see
# https://docs.cerebras.net/en/latest/wsc/port/porting-pytorch-to-cs/cstorch-api.html#initializing-the-learning-rate-scheduler
lr_params = {
"scheduler": "Linear",
"initial_learning_rate": 0.01,
"end_learning_rate": 0.001,
"total_iters": 5,
}
lr_scheduler = cstorch.optim.configure_lr_scheduler(optimizer, lr_params)
# Define gradient scaling parameters.
grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")
loss_values = []
total_steps = 0
@cstorch.step_closure
def accumulate_loss(loss):
nonlocal loss_values
nonlocal total_steps
loss_values.append(loss.item())
total_steps += 1
lr_values = []
@cstorch.step_closure
def save_learning_rate():
lr_values.append(lr_scheduler.get_last_lr())
# DEFINE METHOD FOR SAVING CKPTS
@cstorch.step_closure
def save_checkpoint(step):
logging.info(f"Saving checkpoint at step {step}")
checkpoint_file = os.path.join(MODEL_DIR, f"checkpoint_{step}.mdl")
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"grad_scalar": grad_scaler.state_dict(),
}
state_dict["global_step"] = step
cstorch.save(state_dict, checkpoint_file)
logging.info(f"Saved checkpoint {checkpoint_file}")
global_step = 0
# DEFINE THE ACTUAL TRAINING LOOP
@cstorch.compile_step
def training_step(batch):
inputs, targets = batch
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
cstorch.amp.optimizer_step(
loss, optimizer, grad_scaler,
)
lr_scheduler.step()
save_learning_rate()
accumulate_loss(loss)
return loss
# DEFINE POST-TRAINING LOOP IF YOU ARE INTERESTED IN TRACKING SUMMARIES, ETC.
writer = SummaryWriter(log_dir=os.path.join(MODEL_DIR, "train"))
@cstorch.step_closure
def post_training_step(loss):
if LOG_STEPS and global_step % LOG_STEPS == 0:
# Define the logging any way desired.
logging.info(
f"| Train: {compiled_model.backend.name} "
f"Step={global_step}, "
f"Loss={loss.item():.5f}"
)
# Add handling for NaN values
if torch.isnan(loss).any().item():
raise ValueError(
"NaN loss detected. "
"Please try different hyperparameters "
"such as the learning rate, batch size, etc."
)
if torch.isinf(loss).any().item():
raise ValueError("inf loss detected.")
for group, lr in enumerate(lr_scheduler.get_last_lr()):
writer.add_scalar(f"lr.{group}", lr, global_step)
cstorch.save_summaries(writer, global_step)
# PERFORM TRAINING LOOPS
batch_size = 4
dataloader = cstorch.utils.data.DataLoader(
input_fn_train, batch_size, num_steps=TRAINING_STEPS
)
for i, batch in enumerate(dataloader):
loss = training_step(batch)
global_step += 1
post_training_step(loss)
# Save the loss value to be able to plot the loss curve
cstorch.scalar_summary("loss", loss)
if CHECKPOINT_STEPS and global_step % CHECKPOINT_STEPS == 0:
save_checkpoint(global_step)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
os.makedirs(os.path.join(os.getcwd(),'mnist_dataset'), exist_ok=True)
cstorch.configure(
model_dir=MODEL_DIR,
compile_dir="./compile_dir",
mount_dirs=[os.getcwd()],
python_paths=[os.getcwd()],
compile_only=COMPILE_ONLY,
validate_only=VALIDATE_ONLY,
checkpoint_steps=CKPT_STEPS,
# CSConfig params
max_wgt_servers=1,
num_workers_per_csx=1,
max_act_per_csx=1,
)
main_training_loop()
Eval Example#
In the same directory as the dataloader, create a evaluation script as follows:
eval.py#
""" Example of training script for FC MNIST model on CSX with Weight Streaming. """
import logging
import os
import cerebras_pytorch.experimental as cstorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.fc_layers = []
input_size = 784
hidden_size = 50
depth = 10
hidden_sizes = [hidden_size] * depth
for hidden_size in hidden_sizes:
fc_layer = nn.Linear(input_size, hidden_size)
self.fc_layers.append(fc_layer)
input_size = hidden_size
self.fc_layers = nn.ModuleList(self.fc_layers)
self.last_layer = nn.Linear(input_size, 10)
self.nonlin = nn.ReLU()
self.dropout = nn.Dropout(p=0.0)
def forward(self, inputs):
x = torch.flatten(inputs, 1)
for fc_layer in self.fc_layers:
x = fc_layer(x)
x = self.nonlin(x)
x = self.dropout(x)
pred_logits = self.last_layer(x)
outputs = F.log_softmax(pred_logits, dim=1)
return outputs
# CONFIGURABLE VARIABLES FOR THIS SCRIPT
# Can optionally move these arguments to a params file and configure from there.
MODEL_DIR = "./"
COMPILE_ONLY = False
VALIDATE_ONLY = False
CKPT_STEPS = 5
# Checkpoint-related configurations
CHECKPOINT_STEPS = 5
CHECKPOINT_PATH_EVAL = None
def main_eval_loop():
model = MNISTModel()
compiled_model = cstorch.compile(model, backend="WSE_WS")
def load_checkpoint(checkpoint_path):
state_dict = cstorch.load(checkpoint_path)
model.load_state_dict(state_dict["model"])
global_step = state_dict.get("global_step", 0)
return global_step
global_step = 0
if CHECKPOINT_PATH_EVAL is not None:
global_step = load_checkpoint(CHECKPOINT_PATH_EVAL)
else:
logging.info(
f"No checkpoint was provided, model parameters will be "
f"initialized randomly"
)
writer = SummaryWriter(log_dir=os.path.join(MODEL_DIR, "eval"))
# Define the accuracy use by the model for evaluation.
# This example shows two different eval metrics being used,
# accuracy and perplexity. NOTE: For a complete list of eval metrics
# available in the experimental API, please see
# https://docs.cerebras.net/en/latest/wsc/port/porting-pytorch-to-cs/cstorch-api.html#evaluation-metrics
accuracy = cstorch.metrics.AccuracyMetric(
"accuracy", compute_on_system=True
)
perplexity = cstorch.metrics.PerplexityMetric(
"perplexity", compute_on_system=True
)
# Define loss function for FC MNIST Model
loss_fn = torch.nn.NLLLoss()
@cstorch.compile_step
def eval_step(batch):
inputs, targets = batch
outputs = compiled_model(inputs).to(torch.float16)
loss = loss_fn(outputs, targets)
accuracy(
labels=targets.clone(), predictions=outputs.argmax(-1).int(),
)
perplexity(labels=targets.clone(), loss=loss)
return loss
total_loss = 0
total_steps = 0
@cstorch.step_closure
def post_eval_step(loss: torch.Tensor):
nonlocal total_loss
nonlocal total_steps
logging.info(
f"| Eval: {compiled_model.backend.name} "
f"Step={global_step}, "
f"Loss={loss.item():.5f}"
)
if torch.isnan(loss).any().item():
raise ValueError("NaN loss detected.")
if torch.isinf(loss).any().item():
raise ValueError("inf loss detected.")
total_loss += loss.item()
total_steps += 1
cstorch.scalar_summary("loss", loss)
# PERFORM EVAL LOOPS
batch_size = 4
dataloader = cstorch.utils.data.DataLoader(
input_fn_eval, batch_size, num_steps=10
)
for i, batch in enumerate(dataloader):
loss = eval_step(batch)
global_step += 1
post_eval_step(loss)
writer.add_scalar(f"Eval Accuracy", float(accuracy), global_step)
writer.add_scalar(f"Eval Perplexity", float(perplexity), global_step)
cstorch.save_summaries(writer, global_step)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
os.makedirs(os.path.join(os.getcwd(),'mnist_dataset'), exist_ok=True)
cstorch.configure(
model_dir=MODEL_DIR,
compile_dir="./compile_dir",
mount_dirs=[os.getcwd()],
python_paths=[os.getcwd()],
compile_only=COMPILE_ONLY,
validate_only=VALIDATE_ONLY,
checkpoint_steps=CKPT_STEPS,
# CSConfig params
max_wgt_servers=1,
num_workers_per_csx=1,
max_act_per_csx=1,
)
main_eval_loop()