Port your code using Cerebras PyTorch API (Experimental)#

Note

This API is marked experimental. While, for the most part the design has been finalized, there may be backwards compatibility breaking changes introduced in a future release before it is marked as stable.

High Level Overview#

To leverage the Cerebras PyTorch API for porting and running your model on a Cerebras Wafer-Scale cluster, there are a few high level steps that must be done:

  1. Configure the run

    This configuration involves setting up the cluster for the run. This configure function must be called before the model is instantiated.

  2. Compile the model

    In order to train or evaluate a model on a Cerebras Wafer-Scale cluster it must be compiled.

  3. Instantiate a Cerebras compliant optimizer (training-only)

    Due to the nature of lazy tensor tracing and execution, the core PyTorch optimizer implementations (having been designed for eager execution) are not compatible with running on Cerebras hardware. We provide a number of drop-in Cerebras compliant replacements for commonly used optimizers.

  4. Construct a training step function

    Also due to the nature of lazy tensor tracing and execution, we need to be able to capture the entirety of computation graph. We capture this inside a training/evaluation step.

  5. Define the training loop

    The training loop is fully exposed and customizable now.

  6. Initializing the DataLoader

    The Cerebras dataloader must be initialized to distribute the dataloader across workers

There further steps you can perform to supplement and improve the performance and numerics of your run, such as

  1. Gradient scaling

  2. Learning rate scheduling

There are also some caveats that one should be aware of when running on the Cerebras Wafer-Scale Cluster, such assigned

  1. Step Closures

  2. Saving/Loading Checkpoints

Importing cstorch#

Currently the experimental API can be accessed by importing:

import cerebras_pytorch.experimental as cstorch

Once stable, the experimental modules will be merged into the base of cerebras_pytorch.

Configuring a run#

Configuring a run on the Cerebras Wafer-Scale Cluster using the new Cerebras PyTorch API is done via a call to cstorch.configure:

cstorch.configure(
    model_dir=...,
    compile_dir=...,
    compile_only=...,
    validate_only=...,
    checkpoint_steps=...,
    # CSConfig params
    num_csx=...,
    max_wgt_servers=...,
    mount_dirs=...,
    python_paths=...,
    transfer_processes=...,
    num_workers_per_csx=...,
    job_labels=...,
    max_act_per_csx=...,
)

These are all of the parameters you would need to configure the cluster for the run. Note, most of these parameters are optional and have reasonable defaults. See the below table for a detailed description on each supported parameter and their defaults.

Type

Description

model_dir

str

The path to the directory where model related artifacts get saved on the appliance, not on the user node. (Default: "./")

compile_dir

str

The path to the directory where compile related artifacts get saved on the appliance, not on the user node. (Default: "/opt/cerebras/cached_compile")

compile_only

bool

If True, configure the run to only compile the model, not execute. (Default: False)

validate_only

bool

If True, configure the run to only trace the model, not compile or execute. (Default: False)

checkpoint_steps

int

Set the intervals at which checkpoints will be saved. Note, checkpoints will only be available for saving on these steps as well as the last step (if the value is non-zero). (Default: 0)

num_csx

int

The number of Cerebras systems in the cluster to use in this run. (Default: 1)

max_wgt_servers

int

The maximum number of weight servers to use in this run. (Only applicable for weight streaming) (Default: 24)

mount_dirs

list[str]

A list of paths to local directories that should be mounted to the cluster. (Default: None)

python_paths

list[str]

A list of paths that worker pods respect as PYTHONPATH in addition to the PYTHONPATH set in the container image. (Default: None)

transfer_processes

int

Number of processes to transfer data to/from appliance. (Default: 5)

num_workers_per_csx

int

Number of streaming workers per system. (Default: 1)

job_labels

list[str]

A list of equal-sign-separated key-value pairs that get applied as part of job metadata. (Default: None)

max_act_per_csx

int

Maximum Number of activation servers per system. (Default: 1)

Note

We don’t currently support multiple CS runs in a single process. This means that the above configuration function can only be called once and applies globally. Any runs with different configurations must be run in separate processes.

Compiling a torch.nn.Module#

In order to prepare a model for compilation, we introduce the cstorch.compile function:

model: torch.nn.Module = ...  # Any PyTorch module
compiled_model = cstorch.compile(model, backend="wse_ws")

We designed this function after the torch.compile function that was introduced in PyTorch 2.0.

Note

