The CerebrasEstimator Interface
On This Page
The CerebrasEstimator Interface¶
The CerebrasEstimator
is the interface you use to train your neural network models on the CS system. The CerebrasEstimator
inherits from TensorFlow Estimator.
Run the same code on CS system, GPU or CPU
When you use the CerebrasEstimator
in your TensorFlow code, you can use the same code, without any change, to run your neural network training on either CS system, or GPU or CPU.
Syntax¶
You construct the CerebrasEstimator
with the following functions and arguments:
est = CerebrasEstimator(
model_fn,
model_dir=None,
compile_dir=None,
config=None,
params=None,
warm_start_from=None,
)
est.train(input_fn, steps=100000)
where:
Arguments¶
model_fn
: Required. The model function that defines your neural network model.model_dir
: String. Optional. Same as themodel_dir
passed to thetf.estimator
. The location where your model and all the outputs such as checkpoints and event files are stored. Default value isNone
. See also tf.estimator.Estimator.compile_dir
: String. Optional. The location where the compilation will be executed. This is also the directory where the compilation results are stored. Default value isNone
.config
: Optional. ACSRunConfig
object specifying the runtime configuration options. These options are an extension of TensorFlow RunConfig. Default value isNone
.params
: Dictionary. Optional. A params dictionary that contains additional configuration information formodel_fn
andinput_fn
. Default value isNone
.warm_start_from
: Dictionary. Optional. A dictionary specifying the initial weights that will be used. Typically set toNone
when running the training on the CS system. Default value isNone
.
Important
The use_cs
parameter in the CerebrasEstimator API is removed in version 0.7.0 of the CGC, and will result in compiler error if used in this API. The target hardware will now be automatically determined from a combination of the runtime configuration parameter cs_ip
and the use_cs
parameter setting in the method definitions for train
.
Methods¶
train
: Method, when invoked, will call themodel_fn
andinput_fn
functions to execute the neural network model training.
After constructing the CerebrasEstimator
, you can run the training by calling estimator.train()
.
Note
The method estimator.train()
is overriden by Cerebras to handle the compilation and execution of the model on the CS system.