tf.RNNLayer module
tf.RNNLayer module¶
- class tf.RNNLayer.RNNLayer(*args: Any, **kwargs: Any)¶
Bases:
modelzoo.common.layers.tf.BaseLayer.BaseLayer
Generic RNN layer. API must be same as Keras RNN in order to be compatible with the Keras Bidirectional RNN. See the Tensorflow documentation for more information on the following parameters:
- Parameters
cell – A
tf.keras
instantiated cell. Currently supportsLSTMCell
,SimpleRNNCell
andGRUCell
.name (str) – Name of the layer for the Tensorflow graph.
boundary_casting (bool) – See the documentation for
BaseLayer
.tf_summary (bool) – See the documentation for
BaseLayer
.**kwargs – Additional keyword arguments for
BaseLayer
.
- call(inputs, mask=None, initial_state=None)¶
Pushes a sequence through a stack of LSTMs.
- Parameters
inputs (Tensor) – A sequence of the size
[batch_size, max_seq_len, hidden_size]
.mask (Tensor) – (Optional) Boolean tensor of the same size as inputs, to be applied to the inputs. Usually derived from seq lens.
initial_state (Tensor) – (Optional) Initial state for the RNN stack. Of the size
[num_layers, 2, batch_size, num_units]
.
- get_config()¶
Used by the bidirectional RNN to get the params to pass into
RNNLayer
when calling it.