The call to cstorch.compile, much like torch.compile does not actually compile the model. It only prepares the model for compile. Meaning that it moves the model parameters to the appropriate torch.device and prepares the internals of cstorch to trace it.

The actual compilation does not happen until the first iteration is complete and the batch size is known.

Initializing the Optimizer#

Much like in our previous releases, we cannot support the vanilla PyTorch optimizers. This is because the implementations available in core PyTorch were designed with eager execution in mind and are fundamentally incompatible with traced lazy execution.

As such, we provide our own drop in replacements for these optimizers inside cstorch.optim. See the table below for a list of all of the optimizers that are currently available:

For convenience, we also include a configuration helper function:

cstorch.optim.configure_optimizer(
    optimizer_type="...",  # name of the optimizer
    params=...,  # The model parameters
    ...,  # kwargs to be passed into the optimizer class's init
)

This function is useful when you want to initialize an optimizer from some configuration dictionary. An example of its usage:

optimizer_params = {
    "optimizer_type": "SGD",
    "lr": 0.001,
    "momentum": 0.5,
}
optimizer = cstorch.optim.configure_optimizer(
    optimizer_type=optimizer_params.pop("optimizer_type"),
    params=model.parameters(),
    **optimizer_params
)

Defining a custom Cerebras optimizer#

In order to define a Cerebras compliant optimizer, one must create a subclass of cstorch.optim.Optimizer, e.g.

class CustomOptimizer(cstorch.optim.Optimizer):

    def __init__(self, params, ...):
        ...
        defaults = ...
        super().__init__(params, defaults, enable_global_step=...)

    ...

    def preinitialize(self):
        ...

    def step(self, closure=None):
        ...

    def state_names_to_sparsify(self):
        ...

As can be seen in the above example, similar to torch.optim.Optimizer, the base cstorch.optim.Optimizer class expects 3 arguments. Namely, the model parameters, the param group defaults as well as optional enable_global_step which will define a global step state variable for each parameter.

In addition, here are 3 abstract methods that must be overriden:

  1. preinitialize

    This method is used to initialize any state variables that will be used by the optimizer. For example, cstorch.optim.SGD defines its momentum buffers in its preinitialize method.

    Note, in order to remain Cerebras compliant, no state variables may be initialized outside of the preinitialize method

  2. step

    This method is where the optimizer step is implemented. Note, due to the nature of lazy tensor tracing and execution, there may not be any python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as torch.where) may be used.

    Having said this, static structures are allowed. For example a loop with a fixed number of iterations, or a python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.

  3. state_names_to_sparsify

    This method should return the names of the state variables that should be sparsified. Please see the existing optimizer implementations for examples.

Initializing the Learning Rate Scheduler#

Similar to the optimizers, the vanilla PyTorch learning rate schedulers are not compatible with traced lazy execution.

As such, we provide our own drop in replacements for some common schedulers inside cstorch.optim.lr_scheduler. See the table below for a list of all of the learning rate schedulers that are currently available.

Similar to cstorch.optim.configure_optimizer, for convenience, we also include the following configuration helper function for learning rate schedulers.

cstorch.optim.configure_lr_scheduler(
    optimizer=...,  # the optimizer object
    learning_rate=..., # the learning rate configuration
)

The expected format for the learning_rate parameter is one of the following

  1. learning_rate is a python scalar (int or float)

    In this case, configure_lr_scheduler returns an instance of ConstantLR with the provided value as the constant learning rate.

  2. learning_rate is a dictionary

    In this case, the dictionary is expected to contain the key scheduler which contains the name of the scheduler you want to configure.

    The rest of the parameters in the dictionary are passed in a keyword arguments to the specified schedulers init method.

  3. learning_rate is a list of dictionaries

    In this case, we assume what is being configured is a SequentialLR unless the any one of the dictionaries contains the key main_scheduler and the corresponding value is ChainedLR.

    In either case, each element of the list is expected to be a dictionary that follows the format as outlines in case 2.

    If what is being configured is indeed a SequentialLR, each dictionary entry is also expected to contain the key total_iters specifying the total number of iterations each scheduler should be applied for.

Defining a custom Cerebras learning rate scheduler#

In order to define a Cerebras compliant learning rate scheduler, one must create a subclass of cstorch.optim.lr_scheduler.LRScheduler, e.g.

class CustomScheduler(cstorch.optim.lr_scheduler.LRScheduler):

    def __init__(self, optimizer, ...):
        ...
        super().__init__(optimizer, total_iters=..., last_epoch=...)

    ...

    def _get_closed_form_lr(self) -> torch.Tensor:
        ...

