Source code for modelzoo.transformers.pytorch.transformer_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch

import cerebras_pytorch as cstorch


def _extend_mask_to_shape_of_4(mask: torch.Tensor):
    assert len(mask.shape) in [
        2,
        3,
        4,
    ], "Masks with shape 2, 3, 4 are supported"
    if len(mask.shape) == 2:
        # [batch_size, target_seq_len]
        mask = mask[:, None, None, :]
    elif len(mask.shape) == 3:
        # [batch_size, src_seq_len, target_seq_len]
        mask = mask[:, None, :, :]
    else:
        # len(key_padding_mask.shape) == 4
        # [batch_size, num_heads, src_seq_len, target_seq_len]
        mask = mask
    return mask


[docs]def make_key_padding_mask_broadcastable( key_padding_mask: torch.Tensor, dtype=None, revert_mask: bool = True, multiply_neg_inf: bool = True, ): """Makes broadcastable key_padding masks so that padding tokens are ignored. Args: key_padding_mask (torch.Tensor): key padding mask with shape in [2,3,4], with entry values either 1 or 0. dtype (torch.dtype): Dtype of the resulting mask. revert_mask (bool): whether to flip the 1's and 0's of the attention mask, default to True. multiply_neg_inf (bool): whether to multiply the resulting mask by a negative infinity constant, default to True. Returns: The key padding mask of shape [batch_size, num_heads, src_seq_len, target_seq_len], with broadcast dimensions set to 1. """ if dtype is None: dtype = torch.float16 key_padding_mask = key_padding_mask.to(dtype=dtype) # Since `key_padding_mask` is passed as `1.0` for positions we want to # attend and `0.0` for masked positions, this operation will "invert" # the mask due to the "negative infinity" scaling at the end. if revert_mask: key_padding_mask = 1.0 - key_padding_mask extended_key_padding_mask = _extend_mask_to_shape_of_4(key_padding_mask) if multiply_neg_inf: extended_key_padding_mask = ( extended_key_padding_mask * torch.finfo(extended_key_padding_mask.dtype).min ) return extended_key_padding_mask
[docs]def build_broadcastable_attention_mask( attention_mask: torch.Tensor, attention_span: Optional[torch.Tensor] = None, build_causal: bool = False, device: Optional[torch.device] = None, dtype=None, revert_mask: bool = True, multiply_neg_inf: bool = True, num_heads: Optional[int] = None, ): """Create broadcastable attention mask (full or causal) so that masked positions are ignored. Args: attention_mask (torch.Tensor): attention mask with shape in [2,3,4], with entry values either 1 or 0. attention_span (torch.Tensor): attention span of keys for VSL has shape [batch_size, target_seq_len]. build_causal (bool): If enabled a causal mask will be created according to the dims of attention_mask. device: (torch.device): The device of the input to the model, used for causal mask creation. dtype (torch.dtype): Dtype of the resulting mask. revert_mask (bool): whether to flip the 1's and 0's of the attention mask, default to True. multiply_neg_inf (bool): whether to multiply the resulting mask by a negative infinity constant, default to True. num_heads (int): Number of heads. Returns: The attention mask of shape [batch_size, num_heads, src_seq_len, target_seq_len], with broadcast dimensions set to 1. """ assert len(attention_mask.shape) in [2, 3, 4], ( f"Masks with dimensions of 2, 3, 4 are supported " f"but found shape {attention_mask.shape}" ) if dtype is None: dtype = torch.float16 attention_mask = attention_mask.to(dtype=dtype) # Since `attention_mask` is passed as `1.0` for positions we want to # attend and `0.0` for masked positions, this operation will "invert" # the mask due to the "negative infinity" scaling at the end. if revert_mask: attention_mask = 1.0 - attention_mask extended_attention_mask = _extend_mask_to_shape_of_4(attention_mask) target_sequence_length = extended_attention_mask.shape[-1] src_sequence_length = ( extended_attention_mask.shape[-2] if extended_attention_mask.shape[-2] != 1 else target_sequence_length ) if build_causal: if attention_span is None: causal_mask = create_2D_autoregressive_mask( src_sequence_length, target_sequence_length, dtype=dtype, device=device, ) else: assert ( num_heads is not None ), "num_heads is needed for creating vsl mask." causal_mask = create_vsl_autoregressive_mask( attention_span, src_sequence_length, target_sequence_length, num_heads=num_heads, dtype=dtype, device=device, ) extended_attention_mask, _ = torch.broadcast_tensors( causal_mask, extended_attention_mask ) if multiply_neg_inf: extended_attention_mask = ( extended_attention_mask * torch.finfo(extended_attention_mask.dtype).min ) return extended_attention_mask
[docs]def create_2D_autoregressive_mask( src_sequence_length: int, target_sequence_length: int, dtype=None, device=None, ): """Creates a reverted autoregressive (upper triangular) mask where the 0s refers to the tokens to attend to and 1s refer to the tokens that are skipped. Args: batch_size (int): Batch size. src_sequence_length (int): Sequence length of the source (num query vectors). target_sequence_length (int): Sequence length of the target (num key vectors). dtype (torch.dtype): Dtype of the resulting mask. device: (torch.device): The device of the input to the model, used for causal mask creation. Returns: The causal mask of shape [src_seq_len, target_seq_len]. """ if dtype is None: dtype = torch.float16 causal_mask = torch.triu( torch.ones( (src_sequence_length, target_sequence_length), device=device, dtype=dtype, ), diagonal=1, ) return causal_mask
[docs]def create_2D_full_mask( src_sequence_length: int, target_sequence_length: int, dtype=None, device=None, ): """Create autoregressive (triangular) mask. Args: batch_size (int): Batch size. src_sequence_length (int): Sequence length of the source (num query vectors). target_sequence_length (int): Sequence length of the target (num key vectors). dtype (torch.dtype): Dtype of the resulting mask. device: (torch.device): The device of the input to the model, used for causal mask creation. Returns: The causal mask of shape [src_seq_len, target_seq_len]. """ if dtype is None: dtype = torch.float16 full_mask = torch.ones( (src_sequence_length, target_sequence_length), device=device, dtype=dtype, ) return full_mask
[docs]def create_vsl_autoregressive_mask( attention_span: torch.Tensor, src_sequence_length: int, target_sequence_length: int, num_heads: int, dtype=None, device=None, ): """Create autoregressive (triangular) mask for variable sequence length. Args: attention_span (torch.Tensor): Attention span of the keys has shape [batch_size, target_sequence_length]. src_sequence_length (int): Sequence length of the source (num query vectors). target_sequence_length (int): Sequence length of the target (num key vectors). num_heads (int): Number of heads. dtype (torch.dtype): Dtype of the resulting mask. device: (torch.device): The device of the input to the model, used for causal mask creation. Returns: The causal mask of shape [src_seq_len, target_seq_len]. """ if dtype is None: dtype = torch.float16 batch_size, _ = attention_span.shape mask_shape = ( batch_size, num_heads, src_sequence_length, target_sequence_length, ) s_in = torch.arange( src_sequence_length, device=device, dtype=torch.float32 )[None, :, None].broadcast_to(mask_shape) s_out = torch.arange( target_sequence_length, device=device, dtype=torch.float32 )[None, None, :].broadcast_to(mask_shape) diff = s_in - s_out one = torch.tensor(1, dtype=torch.float32) zero = torch.tensor(0, dtype=torch.float32) # We want causal_mask = (diff < 0) | (diff > attention_span) written as float # ops. # # For Integer tensors diff, attention_span, equivalent to # min((-min(diff, 0)) + max(diff - attention_span, 0), 1) causal_mask = torch.minimum( ( torch.maximum(diff - attention_span[:, None, None, :], zero) - torch.minimum(diff, zero) ), one, ) causal_mask = causal_mask.to(dtype=dtype) return causal_mask
[docs]def make_sparse_mask_broadcastable( sparse_mask: torch.Tensor, key_padding_mask: torch.Tensor, dtype=None, device=None, revert_mask: bool = True, multiply_neg_inf: bool = True, ): """Create broadcastable sparse mask so that masked positions are ignored. Args: sparse_mask (torch.Tensor): sparse_mask mask with shape [src_seq_len, target_seq_len]. key_padding_mask (torch.Tensor): key padding mask with shape in [2,3,4]. dtype (torch.dtype): Dtype of the resulting mask. device: (torch.device): The device to move the sparse mask to. revert_mask (bool): whether to flip the 1's and 0's of the attention mask, default to True. multiply_neg_inf (bool): whether to multiply the resulting mask by a negative infinity constant, default to True. Returns: The attention mask of shape [batch_size, num_heads, src_seq_len, target_seq_len], with broadcast dimensions set to 1. """ if dtype is None: dtype = torch.float16 if revert_mask: sparse_mask = 1.0 - sparse_mask # When running on CS, move constant from CPU to device wrapped with # XLA literal if cstorch.use_cs(): fixed_sparsity = cstorch.make_constant(sparse_mask.to(dtype=dtype)) else: # When running on GPU, move constant from CPU to GPU fixed_sparsity = sparse_mask.to(device=device) extended_key_padding_mask = make_key_padding_mask_broadcastable( key_padding_mask, dtype=dtype, revert_mask=False, multiply_neg_inf=False, ) sparse_attention_mask, _ = torch.broadcast_tensors( fixed_sparsity, extended_key_padding_mask, ) if multiply_neg_inf: sparse_attention_mask = ( sparse_attention_mask * torch.finfo(sparse_attention_mask.dtype).min ) return sparse_attention_mask
[docs]def get_extended_attention_mask( attention_mask: torch.Tensor, input_shape: Optional[Tuple[int]] = None, causal: bool = False, device: Optional[torch.device] = None, dtype=None, ) -> torch.Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model (required for causal masks) causal: (`bool`): if enabled the returned mask will be causal device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ if dtype is None: dtype = torch.float16 attention_mask = attention_mask.to(dtype=dtype) # Since `attention_mask` is passed as `1.0` for positions we want to # attend and `0.0` for masked positions, this operation will "invert" # the mask due to the "negative infinity" scaling at the end. attention_mask = 1.0 - attention_mask # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is an encoder, make the mask broadcastable to # [batch_size, num_heads, seq_length, seq_length] extended_attention_mask = attention_mask[:, None, None, :] # - if the model is a decoder, apply a causal mask instead of the # padding mask if causal: batch_size, seq_length = input_shape # build seq_length x seq_length lower triangular boolean # mask(i, j) = i > j seq_ids = torch.arange(seq_length, device=device) causal_mask = seq_ids[None, :] > seq_ids[:, None] causal_mask = causal_mask.to(attention_mask.dtype) # in case past_key_values are used we need to add a prefix # zeros mask to the causal mask if attention_mask.shape[1] > seq_length: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] causal_mask = torch.cat( [ torch.zeros( (seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask, _ = torch.broadcast_tensors( causal_mask, extended_attention_mask ) else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) # Scale all the `1.0` masked-off values with the min float value, since we # are adding it to the raw scores before the softmax; this is effectively # the same as removing these entirely. return ( extended_attention_mask * torch.finfo(extended_attention_mask.dtype).min )
[docs]def smooth_loss(prediction_scores, loss, label_smoothing, classes): """ Add label smoothing to loss function, this is a workaround method of label smoothing in our system """ logits = prediction_scores.view(-1, classes) logprobs = torch.nn.functional.log_softmax(logits, dim=-1) smooth_loss = -1.0 * logprobs.mean(dim=-1) loss = (1.0 - label_smoothing) * loss + label_smoothing * smooth_loss return loss