modelzoo.common.pytorch.layers.TransformerDecoder

modelzoo.common.pytorch.layers.TransformerDecoderΒΆ

import path: modelzoo.common.pytorch.layers.TransformerDecoder

TransformerDecoder (decoder_layer, num_layers, norm=None):

  • decoder_layer: an instance of the TransformerDecoderLayer() class (required).

  • num_layers: the number of sub-decoder-layers in the decoder (required).

  • norm: the layer normalization component (optional).

forward (tgt=None, memory=None, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, rotary_position_embedding_helper=None):

  • tgt: the sequence to the decoder (required). shape [batch_size, tgt_seq_length, embed_dim].

  • memory: the sequence from the last layer of the encoder (optional). shape [batch_size, memory_length, embed_dim].

  • tgt_mask: the mask for the tgt sequence (optional). shape [tgt_seq_length, tgt_seq_length].

  • memory_mask: the mask for the memory sequence (optional). shape [memory_length, src_seq_length].

  • tgt_key_padding_mask: the mask for the tgt keys per batch (optional). shape [batch_size, tgt_seq_length].

  • memory_key_padding_mask: the mask for the memory keys per batch (optional). shape [batch_size, memory_length].

  • rotary_position_embedding_helper (RotaryPositionEmbeddingHelper): Helper to create rotary embeddings according to the paper RoFormer: Enhanced Transformer with Rotary Position Embedding