The CerebrasEstimator Interface#

The CerebrasEstimator is the interface you use to train your neural network models on the CS system. The CerebrasEstimator inherits from TensorFlow Estimator.

Run the same code on CS system, GPU or CPU

When you use the CerebrasEstimator in your TensorFlow code, you can use the same code, without any change, to run your neural network training on either CS system, or GPU or CPU.


You construct the CerebrasEstimator with the following functions and arguments:

est = CerebrasEstimator(

est.train(input_fn, steps=100000)



  • model_fn: Required. The model function that defines your neural network model.

  • model_dir: String. Optional. Same as the model_dir passed to the tf.estimator. The location where your model and all the outputs such as checkpoints and event files are stored. Default value is None. See also tf.estimator.Estimator.

  • compile_dir: String. Optional. The location where the compilation will be executed. This is also the directory where the compilation results are stored. Default value is None.

  • config: Optional. A CSRunConfig object specifying the runtime configuration options. These options are an extension of TensorFlow RunConfig. Default value is None.

  • params: Dictionary. Optional. A params dictionary that contains additional configuration information for model_fn and input_fn. Default value is None.

  • warm_start_from: Dictionary. Optional. A dictionary specifying the initial weights that will be used. Typically set to None when running the training on the CS system. Default value is None.


The use_cs parameter in the CerebrasEstimator API is removed in version 0.7.0 of the CGC, and will result in compiler error if used in this API. The target hardware will now be automatically determined from a combination of the runtime configuration parameter cs_ip and the use_cs parameter setting in the method definitions for train.


  • train: Method, when invoked, will call the model_fn and input_fn functions to execute the neural network model training.

After constructing the CerebrasEstimator, you can run the training by calling estimator.train().


The method estimator.train() is overriden by Cerebras to handle the compilation and execution of the model on the CS system.