The CerebrasEstimator Interface#
CerebrasEstimator is the interface you use to train your neural network models on the CS system. The
CerebrasEstimator inherits from TensorFlow Estimator.
Run the same code on CS system, GPU or CPU
When you use the
CerebrasEstimator in your TensorFlow code, you can use the same code, without any change, to run your neural network training on either CS system, or GPU or CPU.
You construct the
CerebrasEstimator with the following functions and arguments:
est = CerebrasEstimator( model_fn, model_dir=None, compile_dir=None, config=None, params=None, warm_start_from=None, ) est.train(input_fn, steps=100000)
model_fn: Required. The model function that defines your neural network model.
model_dir: String. Optional. Same as the
model_dirpassed to the
tf.estimator. The location where your model and all the outputs such as checkpoints and event files are stored. Default value is
None. See also tf.estimator.Estimator.
compile_dir: String. Optional. The location where the compilation will be executed. This is also the directory where the compilation results are stored. Default value is
config: Optional. A
CSRunConfigobject specifying the runtime configuration options. These options are an extension of TensorFlow RunConfig. Default value is
params: Dictionary. Optional. A params dictionary that contains additional configuration information for
input_fn. Default value is
warm_start_from: Dictionary. Optional. A dictionary specifying the initial weights that will be used. Typically set to
Nonewhen running the training on the CS system. Default value is
use_cs parameter in the CerebrasEstimator API is removed in version 0.7.0 of the CGC, and will result in compiler error if used in this API. The target hardware will now be automatically determined from a combination of the runtime configuration parameter
cs_ip and the
use_cs parameter setting in the method definitions for
train: Method, when invoked, will call the
input_fnfunctions to execute the neural network model training.
After constructing the
CerebrasEstimator, you can run the training by calling
estimator.train() is overriden by Cerebras to handle the compilation and execution of the model on the CS system.