tf.ReversibleResidualLayerWrapper module

tf.ReversibleResidualLayerWrapper module

class tf.ReversibleResidualLayerWrapper.ReversibleResidualLayerWrapper(*args: Any, **kwargs: Any)

Bases: modelzoo.common.layers.tf.AbstractRecomputeWrapper.AbstractRecomputeWrapper

A wrapper to create a reversible residual layer.

A reversible residual layer is defined to take inputs x1 and x2 and to return y1 and y2 outputs that are the same shape as x1 and x2, respectively. The following is the formulation:

y1 = x1 + f(x2, f_side_input)
y2 = x2 + g(y1, g_side_input)

Limitations:

  • f and g must not close over any tensors. All side inputs to f and g should be passed in to the call function as kwargs, which will be forwarded to the call functions of both f and g (i.e., f and g must accept kwargs).

  • f and g must agree on the dimensionality of their inputs and outputs in order for the addition in the equations above to work. In particular, the input shape to f must be the same as the output shape of g, and vice versa.

Parameters
  • f_block (function) – Function (Tensor) -> (Tensor) defining the f transformation of the layer. Input shape must be the shape of x2, and the output shape must be the same as x1.

  • g_block (function) – Function (Tensor) -> (Tensor) defining the g transformation of the layer. Input shape must be the shape of x1, and the output shape must be the same as x2.

call(*args, **kwargs)