As can be seen in the above example, the base cstorch.optim.lr_scheduler.LRScheduler class expects 3 arguments. Namely, the optimizer whose learning rate is being scheduled, and optionally the total number of iterations that the scheduler is scheduled for as well as the last epoch to start on.

In addition, here one abstract method that must be overriden:

  1. _get_closed_form_lr

    This method is where the full scheduler is defined in closed form. Note, due to the nature of lazy tensor tracing and execution, there may not be any python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as torch.where) may be used.

    Having said this, static structures are allowed. For example a loop with a fixed number of iterations, or a python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.

    This method is expected to return a torch.Tensor that represents the full learning rate schedule as a computed tensor.

    Please see the existing LR scheduler implementations for examples on how to properly define the schedule.

Initializing the DataLoader#

The Cerebras Wafer-Scale cluster makes use of worker nodes to stream data to the system in order to maximize utilization by keeping the input buffers saturated. The workers being their own nodes means that they cannot share a PyTorch dataloader. Hence, we required a mechanism for each worker to be able to initialize their own dataloader.

To facilitate this, we introduce a custom dataloader class:

dataloader = cstorch.utils.data.DataLoader(
    input_fn,
    num_steps=...,  # specifies the number of steps to generate
    max_steps=...,  # specifies the maximum number of steps to generate
    num_epochs=...,  # specifies the number of epochs to generate
    steps_per_epoch=...,  # specifies the number of steps per epoch to generate
    ...,  # kwargs to be passed into the input_fn
)

It takes in a input_fn parameter which is a callable that takes in some parameters and returns a torch.utils.data.DataLoader:

def input_fn(...) -> torch.utils.data.DataLoader:
    ...

Each worker will call this input function to construct their own dataloader object. This means that some data sharding scheme is required if the intent is for each worker to stream in a unique set of data.

The dataloader class also takes in two mutually exclusive groups of parameters:

  1. num_steps

    This specifies the total number of steps to run. If max_steps is provided, it will run a max of min(num_steps, max_steps) steps.

  2. num_epochs and steps_per_epoch

    If num_epochs is specified, then we use steps_per_epoch (or the length of the dataloader if unspecified) to calculate the total number of steps to run.

    If max_steps is provided, it will run a max of min(num_epochs * steps_per_epoch, max_steps) steps.

Any other parameters that are passed into the DataLoader init are forwarded and passed into the input_fn.

Using Gradient Scaling#

Gradient scaling can improve convergence when training models with float16 gradients by minimizing gradient underflow. Please see the PyTorch docs for a more detailed explanation.

To facilitate gradient scaling, we introduce a Cerebras implementation of the AMP GradScaler class found in core PyTorch.

grad_scaler = cstorch.amp.GradScaler(
    loss_scale=...,
    init_scale=...,
    steps_per_increase=...,
    min_loss_scale=...,
    max_loss_scale=...,
    overflow_tolerance=...,
    max_gradient_norm=...,
)

It is designed to be as similar as possible to the API of the CUDA AMP GradScaler class.

See the below table for a description on each parameter:

Type

Description

loss_scale

str or float

If loss_scale == "dynamic", then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling. (Default: 0.0)

init_scale

float

The initial loss scale value if loss_scale == "dynamic" (Default: None)

steps_per_increase

int

The number of steps after which to increase the loss scaling condition (Default: None)

min_loss_scale

float

The minimum loss scale value that can be chosen by dynamic loss scaling (Default: None)

max_loss_scale

float

The maximum loss scale value that can be chosen by dynamic loss scaling (Default: None)

overflow_tolerance

float

The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded (Default: 0.05)

max_gradient_norm

float

The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect (Default: 0.05) (Note: Only used in pipeline mode)

Its usage is practically identical to the usage of the CUDA AMP GradScaler:

loss: torch.Tensor = ...

optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)

# Global gradient clipping
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    1.0,  # max gradient norm
)

# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)

# update the grad scaler once all optimizers have been stepped
grad_scaler.update()

cstorch.amp.optimizer_step#

We introduce an optional helper function to take care of the details of gradient scaling

cstorch.amp.optimizer_step(
    loss,
    optimizer,
    grad_scaler,
    max_gradient_norm=...,  # optionally perform gradient clipping by norm
    max_gradient_val=...,  # optional perform gradient clipping by value
)

