Evaluation Metrics#

We provide Cerebras-compatible metrics that can be used during evaluation to measure how well the model has been trained.

These metrics can be found in the metrics.Metric module.

For example:

model = Model()
compiled_model = cstorch.compile(model, backend)

accuracy = cstorch.metrics.AccuracyMetric("accuracy")

def eval_step(batch):
    inputs, targets = batch
    outputs = compiled_model(inputs)


for batch in executor:

# Log accumulated eval metric
print(f"Accuracy: {float(accuracy)}")

Writing Custom Metrics#

To define a Cerebras compliant metrics, create a subclass of cerebras_pytorch.metrics.Metric.

For example,

class CustomMetric(cstorch.metrics.Metric):

    def __init__(self, name: str):


    def reset(self):

    def update(self, *args, **kwargs):

    def compute(self):

As can be seen in the above example, the base Metric class expects one argument. Namely, the metric name.

In addition, there are three abstract methods that must be overridden:

  • reset

    This method resets (or defines if its the first time its called) the metrics’ internal state.

    States can be registered via calls to register_state

  • update

    This method is used to update the metric’s registered states.

    Note that to remain Cerebras compliant, no tensor may be evaluated/inspected here. The update call is intended to be fully traced.

  • compute

    This method is used to compute the final accumulated metric value using the state that was updated in update