modelzoo.vision.pytorch.dit.layers.DiTDecoder.DiTDecoder#
- class modelzoo.vision.pytorch.dit.layers.DiTDecoder.DiTDecoder[source]#
Bases:
modelzoo.common.pytorch.layers.TransformerDecoder.TransformerDecoder
Methods
Pass the inputs (and mask) through the decoder layer in turn.
reset_parameters
- __call__(*args: Any, **kwargs: Any) Any #
Call self as a function.
- static __new__(cls, *args: Any, **kwargs: Any) Any #
- forward(tgt: torch.Tensor, memory: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, sparse_mask: Optional[torch.Tensor] = None, memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, memory_key_padding_mask: Optional[torch.Tensor] = None, self_attn_position_bias: Optional[torch.Tensor] = None, cross_attn_position_bias: Optional[torch.Tensor] = None, rotary_position_embedding_helper: Optional[modelzoo.common.pytorch.model_utils.RotaryPositionEmbeddingHelper.RotaryPositionEmbeddingHelper] = None, past_kv: Optional[List[Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]]] = None, cache_present_kv: bool = False, **extra_args) Union[torch.Tensor, Tuple[torch.Tensor, List[Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]]]] [source]#
Pass the inputs (and mask) through the decoder layer in turn.
- Parameters
tgt – the sequence to the decoder (required).
memory – the sequence from the last layer of the encoder (optional).
tgt_mask – the mask for the tgt sequence (optional).
memory_mask – the mask for the memory sequence (optional).
tgt_key_padding_mask – the mask for the tgt keys per batch (optional).
memory_key_padding_mask – the mask for the memory keys per batch (optional).
self_attn_position_bias – the tensor containing position bias to apply in self-attention, can be obtained from relative or alibi position embeddings.
cross_attn_position_bias – similar to self_attn_position_bias, this is the tensor containing position bias to apply in cross-attention.
rotary_position_embedding_helper (Optional[RotaryPositionEmbeddingHelper]) – A helper class to apply rotary embedding on the input tensor.
past_kv – Past keys and values for each of the decoder layers (optional).
cache_present_kv – Specifies if the present keys and values must be cached and returned. (optional).
- Shape:
see the docs in Transformer class.