Source code for cerebras.modelzoo.layers.GPTJDecoderLayer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

from torch import Tensor

from cerebras.modelzoo.layers.RotaryPositionEmbeddingHelper import (
    RotaryPositionEmbeddingHelper,
)
from cerebras.modelzoo.layers.TransformerDecoderLayer import (
    SelfAndCrossAttnKV,
    SelfAttnKV,
    TransformerDecoderLayer,
)


[docs]class GPTJDecoderLayer(TransformerDecoderLayer): """ GPTJDecoderLayer is inherited from `TransformerDecoderLayer`, it has 2 modifications: 1. It uses parallel decoder architecture instead of the sequential one 2. It supports both gptj and gpt-neox which uses untied_layer_norm Reference: https://www.cerebras.net/blog/how-to-harness-the-predictive-power-of-gpt-j Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multihead-attention models (required). use_untied_layer_norm (bool): whether to use untied layer_norm. Should be False for GPTJ and True for Neox kwargs: the rest of the arguments the same as `TransformerDecoderLayer` """
[docs] def __init__( self, d_model: int, nhead: int, use_untied_layer_norm: bool = False, **kwargs, ): super(GPTJDecoderLayer, self).__init__(d_model, nhead, **kwargs) self.use_untied_layer_norm = use_untied_layer_norm if not self.use_untied_layer_norm: self.norm3 = None
[docs] def forward( self, tgt: Tensor, memory: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, rotary_position_embedding_helper: Optional[ RotaryPositionEmbeddingHelper ] = None, past_kv: Optional[Union[SelfAttnKV, SelfAndCrossAttnKV]] = None, cache_present_kv: bool = False, self_attn_position_bias: Optional[Tensor] = None, cross_attn_position_bias: Optional[Tensor] = None, layer_idx: Optional[int] = None, ) -> Tensor: """GPTJ layer with rotary position embeddings and parallel decoder architecture Args: tgt: the sequence to the decoder layer (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). rotary_position_embedding_helper (Optional[RotaryPositionEmbeddingHelper]): A helper class to apply rotary embedding on the input tensor. past_kv: Past keys and values for self attention and (if applicable) cross attention modules. Key/value tensors have shape ``[batch_size, num_heads, seq_length, embed_dim / num_heads]``. (optional). cache_present_kv: Specifies if the present keys and values must be cached and returned. Needed to speed up the computations when the decoder is called within an autoregressive loop. (optional). self_attn_position_bias: the tensor containing position bias to apply in self-attention, can be obtained from relative or alibi position embeddings. Shape: Output tensor with shape """ x = tgt residual = x hidden_normed = self.norm1(x) attn_output = self._sa_block( hidden_normed, tgt_mask, tgt_key_padding_mask, rotary_position_embedding_helper, past_kv=past_kv[:2] if past_kv is not None else None, cache_present_kv=cache_present_kv, self_attn_position_bias=self_attn_position_bias, layer_idx=layer_idx, ) # Apply untied layernorm in neox if self.norm3 is not None: hidden_normed = self.norm3(x) ffn_output = self.ffn(hidden_normed) outputs = residual + ffn_output + attn_output[0] return outputs