It is useful for quickly constructing typical examples that use gradient scaling without needing to type up the details or worry about whether the grad scaler is being used correctly.

This is completely optional and only covers the basic gradient scaler use case. For more complicated use cases, the grad scaler object must be used explicitly.

Constructing the training step#

In order to compile the full training graph, the entire training step must be captured in its entirety. To handle this we introduce the cstorch.compile_step decorator:

@cstorch.compile_step
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(output, targets)

    cstorch.amp.optimizer_step(
        loss, optimizer, grad_scaler, max_gradient_norm=1.0
    )

    if lr_scheduler:
        lr_scheduler.step()

    return loss

This decorator should wrap some function that encapsulates the entirety of a single training iteration. That is to say, everything that is intended to run on a Cerebras system should be inside this wrapped function.

In addition, no tensor value may be eagerly evaluated at any point inside this training step. This means, no tensor is allowed to be printed, fetched via a debugger, or used as part of a python conditional. Any operation that requires knowing the value of tensor inside the training step will result in an error stating that it is not allowed to read a tensor’s contents outside of a step closure.

Another caveat with the compile_step is that any variables that are not torch tensors will only see their first value. So, for example, if they have an int counter that they increment inside a cstorch.compile_step wrapper, they will see the first value on all iterations. This is because the training step graph is only captured once. Hence, any pure python ops only run once.

Step Closures#

By design, in the execution schema used by the Cerebras Wafer-Scale cluster, the client and the server run asynchronous to each other. This was to prevent the server becoming bottlenecked by any client processes such as disk IO or networking.

However, this means that a computed tensor may not be available to fetch from the server when the client requests it. For example, the call to compile happens in the first iteration of the training loop. Until compile is complete and execution starts on cluster, no tensor is available to fetch.

To handle this, we introduce the concept of a step closure via the step_closure decorator:

@cstorch.step_closure
def closure(...):
    ...

Any tensors that are passed into a “step closure” are fetched from the server and their value is materialized before the closure is actually called. If the tensor is not yet available, it waits until the server “catches up” to the current step and the tensor value is available to be fetched before actually calling the closure.

One caveat regarding values passed into step closures is that the value seen by the step closure is the last value set to that tensor, not the value at point of definition. This means if the tensor is updated inplace after being passed into the step closure, the inplace modified tensor is what gets materialized before being passed into the closure

Saving/Loading Checkpoints#

To save and load weights in a Cerebras run, we provide a custom Cerebras H5 based checkpoint format that is far more performant and efficient compared to the core PyTorch pickle based checkpoint format, especially when it comes to any models with extremely large weights, such as large language models.

To save a checkpoint, we provide a very familiar cstorch.save function that you can use in exactly the same way as torch.save:

state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    ...
}
cstorch.save(state_dict, "<path to save checkpoint to>")

Similarly, we provide a very familiar cstorch.load function that can also be used in exactly the same way as torch.load:

state_dict = cstorch.load("<path to save checkpoint to>")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
...

Note

This is a perfect example of a case where one should use a step closure to make sure that the weights are available to fetch from the server before they can be saved to the checkpoint file.

@cstorch.step_closure
def save_checkpoint():
    state_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        ...
    }
    cstorch.save(state_dict, "<path to save checkpoint to>")

Note

It is only possible to fetch weights on predetermined checkpoint steps configured in cstorch.configure. The reason this is so, is to make training more performant.

For example, if the configuration was checkpoint_steps=100, you are only allowed to fetch the weights to take a checkpoint every 100th step and at the very end on the last step

Converting checkpoints to Pickle-based format#

If you have a checkpoint in the Cerebras H5-based format and wish to use it in a CPU/GPU workflow, it can easily be converted to a PyTorch pickle-based format:

state_dict = cstorch.load("<path to H5 checkpoint>", map_location="cpu")
torch.save(state_dict, "<path to torch checkpoint>")

Note, this will eagerly load the entirety of the checkpoint into memory. Thus, it may cause memory issues when loading checkpoints for very large models.

Training Loop#

The training loop is now fully exposed and customizable. An very basic example of a training loop could be

@cstorch.step_closure
def post_training_step(loss: torch.Tensor):
    print("Loss: ", loss.item())

for i, batch in enumerate(dataloader):
    loss = training_step(batch)

    post_training_step(loss)

    if i % checkpoint_steps == 0:
        save_checkpoint()

Please see the Full Example (Training) for a comprehensive example of the full training API

Evaluation Metrics#

