Static graphs#

As of the 2.1.0 software release, we do not officially support reprogramming the Cerebras Wafer-Scale cluster after initial programming. This means that multiple compiles are not supported and therefore, the PyTorch compute graph must not change between iterations.

The way to define a training/evaluation is by decorating a function using cerebras.pytorch.trace.

For example:

loss_fn = torch.nn.CrossEntropyLoss()

def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)


    return loss

By default, the training_step function is only ever traced a single time. This means that the contents of the traced function must represent a static computation graph. If there are any conditionals, the branch that is encountered in the first iteration is what gets encoded into the graph. If there are any loops, the loops get unrolled according to the number of times the loop ran in the first iteration.

In addition, any other side effects, such as print statements and changes to python scalars, will only happen once when the function is being traced.

Retracing every Iteration#

There is an option to enable retracing every iteration. The way to do this is to specify the retrace_every_iteration flag while constructing the backend to enable retracing.

For example:

backend = cstorch.backend("CSX", ..., retrace_every_iteration=True)

Setting this flag to True means that the function decorated with cerebras.pytorch.trace will be traced every single iteration. The benefit to retracing every iteration is that side effects such as print statements and changes to python scalars will happen at every iteration now.

It is important to note that dynamic graph logic will still not be captured. Python conditionals will be resolved at trace time and python loops will be unrolled. If the computation graph has changed in any way between iterations, then a compile error will be thrown.


Retracing every iteration can have a performance impact for smaller models where tracing time outweights the time it takes to execute the model.

For larger models where execution time is significantly larger, retracing time should be negligible.