Using the CerebrasEstimator

The CerebrasEstimator is a critical part of your main Python program when running on the CS system. It is the CerebrasEstimator that launches the Cerebras Graph Compiler (CGC) when its methods such as compile, or train are called while providing the IP address of the CS system with cs_ip. See also The CerebrasEstimator Interface.

In this section, an example template is used to show how the CerebrasEstimator interacts with the key code segments of your Python program.


For a detailed description of the example template, see The Template.

Shown below is a highly simplified example code that is used for neural network training:

 1 # Example script for neural network training
 2 from import CerebrasEstimator
 3 from import CSRunConfig
 4 from import CSSlurmClusterResolver
 6 def model_fn(features, labels, mode, params):
 8     ...
10     return spec
12 def input_fn(params):
14     ...
16     return dataset
18 config = CSRunConfig(
19     cs_ip=ip,
20     save_checkpoints_steps=1000,
21     log_step_count_steps=10000,
22     "use_cbfloat16": True )
23 params ={
24     "batch_size":32,
25     "lr":0.1,
26     "use_cbfloat16": True
27 }
29 est = CerebrasEstimator(
30     model_fn,
31     config=config,
32     params=params,
33     model_dir='./out',
34     use_cs=True
35 )
37 est.train(input_fn, steps=100000)

Calling the CerebrasEstimator

In the est=CerebrasEstimator(...) call (line 29), the model_fn argument is a callback function. When the CerebrasEstimator receives this argument, the CerebrasEstimator API waits until one of its methods, train, is invoked.


The model_fn argument to the CerebrasEstimator interface is passed without the ().

Callback input function

  1. The est.train (input_fn, steps=100000) (line 37) is a train method call to the CerebrasEstimator with input_fn argument as a callback function. The CerebrasEstimator then calls the input_fn with the params argument.


    The input_fn argument to the train method is passed without the ().

    Both the CerebrasEstimator and TensorFlow Estimator API expect the input function to:

    • Accept a standard group of input parameters with the argument params and

    • Returns a that yields tensor pairs in the predefined format: tensor with features and tensor with labeles.

  2. Any params passed to the CerebrasEstimator are passed on to the input_fn and to the model_fn. when the CerebrasEstimator calls the input_fn.

    The input_fn should return a (see Dataset API for documentation).

  3. The input function builds the input pipeline and yields the batched data in the form of (features, labels) pairs, where:

    • features can be a tensor or dictionary of tensors, and

    • labels can be a tensor, a dictionary of tensors or None.


def input_fn(params):
    ds = ds.shuffle(buffer_size)
    ds = ds.repeat()
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(buffer_size)
    return ds

Callback model function

The model function model_fn is used to generate the graph for your neural network model.

  1. The features and labels, the two arguments returned from the input_fn, are the handles to the batched data that your model will use. When these two arguments, features and labels, are returned from the input_fn, the CerebrasEstimator will then call the model_fn by passing the following arguments to the model_fn:

    • The mode argument that indicates whether the caller is requesting training.

    • The params object that was passed in the est=CerebrasEstimator(...) call.


    The functions input_fn and the model_fn are called by the CerebrasEstimator as these two are passed to the CerebrasEstimator as callback functions. You should not directly call either of these two functions in your TensorFlow code.

Both the CerebrasEstimator and TensorFlow Estimator API expect the model function to accept a standard group of input parameters and return a standard group of output values.

Currently, the CerebrasEstimator supports usage of the Tensorflow Keras Layers API in the model function. However, the Tensorflow Metrics API is not supported.


def model_fn(
    features,  # This is batch_features from input_fn
    labels,    # This is batch_labels from input_fn
    mode,      # An instance of tf.estimator.ModeKeys
    params     # Additional configuration


See below an example of model_fn definition.

def model_fn(features, labels, mode=tf.estimator.ModeKeys.TRAIN, params=None):
    """ Model definition """
    logits = build_model(features, params)
    learning_rate = tf.constant(params["lr"])
    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
        loss_op = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits) )
        train_op = tf.train.GradientDescentOptimizer(learning_rate=learning_rate ).minimize(loss_op, global_step=tf.train.get_global_step())
        spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss_op, train_op=train_op)
        return spec

Setting the runtime configuration

Runtime and environment options can be set. Usually this is the information that is not captured in the model_fn and input_fn. Use the CSRunConfig object to set these Cerebras-specific options. These options are an extension of TensorFlow RunConfig.


Make sure to add the following import statement to your Slurm-orchestrated TensorFlow code so that Slurm cluster resolving is done automatically.

from import CSSlurmClusterResolver


The Cerebras CSRunConfig class inherits from the standard TensorFlow RunConfig class. You can pass to the CSRunConfig the same parameters as those of the Tensorflow RunConfig, and also pass additional parameters that specify the configurations for a CerebrasEstimator run, including the IP address of the CS system. Such additional parameters include:

  • cs_ip: IP address of the CS system, provided by Cerebras.

  • system_name: Name of the CS system.

The full list of options for TensorFlow RunConfig can be found here.


from import CSRunConfig
from import CSSlurmClusterResolver

config = CSRunConfig(