modelzoo.common.pytorch.layers.TransformerDecoder
modelzoo.common.pytorch.layers.TransformerDecoder#
import path: modelzoo.common.pytorch.layers.TransformerDecoder
TransformerDecoder (decoder_layer, num_layers, norm=None):
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
forward (tgt=None, memory=None, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, rotary_position_embedding_helper=None):
tgt: the sequence to the decoder (required). shape
[batch_size, tgt_seq_length, embed_dim]
.memory: the sequence from the last layer of the encoder (optional). shape
[batch_size, memory_length, embed_dim]
.tgt_mask: the mask for the tgt sequence (optional). shape
[tgt_seq_length, tgt_seq_length]
.memory_mask: the mask for the memory sequence (optional). shape
[memory_length, src_seq_length]
.tgt_key_padding_mask: the mask for the tgt keys per batch (optional). shape
[batch_size, tgt_seq_length]
.memory_key_padding_mask: the mask for the memory keys per batch (optional). shape
[batch_size, memory_length]
.rotary_position_embedding_helper (RotaryPositionEmbeddingHelper): Helper to create rotary embeddings according to the paper RoFormer: Enhanced Transformer with Rotary Position Embedding