Port your code using Cerebras PyTorch API (Experimental)#
Note
This API is marked experimental. While, for the most part the design has been finalized, there may be backwards compatibility breaking changes introduced in a future release before it is marked as stable.
High Level Overview#
To leverage the Cerebras PyTorch API for porting and running your model on a Cerebras Wafer-Scale cluster, there are a few high level steps that must be done:
-
This configuration involves setting up the cluster for the run. This configure function must be called before the model is instantiated.
-
In order to train or evaluate a model on a Cerebras Wafer-Scale cluster it must be compiled.
Instantiate a Cerebras compliant optimizer (training-only)
Due to the nature of lazy tensor tracing and execution, the core PyTorch optimizer implementations (having been designed for eager execution) are not compatible with running on Cerebras hardware. We provide a number of drop-in Cerebras compliant replacements for commonly used optimizers.
Construct a training step function
Also due to the nature of lazy tensor tracing and execution, we need to be able to capture the entirety of computation graph. We capture this inside a training/evaluation step.
-
The training loop is fully exposed and customizable now.
-
The Cerebras dataloader must be initialized to distribute the dataloader across workers
There further steps you can perform to supplement and improve the performance and numerics of your run, such as
There are also some caveats that one should be aware of when running on the Cerebras Wafer-Scale Cluster, such assigned
Importing cstorch
#
Currently the experimental API can be accessed by importing:
import cerebras_pytorch.experimental as cstorch
Once stable, the experimental modules will be merged into the base of
cerebras_pytorch
.
Configuring a run#
Configuring a run on the Cerebras Wafer-Scale Cluster using the new Cerebras
PyTorch API is done via a call to cstorch.configure
:
cstorch.configure(
model_dir=...,
compile_dir=...,
compile_only=...,
validate_only=...,
checkpoint_steps=...,
# CSConfig params
num_csx=...,
max_wgt_servers=...,
mount_dirs=...,
python_paths=...,
transfer_processes=...,
num_workers_per_csx=...,
job_labels=...,
max_act_per_csx=...,
)
These are all of the parameters you would need to configure the cluster for the run. Note, most of these parameters are optional and have reasonable defaults. See the below table for a detailed description on each supported parameter and their defaults.
Type
Description
model_dir
str
The path to the directory where model related artifacts get saved on the appliance, not on the user node. (Default:
"./"
)
compile_dir
str
The path to the directory where compile related artifacts get saved on the appliance, not on the user node. (Default:
"/opt/cerebras/cached_compile"
)
compile_only
bool
If
True
, configure the run to only compile the model, not execute. (Default:False
)
validate_only
bool
If
True
, configure the run to only trace the model, not compile or execute. (Default:False
)
checkpoint_steps
int
Set the intervals at which checkpoints will be saved. Note, checkpoints will only be available for saving on these steps as well as the last step (if the value is non-zero). (Default:
0
)
num_csx
int
The number of Cerebras systems in the cluster to use in this run. (Default:
1
)
max_wgt_servers
int
The maximum number of weight servers to use in this run. (Only applicable for weight streaming) (Default:
24
)
mount_dirs
list[str]
A list of paths to local directories that should be mounted to the cluster. (Default:
None
)
python_paths
list[str]
A list of paths that worker pods respect as PYTHONPATH in addition to the PYTHONPATH set in the container image. (Default:
None
)
transfer_processes
int
Number of processes to transfer data to/from appliance. (Default:
5
)
num_workers_per_csx
int
Number of streaming workers per system. (Default:
1
)
job_labels
list[str]
A list of equal-sign-separated key-value pairs that get applied as part of job metadata. (Default:
None
)
max_act_per_csx
int
Maximum Number of activation servers per system. (Default:
1
)
Note
We don’t currently support multiple CS runs in a single process. This means that the above configuration function can only be called once and applies globally. Any runs with different configurations must be run in separate processes.
Compiling a torch.nn.Module
#
In order to prepare a model for compilation, we introduce the
cstorch.compile
function:
model: torch.nn.Module = ... # Any PyTorch module
compiled_model = cstorch.compile(model, backend="wse_ws")
We designed this function after the torch.compile
function that was
introduced in PyTorch 2.0.
Note
The call to cstorch.compile
, much like torch.compile
does not
actually compile the model. It only prepares the model for compile. Meaning
that it moves the model parameters to the appropriate torch.device
and
prepares the internals of cstorch
to trace it.
The actual compilation does not happen until the first iteration is complete and the batch size is known.
Initializing the Optimizer#
Much like in our previous releases, we cannot support the vanilla PyTorch optimizers. This is because the implementations available in core PyTorch were designed with eager execution in mind and are fundamentally incompatible with traced lazy execution.
As such, we provide our own drop in replacements for these optimizers inside
cstorch.optim
. See the table below for a list of all of the optimizers that
are currently available:
For convenience, we also include a configuration helper function:
cstorch.optim.configure_optimizer(
optimizer_type="...", # name of the optimizer
params=..., # The model parameters
..., # kwargs to be passed into the optimizer class's init
)
This function is useful when you want to initialize an optimizer from some configuration dictionary. An example of its usage:
optimizer_params = {
"optimizer_type": "SGD",
"lr": 0.001,
"momentum": 0.5,
}
optimizer = cstorch.optim.configure_optimizer(
optimizer_type=optimizer_params.pop("optimizer_type"),
params=model.parameters(),
**optimizer_params
)
Defining a custom Cerebras optimizer#
In order to define a Cerebras compliant optimizer, one must create a subclass of
cstorch.optim.Optimizer
, e.g.
class CustomOptimizer(cstorch.optim.Optimizer):
def __init__(self, params, ...):
...
defaults = ...
super().__init__(params, defaults, enable_global_step=...)
...
def preinitialize(self):
...
def step(self, closure=None):
...
def state_names_to_sparsify(self):
...
As can be seen in the above example, similar to torch.optim.Optimizer
, the
base cstorch.optim.Optimizer
class expects 3 arguments. Namely, the model
parameters, the param group defaults as well as optional enable_global_step
which will define a global step state variable for each parameter.
In addition, here are 3 abstract methods that must be overriden:
preinitialize
This method is used to initialize any state variables that will be used by the optimizer. For example,
cstorch.optim.SGD
defines its momentum buffers in itspreinitialize
method.Note, in order to remain Cerebras compliant, no state variables may be initialized outside of the
preinitialize
methodstep
This method is where the optimizer step is implemented. Note, due to the nature of lazy tensor tracing and execution, there may not be any python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as
torch.where
) may be used.Having said this, static structures are allowed. For example a loop with a fixed number of iterations, or a python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.
state_names_to_sparsify
This method should return the names of the state variables that should be sparsified. Please see the existing optimizer implementations for examples.
Initializing the Learning Rate Scheduler#
Similar to the optimizers, the vanilla PyTorch learning rate schedulers are not compatible with traced lazy execution.
As such, we provide our own drop in replacements for some common schedulers
inside cstorch.optim.lr_scheduler
. See the table below for a list of all of
the learning rate schedulers that are currently available.
Similar to cstorch.optim.configure_optimizer
, for convenience, we also
include the following configuration helper function for learning rate schedulers.
cstorch.optim.configure_lr_scheduler(
optimizer=..., # the optimizer object
learning_rate=..., # the learning rate configuration
)
The expected format for the learning_rate
parameter is one of the following
learning_rate
is a python scalar (int
orfloat
)In this case,
configure_lr_scheduler
returns an instance ofConstantLR
with the provided value as the constant learning rate.learning_rate
is a dictionaryIn this case, the dictionary is expected to contain the key
scheduler
which contains the name of the scheduler you want to configure.The rest of the parameters in the dictionary are passed in a keyword arguments to the specified schedulers init method.
learning_rate
is a list of dictionariesIn this case, we assume what is being configured is a
SequentialLR
unless the any one of the dictionaries contains the keymain_scheduler
and the corresponding value isChainedLR
.In either case, each element of the list is expected to be a dictionary that follows the format as outlines in case 2.
If what is being configured is indeed a
SequentialLR
, each dictionary entry is also expected to contain the keytotal_iters
specifying the total number of iterations each scheduler should be applied for.
Defining a custom Cerebras learning rate scheduler#
In order to define a Cerebras compliant learning rate scheduler, one must create
a subclass of cstorch.optim.lr_scheduler.LRScheduler
, e.g.
class CustomScheduler(cstorch.optim.lr_scheduler.LRScheduler):
def __init__(self, optimizer, ...):
...
super().__init__(optimizer, total_iters=..., last_epoch=...)
...
def _get_closed_form_lr(self) -> torch.Tensor:
...
As can be seen in the above example, the base
cstorch.optim.lr_scheduler.LRScheduler
class expects 3 arguments. Namely,
the optimizer whose learning rate is being scheduled, and optionally the total
number of iterations that the scheduler is scheduled for as well as the last
epoch to start on.
In addition, here one abstract method that must be overriden:
_get_closed_form_lr
This method is where the full scheduler is defined in closed form. Note, due to the nature of lazy tensor tracing and execution, there may not be any python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as
torch.where
) may be used.Having said this, static structures are allowed. For example a loop with a fixed number of iterations, or a python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.
This method is expected to return a
torch.Tensor
that represents the full learning rate schedule as a computed tensor.Please see the existing LR scheduler implementations for examples on how to properly define the schedule.
Initializing the DataLoader#
The Cerebras Wafer-Scale cluster makes use of worker nodes to stream data to the system in order to maximize utilization by keeping the input buffers saturated. The workers being their own nodes means that they cannot share a PyTorch dataloader. Hence, we required a mechanism for each worker to be able to initialize their own dataloader.
To facilitate this, we introduce a custom dataloader class:
dataloader = cstorch.utils.data.DataLoader(
input_fn,
num_steps=..., # specifies the number of steps to generate
max_steps=..., # specifies the maximum number of steps to generate
num_epochs=..., # specifies the number of epochs to generate
steps_per_epoch=..., # specifies the number of steps per epoch to generate
..., # kwargs to be passed into the input_fn
)
It takes in a input_fn
parameter which is a callable that takes in some
parameters and returns a torch.utils.data.DataLoader
:
def input_fn(...) -> torch.utils.data.DataLoader:
...
Each worker will call this input function to construct their own dataloader object. This means that some data sharding scheme is required if the intent is for each worker to stream in a unique set of data.
The dataloader class also takes in two mutually exclusive groups of parameters:
num_steps
This specifies the total number of steps to run. If
max_steps
is provided, it will run a max ofmin(num_steps, max_steps)
steps.num_epochs
andsteps_per_epoch
If
num_epochs
is specified, then we usesteps_per_epoch
(or the length of the dataloader if unspecified) to calculate the total number of steps to run.If
max_steps
is provided, it will run a max ofmin(num_epochs * steps_per_epoch, max_steps)
steps.
Any other parameters that are passed into the DataLoader init are forwarded and
passed into the input_fn
.
Using Gradient Scaling#
Gradient scaling can improve convergence when training models with float16 gradients by minimizing gradient underflow. Please see the PyTorch docs for a more detailed explanation.
To facilitate gradient scaling, we introduce a Cerebras implementation of the AMP GradScaler class found in core PyTorch.
grad_scaler = cstorch.amp.GradScaler(
loss_scale=...,
init_scale=...,
steps_per_increase=...,
min_loss_scale=...,
max_loss_scale=...,
overflow_tolerance=...,
max_gradient_norm=...,
)
It is designed to be as similar as possible to the API of the CUDA AMP GradScaler class.
See the below table for a description on each parameter:
Type
Description
loss_scale
str
orfloat
If
loss_scale == "dynamic"
, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling. (Default:0.0
)
init_scale
float
The initial loss scale value if
loss_scale == "dynamic"
(Default:None
)
steps_per_increase
int
The number of steps after which to increase the loss scaling condition (Default:
None
)
min_loss_scale
float
The minimum loss scale value that can be chosen by dynamic loss scaling (Default:
None
)
max_loss_scale
float
The maximum loss scale value that can be chosen by dynamic loss scaling (Default:
None
)
overflow_tolerance
float
The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded (Default:
0.05
)
max_gradient_norm
float
The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect (Default:
0.05
) (Note: Only used in pipeline mode)
Its usage is practically identical to the usage of the CUDA AMP GradScaler:
loss: torch.Tensor = ...
optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)
# Global gradient clipping
torch.nn.utils.clip_grad_norm_(
model.parameters(),
1.0, # max gradient norm
)
# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)
# update the grad scaler once all optimizers have been stepped
grad_scaler.update()
cstorch.amp.optimizer_step
#
We introduce an optional helper function to take care of the details of gradient scaling
cstorch.amp.optimizer_step(
loss,
optimizer,
grad_scaler,
max_gradient_norm=..., # optionally perform gradient clipping by norm
max_gradient_val=..., # optional perform gradient clipping by value
)
It is useful for quickly constructing typical examples that use gradient scaling without needing to type up the details or worry about whether the grad scaler is being used correctly.
This is completely optional and only covers the basic gradient scaler use case. For more complicated use cases, the grad scaler object must be used explicitly.
Constructing the training step#
In order to compile the full training graph, the entire training step must be
captured in its entirety. To handle this we introduce the
cstorch.compile_step
decorator:
@cstorch.compile_step
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(output, targets)
cstorch.amp.optimizer_step(
loss, optimizer, grad_scaler, max_gradient_norm=1.0
)
if lr_scheduler:
lr_scheduler.step()
return loss
This decorator should wrap some function that encapsulates the entirety of a single training iteration. That is to say, everything that is intended to run on a Cerebras system should be inside this wrapped function.
In addition, no tensor value may be eagerly evaluated at any point inside this training step. This means, no tensor is allowed to be printed, fetched via a debugger, or used as part of a python conditional. Any operation that requires knowing the value of tensor inside the training step will result in an error stating that it is not allowed to read a tensor’s contents outside of a step closure.
Another caveat with the compile_step
is that any variables that are not
torch tensors will only see their first value. So, for example, if they have an
int
counter that they increment inside a cstorch.compile_step
wrapper, they
will see the first value on all iterations. This is because the training step
graph is only captured once. Hence, any pure python ops only run once.
Step Closures#
By design, in the execution schema used by the Cerebras Wafer-Scale cluster, the client and the server run asynchronous to each other. This was to prevent the server becoming bottlenecked by any client processes such as disk IO or networking.
However, this means that a computed tensor may not be available to fetch from the server when the client requests it. For example, the call to compile happens in the first iteration of the training loop. Until compile is complete and execution starts on cluster, no tensor is available to fetch.
To handle this, we introduce the concept of a step closure via the
step_closure
decorator:
@cstorch.step_closure
def closure(...):
...
Any tensors that are passed into a “step closure” are fetched from the server and their value is materialized before the closure is actually called. If the tensor is not yet available, it waits until the server “catches up” to the current step and the tensor value is available to be fetched before actually calling the closure.
One caveat regarding values passed into step closures is that the value seen by the step closure is the last value set to that tensor, not the value at point of definition. This means if the tensor is updated inplace after being passed into the step closure, the inplace modified tensor is what gets materialized before being passed into the closure
Saving/Loading Checkpoints#
To save and load weights in a Cerebras run, we provide a custom Cerebras H5 based checkpoint format that is far more performant and efficient compared to the core PyTorch pickle based checkpoint format, especially when it comes to any models with extremely large weights, such as large language models.
To save a checkpoint, we provide a very familiar cstorch.save
function that
you can use in exactly the same way as torch.save
:
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
...
}
cstorch.save(state_dict, "<path to save checkpoint to>")
Similarly, we provide a very familiar cstorch.load
function that can also be
used in exactly the same way as torch.load
:
state_dict = cstorch.load("<path to save checkpoint to>")
model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
...
Note
This is a perfect example of a case where one should use a step closure to make sure that the weights are available to fetch from the server before they can be saved to the checkpoint file.
@cstorch.step_closure
def save_checkpoint():
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
...
}
cstorch.save(state_dict, "<path to save checkpoint to>")
Note
It is only possible to fetch weights on predetermined checkpoint steps
configured in cstorch.configure
. The reason this is so, is to make
training more performant.
For example, if the configuration was checkpoint_steps=100
, you are only
allowed to fetch the weights to take a checkpoint every 100th step and at the
very end on the last step
Converting checkpoints to Pickle-based format#
If you have a checkpoint in the Cerebras H5-based format and wish to use it in a CPU/GPU workflow, it can easily be converted to a PyTorch pickle-based format:
state_dict = cstorch.load("<path to H5 checkpoint>", map_location="cpu")
torch.save(state_dict, "<path to torch checkpoint>")
Note, this will eagerly load the entirety of the checkpoint into memory. Thus, it may cause memory issues when loading checkpoints for very large models.
Training Loop#
The training loop is now fully exposed and customizable. An very basic example of a training loop could be
@cstorch.step_closure
def post_training_step(loss: torch.Tensor):
print("Loss: ", loss.item())
for i, batch in enumerate(dataloader):
loss = training_step(batch)
post_training_step(loss)
if i % checkpoint_steps == 0:
save_checkpoint()
Please see the Full Example (Training) for a comprehensive example of the full training API
Evaluation Metrics#
We provide Cerebras compatible metrics that can be used to during evaluation to measure how well the model has trained.
They are found in the cstorch.metrics
module. See the table below for a list
of all of the metrics that are currently available:
AccuracyMetric(name, compute_on_system)
PerplexityMetric(name, compute_on_system)
All of the metric classes take in a name
parameter as well as whether or not
to compute them on the system. Note, for weight streaming mode,
compute_on_system
must be set to true.
These metrics class keep an internal state and will return the final computed value. Please see the Full Example (Evaluation) to see how these metrics may be used
Full Example (Training)#
Shown below is a simple skeleton of a full training script. For a complete, executable example please see our sample training script.
import torch
import cerebras_pytorch.experimental as cstorch
# Needs to be run at the beginning before model initialization
checkpoint_steps = 100
cstorch.configure(
checkpoint_steps=checkpoint_steps,
mgmt_address=...,
...
)
# user defined model
model: torch.nn.Module = ...
compiled_model = cstorch.compile(model, backend="wse_ws")
loss_fn: torch.nn.Module = ...
optimizer: cstorch.optim.Optimizer = cstorch.optim.configure_optimizer(
optimizer_type="...",
params=model.parameters(),
...
)
lr_scheduler: cstorch.optim.lr_scheduler.LRScheduler = cstorch.optim.configure_lr_scheduler(
optimizer, learning_rate=...,
)
grad_scaler = None
if loss_scale != 0.0:
grad_scaler = cstorch.amp.GradScaler(...)
@cstorch.step_closure
def save_checkpoint(step):
checkpoint_file = f"checkpoint_{step}.mdl"
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if lr_scheduler:
state_dict["lr_scheduler"] = lr_scheduler.state_dict()
if grad_scaler:
state_dict["grad_scaler"] = grad_scaler.state_dict()
state_dict["global_step"] = step
cstorch.save(state_dict, checkpoint_file)
global_step = 0
# Load checkpoint if provided
if checkpoint_path is not None:
state_dict = cstorch.load(checkpoint_path)
model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])
if lr_scheduler:
lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
if grad_scaler:
grad_scaler.load_state_dict(state_dict["grad_scaler"])
global_step = state_dict.get("global_step", 0)
dataloader = cstorch.utils.data.DataLoader(
train_dataloader_fn,
num_steps=1000,
...
)
@cstorch.compile_step
def training_step(batch):
inputs, targets = batch
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
cstorch.amp.optimizer_step(
loss, optimizer, grad_scaler, max_gradient_norm=1.0
)
return loss
@cstorch.step_closure
def post_training_step(loss: torch.Tensor):
print("Loss: ", loss.item())
for i, batch in enumerate(dataloader):
loss = training_step(dataloader)
post_training_step(loss)
# Can only save checkpoint on predetermined steps
if i % checkpoint_steps == 0:
save_checkpoint(i)
Full Example (Evaluation)#
Evaluation is less complex compared to training. There is no optimizer or gradient scaler that needs to be initialized.
Shown below is a simple skeleton of a full training script. For a complete, executable example please see our sample eval script.
import torch
import cerebras_pytorch.experimental as cstorch
import cerebras_pytorch.experimental.metrics as metrics
cstorch.configure(mgmt_address=...)
model: torch.nn.Module = ...
compiled_model = cstorch.compile(model, backend="wse_ws")
compiled_model.eval()
loss_fn: torch.nn.Module = ...
accuracy = metrics.AccuracyMetric("accuracy", compute_on_system=True)
state_dict = cstorch.load("<path to checkpoint file>")
model.load_state_dict(state_dict["model"])
dataloader = cstorch.utils.data.DataLoader(
eval_dataloader_fn,
num_steps=100,
...
)
@cstorch.compile_step
def evaluation_step(batch):
inputs, targets = batch
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
accuracy(
labels=targets.clone(),
predictions=outputs.argmax(-1).int(),
)
return loss
total_loss = 0
total_steps = 0
@cstorch.step_closure
def post_eval_step(loss: torch.Tensor):
global total_loss
total_loss += loss.item()
total_steps += 1
with torch.no_grad():
for batch in dataloader:
loss = evaluation_step(batch)
post_eval_step(loss)
print(f"Eval Accuracy: {float(accuracy)}"))