modelzoo.transformers.pytorch.transformer_utils.make_sparse_mask_broadcastable#
- modelzoo.transformers.pytorch.transformer_utils.make_sparse_mask_broadcastable(sparse_mask: torch.Tensor, key_padding_mask: torch.Tensor, dtype=None, device=None, revert_mask: bool = True, multiply_neg_inf: bool = True)[source]#
Create broadcastable sparse mask so that masked positions are ignored.
- Parameters
sparse_mask (torch.Tensor) – sparse_mask mask with shape [src_seq_len, target_seq_len].
key_padding_mask (torch.Tensor) – key padding mask with shape in [2,3,4].
dtype (torch.dtype) – Dtype of the resulting mask.
device – (torch.device): The device to move the sparse mask to.
revert_mask (bool) – whether to flip the 1’s and 0’s of the attention mask, default to True.
multiply_neg_inf (bool) – whether to multiply the resulting mask by a negative infinity constant, default to True.
- Returns
The attention mask of shape [batch_size, num_heads, src_seq_len, target_seq_len], with broadcast dimensions set to 1.