tf.RNNLayer module

tf.RNNLayer module

class tf.RNNLayer.RNNLayer(*args: Any, **kwargs: Any)

Bases: modelzoo.common.layers.tf.BaseLayer.BaseLayer

Generic RNN layer. API must be same as Keras RNN in order to be compatible with the Keras Bidirectional RNN. See the Tensorflow documentation for more information on the following parameters:

Parameters
  • cell – A tf.keras instantiated cell. Currently supports LSTMCell, SimpleRNNCell and GRUCell.

  • name (str) – Name of the layer for the Tensorflow graph.

  • boundary_casting (bool) – See the documentation for BaseLayer.

  • tf_summary (bool) – See the documentation for BaseLayer.

  • **kwargs – Additional keyword arguments for BaseLayer.

call(inputs, mask=None, initial_state=None)

Pushes a sequence through a stack of LSTMs.

Parameters
  • inputs (Tensor) – A sequence of the size [batch_size, max_seq_len, hidden_size].

  • mask (Tensor) – (Optional) Boolean tensor of the same size as inputs, to be applied to the inputs. Usually derived from seq lens.

  • initial_state (Tensor) – (Optional) Initial state for the RNN stack. Of the size [num_layers, 2, batch_size, num_units].

get_config()

Used by the bidirectional RNN to get the params to pass into RNNLayer when calling it.