Keras Model to CerebrasEstimator
On This Page
Keras Model to CerebrasEstimator¶
The KerasModelToCerebrasEstimator
function should be used to convert your Keras model to the CerebrasEstimator
. This section describes the KerasModelToCerebrasEstimator
function. For more on CerebrasEstimator
, see The CerebrasEstimator Interface.
Using the KerasModelToCerebrasEstimator¶
The KerasModelToCerebrasEstimator
is a wrapper that converts the Keras model so the model can be run using the CerebrasEstimator
.
Important
Make sure that you use mixed precision by specifying mixed_float16
while using the KerasModelToCerebrasEstimator
wrapper.
Example¶
The following example code shows how you can construct the KerasModelToCerebrasEstimator
:
from cerebras.tf.cs_model_to_estimator import KerasModelToCerebrasEstimator, KerasModelToModelFn
...
dtype = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', # Important: This is required.
loss_scale=None
)
tf.keras.mixed_precision.experimental.set_policy(dtype)
...
estimator = KerasModelToCerebrasEstimator(
model_fn=model_fn,
model_dir=None,
compile_dir=None,
config: CSRunConfig=None,
params=None,
)
estimator.compile(
input_fn=input_fn,
)
where:
model_fn
: Function. Required. The Keras model function.model_dir
: String. Optional. Same as themodel_dir
passed to thetf.estimator
. The location where your model and all the outputs such as checkpoints, summaries and event files are stored. Default value isNone
. See also tf.estimator.Estimator.compile_dir
: String. Optional. The directory where the compilation results are stored, and from which the compilation outputs are reloaded. Default value ismodel_dir
.config
: CSRunConfig. Optional. ACSRunConfig
object specifying the environment and runtime configuration options. These options are an extension of TensorFlow RunConfig. Default value isNone
.params
: Dictionary. Optional. A parameters dictionary that contains additional configuration information that will be passed tomodel_fn
andinput_fn
. Default value isNone
.
After converting to the CerebrasEstimator
, you can run the compile on the model by calling the CerebrasEstimator
method: estimator.compile()
.