tf.AbstractRecomputeWrapper module

tf.AbstractRecomputeWrapper module

class tf.AbstractRecomputeWrapper.AbstractRecomputeWrapper

Bases: abc.ABC

Utility functions for the decorator tf.custom_gradient, when used in training.

An abstract class to handle many small requirements when using the decorator tf.custom_gradient. This class is used to recompute the activations during the backward propagation part of a training step. This code acts as a backbone for recompute wrappers and reversible layers.

The following utility functions are designed to make it easy to implement the recomputation:

  • _set_recomputed_tensor and _check_get_recomputed_tensor.

    These functions to attach the recomputed tensors to the corresponding forward pass tensors. These functions are useful for passing the recomputed tensors between, for example, reversible layers, so that we do not need to save any tensors.

  • _block_recompute_and_gradients.

    This function takes a forward block of the computation, recomputes the block, and then calculates and returns the gradients associated with the block.

  • Scope handling functions

    • tf.custom_gradient.

      This structure names the scopes of the gradients. However, this naming is based on the IdentityN ops it attaches to the portion of the graph for which the user would like to add a custom gradient. This is not always convenient. Moreover, the tf.custom_gradient does not track the appropriate control flow contexts for the variables used in that portion of the graph. The scope handling functions in this class are helpful here.

    • _get_clean_grad_scope

      This function cleans the named scope for clean graphs.

    • _update_variables_for_context

      This function finds the correct variable tensors for the control flow contexts: for example, to use recomputation inside a while-loop).

The basic structure for a recompute layer is as follows:

  • Define a custom gradient function using tf.custom_gradient inside the __call__ function of a recompute layer.

  • Inside the __call__ function, call the forward propagation of the layer and define the recompute+gradient function. We recommend you use the _block_recompute_and_gradients function).

CtrlFlowWarnedOnce = False
abstract call(*args, **kwargs)

The call function for the layers that use recomputation during backward phase.

This function is wrapped by the __call__ function of this abstract recompute wrapper, and it must be overridden by a child class to implement the forward computation of the layer.

static is_in_while_loop(graph=None)

Returns True if the specified, or current if unspecified, graph corresponds to a while loop in the forward, backward or cond graph.

Returns

True if the specified, or current if unspecified, graph corresponds to a while loop in the forward, backward or cond graph.

Return type

bool