Keras Model to CerebrasEstimator#

The KerasModelToCerebrasEstimator function should be used to convert your Keras model to the CerebrasEstimator. This section describes the KerasModelToCerebrasEstimator function. For more on CerebrasEstimator, see The CerebrasEstimator Interface.

Using the KerasModelToCerebrasEstimator#

The KerasModelToCerebrasEstimator is a wrapper that converts the Keras model so the model can be run using the CerebrasEstimator.


Make sure that you use mixed precision by specifying mixed_float16 while using the KerasModelToCerebrasEstimator wrapper.


The following example code shows how you can construct the KerasModelToCerebrasEstimator:

from import KerasModelToCerebrasEstimator, KerasModelToModelFn


dtype = tf.keras.mixed_precision.experimental.Policy(
      'mixed_float16', # Important: This is required.



estimator = KerasModelToCerebrasEstimator(
               config: CSRunConfig=None,



  • model_fn: Function. Required. The Keras model function.

  • model_dir: String. Optional. Same as the model_dir passed to the tf.estimator. The location where your model and all the outputs such as checkpoints, summaries and event files are stored. Default value is None. See also tf.estimator.Estimator.

  • compile_dir: String. Optional. The directory where the compilation results are stored, and from which the compilation outputs are reloaded. Default value is model_dir.

  • config: CSRunConfig. Optional. A CSRunConfig object specifying the environment and runtime configuration options. These options are an extension of TensorFlow RunConfig. Default value is None.

  • params: Dictionary. Optional. A parameters dictionary that contains additional configuration information that will be passed to model_fn and input_fn. Default value is None.

After converting to the CerebrasEstimator, you can run the compile on the model by calling the CerebrasEstimator method: estimator.compile().