modelzoo.transformers.pytorch.transformer_utils.build_broadcastable_attention_mask#

modelzoo.transformers.pytorch.transformer_utils.build_broadcastable_attention_mask(attention_mask: torch.Tensor, build_causal: bool = False, device: Optional[torch.device] = None, dtype=None, revert_mask: bool = True, multiply_neg_inf: bool = True)[source]#

Create broadcastable attention mask (full or causal) so that masked positions are ignored.

Parameters
  • attention_mask (torch.Tensor) – attention mask with shape in [2,3,4], with entry values either 1 or 0.

  • build_causal (bool) – If enabled a causal mask will be created according to the dims of attention_mask.

  • device – (torch.device): The device of the input to the model, used for causal mask creation.

  • dtype (torch.dtype) – Dtype of the resulting mask.

  • revert_mask (bool) – whether to flip the 1’s and 0’s of the attention mask, default to True.

  • multiply_neg_inf (bool) – whether to multiply the resulting mask by a negative infinity constant, default to True.

Returns

The attention mask of shape [batch_size, num_heads, src_seq_len, target_seq_len], with broadcast dimensions set to 1.