On This Page

Early Stopping

On This Page

Early Stopping

Using a custom hook called CerebrasEarlyStoppingHook you can terminate early a neural network training based on some logic. This hook is similar to the Keras EarlyStopping class. The CerebrasEarlyStoppingHook can be used in Tensorflow either on the CS system or on a CPU.

Important

Early stopping with CerebrasEarlyStoppingHook is currently supported only on the data accessible by the model_fn. This means that if you are running training, CerebrasEarlyStoppingHook will only compute the stopping condition based on the data provided for the training run. If you are running evaluation, CerebrasEarlyStoppingHook will only compute the stopping condition based on the validation data.

Example

See the following Tensorflow example.

def acc_early_stop(logits, labels):
    train_acc = tf.compat.v1.metrics.accuracy(
        tf.argmax(labels, 1), tf.argmax(logits, 1)
    )
    # Return True if training accuracy is greater than 90%.
return tf.math.greater(train_acc[0], tf.constant(0.9))

def loss_early_stop(loss, threshold):
    # Return True if training loss is lower than threshold.
    return tf.math.less(train_acc, tf.constant(threshold))

def model_fn(features, labels, mode, params):
    ...
    # Specify the model.
    ...
    training_hooks = [
        # Check acc_early_stop every 1000th iteration and stop training if True.
        CerebrasEarlyStoppingHook(acc_early_stop, [logits, labels], every_n_iter=1000),
        # Check loss_early_stop every 500th iteration and stop training if True.
        CerebrasEarlyStoppingHook(loss_early_stop, [loss, 0.01], every_n_iter=500)
    ]
    ...
    spec = CSEstimatorSpec(
        ...
        training_hooks=training_hooks
        ...
    )
    return spec

In the above example, the function acc_early_stop returns True if the training accuracy is greater than 90%, and the function loss_early_stop returns True if the training loss is lower than the threshold argument.

The first CerebrasEarlyStoppingHook in the training_hooks list evaluates the acc_early_stop function once every 1000 iterations. If acc_early_stop function evaluates to True, then training is stopped. If the training accuracy is not greater than 90% then acc_early_stop function is evaluated at the next 1000th iteration.

Similarly the second CerebrasEarlyStoppingHook evaluates loss_early_stop function every 500th iteration and stops the training if True.

Note

A function like acc_early_stop or loss_early_stop must return a 0 rank Boolean tensor. There are no other restrictions on the computation that occurs inside such a function. This function runs on the host.