modelzoo.vision.pytorch.dit.layers.DiTDecoder.DiTDecoder#

class modelzoo.vision.pytorch.dit.layers.DiTDecoder.DiTDecoder[source]#

Bases: modelzoo.common.pytorch.layers.TransformerDecoder.TransformerDecoder

Methods

forward

Pass the inputs (and mask) through the decoder layer in turn.

reset_parameters

__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

__init__(**kwargs)[source]#
static __new__(cls, *args: Any, **kwargs: Any) Any#
forward(tgt: torch.Tensor, memory: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, sparse_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, self_attn_position_bias: Optional[torch.Tensor] = None, cross_attn_position_bias: Optional[torch.Tensor] = None, rotary_position_embedding_helper: Optional[modelzoo.common.pytorch.model_utils.RotaryPositionEmbeddingHelper.RotaryPositionEmbeddingHelper] = None, past_kv: Optional[List[Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]]] = None, cache_present_kv: bool = False, **extra_args) Union[torch.Tensor, Tuple[torch.Tensor, List[Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]]]][source]#

Pass the inputs (and mask) through the decoder layer in turn.

Parameters
  • tgt – the sequence to the decoder (required).

  • memory – the sequence from the last layer of the encoder (optional).

  • tgt_mask – the mask for the tgt sequence (optional).

  • memory_mask – the mask for the memory sequence (optional).

  • tgt_key_padding_mask – the mask for the tgt keys per batch (optional).

  • memory_key_padding_mask – the mask for the memory keys per batch (optional).

  • self_attn_position_bias – the tensor containing position bias to apply in self-attention, can be obtained from relative or alibi position embeddings.

  • cross_attn_position_bias – similar to self_attn_position_bias, this is the tensor containing position bias to apply in cross-attention.

  • rotary_position_embedding_helper (Optional[RotaryPositionEmbeddingHelper]) – A helper class to apply rotary embedding on the input tensor.

  • past_kv – Past keys and values for each of the decoder layers (optional).

  • cache_present_kv – Specifies if the present keys and values must be cached and returned. (optional).

Shape:

see the docs in Transformer class.