We provide Cerebras compatible metrics that can be used to during evaluation to measure how well the model has trained.

They are found in the cstorch.metrics module. See the table below for a list of all of the metrics that are currently available:

AccuracyMetric(name, compute_on_system)

PerplexityMetric(name, compute_on_system)

All of the metric classes take in a name parameter as well as whether or not to compute them on the system. Note, for weight streaming mode, compute_on_system must be set to true.

These metrics class keep an internal state and will return the final computed value. Please see the Full Example (Evaluation) to see how these metrics may be used

Full Example (Training)#

Shown below is a simple skeleton of a full training script. For a complete, executable example please see our sample training script.

import torch
import cerebras_pytorch.experimental as cstorch

# Needs to be run at the beginning before model initialization
checkpoint_steps = 100
cstorch.configure(
    checkpoint_steps=checkpoint_steps,
    mgmt_address=...,
    ...
)

# user defined model
model: torch.nn.Module = ...
compiled_model = cstorch.compile(model, backend="wse_ws")

loss_fn: torch.nn.Module = ...

optimizer: cstorch.optim.Optimizer = cstorch.optim.configure_optimizer(
    optimizer_type="...",
    params=model.parameters(),
    ...
)
lr_scheduler: cstorch.optim.lr_scheduler.LRScheduler = cstorch.optim.configure_lr_scheduler(
    optimizer, learning_rate=...,
)

grad_scaler = None
if loss_scale != 0.0:
    grad_scaler = cstorch.amp.GradScaler(...)

@cstorch.step_closure
def save_checkpoint(step):
    checkpoint_file = f"checkpoint_{step}.mdl"

    state_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    if lr_scheduler:
        state_dict["lr_scheduler"] = lr_scheduler.state_dict()
    if grad_scaler:
        state_dict["grad_scaler"] = grad_scaler.state_dict()

    state_dict["global_step"] = step

    cstorch.save(state_dict, checkpoint_file)

global_step = 0

# Load checkpoint if provided
if checkpoint_path is not None:
    state_dict = cstorch.load(checkpoint_path)

    model.load_state_dict(state_dict["model"])
    optimizer.load_state_dict(state_dict["optimizer"])
    if lr_scheduler:
        lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
    if grad_scaler:
        grad_scaler.load_state_dict(state_dict["grad_scaler"])

    global_step = state_dict.get("global_step", 0)

dataloader = cstorch.utils.data.DataLoader(
    train_dataloader_fn,
    num_steps=1000,
    ...
)

@cstorch.compile_step
def training_step(batch):
    inputs, targets = batch
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    cstorch.amp.optimizer_step(
        loss, optimizer, grad_scaler, max_gradient_norm=1.0
    )

    return loss

@cstorch.step_closure
def post_training_step(loss: torch.Tensor):
    print("Loss: ", loss.item())

for i, batch in enumerate(dataloader):
    loss = training_step(dataloader)

    post_training_step(loss)

    # Can only save checkpoint on predetermined steps
    if i % checkpoint_steps == 0:
        save_checkpoint(i)

Full Example (Evaluation)#

Evaluation is less complex compared to training. There is no optimizer or gradient scaler that needs to be initialized.

Shown below is a simple skeleton of a full training script. For a complete, executable example please see our sample eval script.

import torch
import cerebras_pytorch.experimental as cstorch
import cerebras_pytorch.experimental.metrics as metrics


cstorch.configure(mgmt_address=...)

model: torch.nn.Module = ...
compiled_model = cstorch.compile(model, backend="wse_ws")
compiled_model.eval()

loss_fn: torch.nn.Module = ...

accuracy = metrics.AccuracyMetric("accuracy", compute_on_system=True)

state_dict = cstorch.load("<path to checkpoint file>")
model.load_state_dict(state_dict["model"])

dataloader = cstorch.utils.data.DataLoader(
    eval_dataloader_fn,
    num_steps=100,
    ...
)

@cstorch.compile_step
def evaluation_step(batch):
    inputs, targets = batch
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    accuracy(
        labels=targets.clone(),
        predictions=outputs.argmax(-1).int(),
    )

    return loss

total_loss = 0
total_steps = 0

@cstorch.step_closure
def post_eval_step(loss: torch.Tensor):
    global total_loss

    total_loss += loss.item()
    total_steps += 1


with torch.no_grad():
    for batch in dataloader:
        loss = evaluation_step(batch)

        post_eval_step(loss)

print(f"Eval Accuracy: {float(accuracy)}"))