tf.layers.AbstractRecomputeWrapper module
tf.layers.AbstractRecomputeWrapper module#
- class tf.layers.AbstractRecomputeWrapper.AbstractRecomputeWrapper#
Bases:
abc.ABC
Utility functions for the decorator tf.custom_gradient, when used in training.
An abstract class to handle many small requirements when using the decorator tf.custom_gradient. This class is used to recompute the activations during the backward propagation part of a training step. This code acts as a backbone for recompute wrappers and reversible layers.
The following utility functions are designed to make it easy to implement the recomputation:
_set_recomputed_tensor
and_check_get_recomputed_tensor
.These functions to attach the recomputed tensors to the corresponding forward pass tensors. These functions are useful for passing the recomputed tensors between, for example, reversible layers, so that we do not need to save any tensors.
_block_recompute_and_gradients
.This function takes a forward block of the computation, recomputes the block, and then calculates and returns the gradients associated with the block.
Scope handling functions
tf.custom_gradient
.This structure names the scopes of the gradients. However, this naming is based on the
IdentityN
ops it attaches to the portion of the graph for which the user would like to add a custom gradient. This is not always convenient. Moreover, thetf.custom_gradient
does not track the appropriate control flow contexts for the variables used in that portion of the graph. The scope handling functions in this class are helpful here._get_clean_grad_scope
This function cleans the named scope for clean graphs.
_update_variables_for_context
This function finds the correct variable tensors for the control flow contexts: for example, to use recomputation inside a while-loop).
The basic structure for a recompute layer is as follows:
Define a custom gradient function using
tf.custom_gradient
inside the__call__
function of a recompute layer.Inside the
__call__
function, call the forward propagation of the layer and define the recompute+gradient function. We recommend you use the_block_recompute_and_gradients
function).
- CtrlFlowWarnedOnce = False#
- abstract call(*args, **kwargs)#
The call function for the layers that use recomputation during backward phase.
This function is wrapped by the
__call__
function of this abstract recompute wrapper, and it must be overridden by a child class to implement the forward computation of the layer.
- static is_in_while_loop(graph=None)#
Returns
True
if the specified, or current if unspecified, graph corresponds to awhile
loop in the forward, backward or cond graph.- Returns
True
if the specified, or current if unspecified, graph corresponds to awhile
loop in the forward, backward or cond graph.- Return type
bool