Source code for modelzoo.transformers.pytorch.t5.t5_model

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This code is adapted from
# https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_t5.py
#
# Copyright 2022 Cerebras Systems.
#
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch
import torch.nn as nn

import cerebras_pytorch as cstorch
from modelzoo.common.pytorch.layers import (
    EmbeddingLayer,
    TransformerDecoder,
    TransformerDecoderLayer,
    TransformerEncoder,
    TransformerEncoderLayer,
)
from modelzoo.common.pytorch.model_utils.norms import get_norm
from modelzoo.transformers.pytorch.transformer_utils import (
    build_broadcastable_attention_mask,
    create_2D_autoregressive_mask,
    make_key_padding_mask_broadcastable,
)


[docs]class T5ForConditionalGeneration(nn.Module): r""" T5 Model with a `language modeling` head on top. Arguments: src_vocab_size (:obj:`int`, `optional`, defaults to 32128): Source vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the :obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`. tgt_vocab_size (:obj:`int`, `optional`, defaults to 32128): Target vocabulary size of the T5 model. Only useful if set for Transformer variant where source and target vocabularies can be different. d_model (:obj:`int`, `optional`, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (:obj:`int`, `optional`, defaults to 64): Size of the key, query, value projections per attention head. :obj:`d_kv` does *not* have tobe equal to :obj:`d_model // num_heads`. d_ff (:obj:`int`, `optional`, defaults to 2048): Size of the intermediate feed forward layer in each :obj:`T5Block`. encoder_num_hidden_layers (:obj:`int`, `optional`, defaults to 6): Number of hidden layers in the Transformer encoder. decoder_num_hidden_layers (:obj:`int`, `optional`): Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not set. num_heads (:obj:`int`, `optional`, defaults to 8): Number of attention heads for each attention layer in the Transformer encoder and decoder. relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32): The number of buckets to use for each attention layer. norm_type (:obj:`str`, `optional`, defaults to "rmsnorm"): Determines which type of layernorm to use. RMSNorm is the same as T5-style layernorm (no mean subtraction and bias correction). dropout_rate (:obj:`float`, `optional`, defaults to 0.1): The ratio for all dropout layers. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6): The epsilon used by the layer normalization layers. initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). encoder_nonlinearity (:obj:`string`, `optional`, defaults to :obj:`"relu"`): Type of feed forward layer to be used in encoder. Should be one of :obj:`"relu"` or :obj:`"geglu"` or :obj:`"gelu"`. T5v1.1 uses the :obj:`"geglu"` feed forward projection. Original T5 uses :obj:`"relu"`. decoder_nonlinearity (:obj:`string`, `optional`, defaults to :obj:`"relu"`): Type of feed forward layer to be used in decoder. Should be one of :obj:`"relu"` or :obj:`"geglu"` or :obj:`"gelu"`. T5v1.1 uses the :obj:`"geglu"` feed forward projection. Original T5 uses :obj:`"relu"`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). position_embedding_type (:obj: `string`, `optional`, defaults to :obj:`"relative"`): The type of position embedding to use. Should be one of :obj:`"fixed"`, :obj:`"learned_absolute"`, :obj:`"relative"`, or :obj:`None`. :obj:`"fixed"` uses a concatenation of sin curves to express relative position as used in the original Transformer paper. :obj:`"learned_absolute"` uses a learned vector for each position in the sequence. :obj:`"relative"` uses learned relative position embeddings as introduced in https://arxiv.org/abs/1803.02155, configured as done in the original T5 publication. :obj:`None` turns off position embedding altogether. src_max_position_embeddings (:obj:`int`, `optional`, defaults to :obj: 512): Maximum source sequence length to train using to train the model. tgt_max_position_embeddings (:obj:`int`, `optional`, defaults to :obj: 512): Maximum target sequence length to train using to train the model. use_dropout_outside_residual_path (:obj:`bool`, `optional`, defaults to :obj: True): Whether to set dropout calculations outside of the residual path. Set to `True` for T5, but `False` for Transformer. share_encoder_decoder_embedding (:obj:`bool`, `optional`, defaults to :obj: True): Whether to share encoder/decoder embedding layer. Set to `True` for both T5 and Transformer models. share_embedding_weights (:obj:`bool`, `optional`, defaults to :obj: True): Whether to share embedding weights between decoder and language model head. relu_dropout_rate (:obj:`int`, `optional`, defaults to :obj: 0.1): Dropout rate utilized in the FFN layer after applying relu activation function. This parameter is set to `0` for Transformer model, and set to `dropout_rate` for default T5 configuration. Transformer reference: https://github.com/tensorflow/tensor2tensor/blob/5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d/tensor2tensor/models/transformer.py#L1811 T5 reference: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/t5/modeling_t5.py#L261 use_pre_encoder_decoder_dropout (:obj:`bool`, `optional`, defaults to :obj: False): Whether to use dropout layer after positional embedding layer and encoder/decoder. This is set to `False` for T5 and `True` for Transformer. use_pre_encoder_decoder_layer_norm (:obj:`bool`, `optional`, defaults to :obj: True): Whether to use layer norm before passing input tensors into encoder/decoder. This is set to `True` for T5 and `False` for Transformer. use_ffn_bias (:obj:`bool`, `optional`, defaults to :obj: False): Whether to use bias in the hidden layer with relu activation. This is set to `False` for T5, and `True` for Transformer. lm_loss_weight (:obj:`float`, `optional`, default to :obj: 1.0): Value that scales loss by the mean number of predictions per sequence in the dataset. use_transformer_initialization (:obj:`bool`, `optional`, defaults to :obj:`False`): The Transformer model tends to converge best with a scaled variant on Xavier uniform initialization used for linear layers. This contrasts the initialization used for the original T5 paper, which uses He normal initialization for linear layers. Setting this flag to `True` switches the initialization to the Transformer specific scaled Xavier initialization. """
[docs] def __init__( self, src_vocab_size=32128, tgt_vocab_size=32128, d_model=512, d_kv=64, d_ff=2048, encoder_num_hidden_layers=6, decoder_num_hidden_layers=None, num_heads=8, relative_attention_num_buckets=32, norm_type="rmsnorm", dropout_rate=0.1, relu_dropout_rate=None, layer_norm_epsilon=1e-6, initializer_factor=1.0, encoder_nonlinearity="relu", decoder_nonlinearity="relu", use_projection_bias_in_attention=False, attention_softmax_fp32=True, use_cache=False, decoder_start_token_id=None, pad_token_id=0, position_embedding_type="relative", src_max_position_embeddings=512, tgt_max_position_embeddings=512, use_dropout_outside_residual_path=True, share_encoder_decoder_embedding=True, share_embedding_weights=True, tie_encoder_decoder=False, use_pre_encoder_decoder_dropout=False, use_pre_encoder_decoder_layer_norm=True, use_ffn_bias=False, label_smoothing=0.0, use_transformer_initialization=False, attention_module="aiayn_attention", extra_attention_params={}, **kwargs, ): super().__init__() # Copy only the subset of params that are referenced later self.d_model = d_model self.d_kv = d_kv self.attention_inner_dim = d_kv * num_heads self.initializer_factor = initializer_factor self.use_cache = use_cache self.share_encoder_decoder_embedding = share_encoder_decoder_embedding self.share_embedding_weights = share_embedding_weights self.tie_encoder_decoder = tie_encoder_decoder self.decoder_start_token_id = decoder_start_token_id self.pad_token_id = pad_token_id self.position_embedding_type = position_embedding_type self.label_smoothing = label_smoothing if decoder_num_hidden_layers is None: decoder_num_hidden_layers = encoder_num_hidden_layers if relu_dropout_rate is None: relu_dropout_rate = dropout_rate assert position_embedding_type in ( "fixed", "learned_absolute", "relative", None, ), ( f"Position embedding must be one of `fixed`, `learned_absolute`, " f"`relative`, or None. Got {position_embedding_type}." ) if position_embedding_type == "learned_absolute": position_embedding_type = "learned" self.encoder_embeddings = EmbeddingLayer( src_vocab_size, d_model, embeddings_initializer={ "name": "truncated_normal", "mean": 0.0, "std": 1.0, "a": -2.0, "b": 2.0, } if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * 1.0, }, max_position_embeddings=src_max_position_embeddings, position_embedding_type=position_embedding_type, # RPE: num_heads=num_heads, bidirectional=True, num_relative_attention_buckets=relative_attention_num_buckets, ) self.decoder_embeddings = EmbeddingLayer( tgt_vocab_size, d_model, embeddings_initializer={ "name": "truncated_normal", "mean": 0.0, "std": 1.0, "a": -2.0, "b": 2.0, } if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * 1.0, }, max_position_embeddings=tgt_max_position_embeddings, position_embedding_type=position_embedding_type, # RPE: num_heads=num_heads, bidirectional=False, num_relative_attention_buckets=relative_attention_num_buckets, ) if self.share_encoder_decoder_embedding: assert ( src_vocab_size == tgt_vocab_size ), "Cannot share embeddings between encoder and decoder due to different vocab sizes" self.decoder_embeddings.set_input_embeddings( self.encoder_embeddings.get_input_embeddings() ) self.pre_encoder_dropout = None self.pre_decoder_dropout = None # Transformer model uses dropout right after position embeddings # and before the encoder call, T5 does not use it. if use_pre_encoder_decoder_dropout: self.pre_encoder_dropout = nn.Dropout(dropout_rate) self.pre_decoder_dropout = nn.Dropout(dropout_rate) assert encoder_nonlinearity in [ "relu", "gelu", "reglu", "geglu", "swiglu", ], "T5/Transformer doesn't support encoder_nonlinearity {}".format( encoder_nonlinearity ) assert decoder_nonlinearity in [ "relu", "gelu", "reglu", "geglu", "swiglu", ], "T5/Transformer doesn't support decoder_nonlinearity {}".format( decoder_nonlinearity ) if (encoder_nonlinearity == "gelu" and use_ffn_bias) or ( decoder_nonlinearity == "gelu" and use_ffn_bias ): logging.warning( "Overriding use_ffn_bias to false because using gelu" ) use_ffn_bias = False norm_class = get_norm(norm_type) encoder_layer = TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=d_ff, dropout=dropout_rate, activation=encoder_nonlinearity, norm_layer=norm_class, layer_norm_eps=layer_norm_epsilon, norm_first=use_pre_encoder_decoder_layer_norm, batch_first=True, extra_attention_params={**extra_attention_params,}, attention_type="scaled_dot_product" if use_transformer_initialization else "dot_product", attention_module=attention_module, attention_inner_dim=self.attention_inner_dim, attention_softmax_fp32=attention_softmax_fp32, use_projection_bias_in_attention=use_projection_bias_in_attention, use_ffn_bias_in_attention=False, use_ffn_bias=use_ffn_bias, ffn_dropout_rate=relu_dropout_rate, use_ff_layer1_dropout=True, use_ff_layer2_dropout=True, attention_q_initializer={ "name": "variance_scaling", "scale": 1.0 / (d_kv * 9.0), "mode": "fan_avg", "distribution": "uniform", } if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * ((d_model * d_kv) ** -0.5), }, attention_initializer={ "name": "variance_scaling", "scale": 1.0 / 9.0, "mode": "fan_avg", "distribution": "uniform", } if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * (d_model ** -0.5), }, ffn_initializer={"name": "xavier_uniform", "gain": 1.0} if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * (d_model ** -0.5), }, ffn_output_layer_initializer={"name": "xavier_uniform", "gain": 1.0} if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * (d_ff ** -0.5), }, ) encoder_final_layer_norm = norm_class(d_model, eps=layer_norm_epsilon) self.dropout_before_encoder = nn.Dropout(dropout_rate) self.encoder = TransformerEncoder( encoder_layer, num_layers=encoder_num_hidden_layers, norm=encoder_final_layer_norm, ) self.dropout_after_encoder = None if use_dropout_outside_residual_path: self.dropout_after_encoder = nn.Dropout(dropout_rate) decoder_layer = TransformerDecoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=d_ff, dropout=dropout_rate, activation=encoder_nonlinearity, norm_layer=norm_class, layer_norm_eps=layer_norm_epsilon, norm_first=use_pre_encoder_decoder_layer_norm, batch_first=True, extra_attention_params={**extra_attention_params,}, attention_type="scaled_dot_product" if use_transformer_initialization else "dot_product", attention_module=attention_module, attention_inner_dim=self.attention_inner_dim, use_projection_bias_in_attention=use_projection_bias_in_attention, attention_softmax_fp32=attention_softmax_fp32, use_ffn_bias_in_attention=False, use_ffn_bias=use_ffn_bias, use_ff_layer1_dropout=True, use_ff_layer2_dropout=True, attention_q_initializer={ "name": "variance_scaling", "scale": 1.0 / (d_kv * 9.0), "mode": "fan_avg", "distribution": "uniform", } if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * ((d_model * d_kv) ** -0.5), }, attention_initializer={ "name": "variance_scaling", "scale": 1.0 / 9.0, "mode": "fan_avg", "distribution": "uniform", } if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * (d_model ** -0.5), }, ffn_initializer={"name": "xavier_uniform", "gain": 1.0} if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * (d_model ** -0.5), }, ffn_output_layer_initializer={"name": "xavier_uniform", "gain": 1.0} if use_transformer_initialization else { "name": "normal", "mean": 0.0, "std": initializer_factor * (d_ff ** -0.5), }, ) decoder_final_layer_norm = norm_class(d_model, eps=layer_norm_epsilon) self.dropout_before_decoder = nn.Dropout(dropout_rate) self.decoder = TransformerDecoder( decoder_layer, num_layers=decoder_num_hidden_layers, norm=decoder_final_layer_norm, ) self.dropout_after_decoder = None if use_dropout_outside_residual_path: self.dropout_after_decoder = nn.Dropout(dropout_rate) self.lm_head = nn.Linear(d_model, tgt_vocab_size, bias=False) # Initialize weights and apply final processing self.__reset_parameters() assert ( not tie_encoder_decoder ), "Implementation does not currently support tied Encoder/Decoder weights" self.tie_weights()
def reset_parameters(self): self.encoder_embeddings.reset_parameters() self.decoder_embeddings.reset_parameters() if self.relative_position_encoder: self.relative_position_encoder.reset_parameters() if self.relative_position_decoder: self.relative_position_decoder.reset_parameters() self.encoder.reset_parameters() self.decoder.reset_parameters() self.__reset_parameters() def __reset_parameters(self): # Initialize LM head if not self.share_embedding_weights: self.lm_head.weight.data.normal_( mean=0.0, std=self.initializer_factor * 1.0 ) # Helper function `forward` for computing everything up to (but not # including) the model head. This is helpful for models that inherit from # T5 that apply a different head at the end of the model def compute_hidden_states( self, input_ids=None, attention_mask=None, prepend_embeddings=None, ): src = self.encoder_embeddings(input_ids) if prepend_embeddings is not None: src = torch.cat([prepend_embeddings, src], dim=1) # Transformer uses pre-encoder dropout if self.pre_encoder_dropout: src = self.pre_encoder_dropout(src) # Compute relative position bias for the encoder block if applicable encoder_self_attn_position_bias = self.encoder_embeddings.compute_position_bias( src.shape[1], src.shape[1] ) src = self.dropout_before_encoder(src) if attention_mask is not None: attention_mask = make_key_padding_mask_broadcastable( attention_mask, dtype=src.dtype ) # Convert encoder inputs in embeddings if needed hidden_states = self.encoder( src, mask=attention_mask, self_attn_position_bias=encoder_self_attn_position_bias, ) if self.dropout_after_encoder: hidden_states = self.dropout_after_encoder( hidden_states ) # HF T5 Decoder also applies dropout at the end return hidden_states def compute_decoder_states( self, hidden_states=None, memory_mask=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, labels=None, use_cache=None, ): assert ( past_key_values is None ), "past_key_values should be None since inference is not supported yet" use_cache = use_cache if use_cache is not None else self.use_cache assert ( not use_cache ), "cannot enable use_cache because inference is not supported yet" if ( labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) decoder_inputs_embeds = self.decoder_embeddings(decoder_input_ids) # Transformer uses dropout before feeding to decoder module while # T5 does not use this layer if self.pre_decoder_dropout: decoder_inputs_embeds = self.pre_decoder_dropout( decoder_inputs_embeds ) batch_size, decoder_seq_length = decoder_inputs_embeds.size()[:2] decoder_self_attn_position_bias = self.decoder_embeddings.compute_position_bias( decoder_seq_length, decoder_seq_length ) if memory_mask is not None: memory_mask = make_key_padding_mask_broadcastable( memory_mask, dtype=hidden_states.dtype ) if decoder_attention_mask is None: extended_decoder_attention_mask = ( create_2D_autoregressive_mask( decoder_seq_length, decoder_seq_length, device=decoder_inputs_embeds.device, dtype=hidden_states.dtype, ) * torch.finfo(hidden_states.dtype).min ) else: extended_decoder_attention_mask = build_broadcastable_attention_mask( decoder_attention_mask, build_causal=True, device=decoder_inputs_embeds.device, dtype=hidden_states.dtype, ) decoder_inputs_embeds = self.dropout_before_decoder( decoder_inputs_embeds ) decoder_outputs = self.decoder( decoder_inputs_embeds, memory=hidden_states, tgt_mask=extended_decoder_attention_mask, memory_mask=memory_mask, past_kv=past_key_values, cache_present_kv=use_cache, self_attn_position_bias=decoder_self_attn_position_bias, ) if use_cache: sequence_output, present_kv = decoder_outputs else: sequence_output = decoder_outputs if self.dropout_after_decoder: sequence_output = self.dropout_after_decoder(sequence_output) return sequence_output def compute_sequence_output( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, labels=None, use_cache=None, prepend_embeddings=None, ): if encoder_outputs is None: hidden_states = self.compute_hidden_states( input_ids, attention_mask, prepend_embeddings=prepend_embeddings, ) else: hidden_states = encoder_outputs sequence_output = self.compute_decoder_states( hidden_states=hidden_states, memory_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, labels=labels, use_cache=use_cache, ) return sequence_output
[docs] def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, labels=None, use_cache=None, prepend_embeddings=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Returns: Examples:: >>> from transformers import T5Tokenizer, T5ForConditionalGeneration >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') >>> # training >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> # inference >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 >>> outputs = model.generate(input_ids) >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> # studies have shown that owning a dog is good for you. """ sequence_output = self.compute_sequence_output( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, labels=labels, use_cache=use_cache, prepend_embeddings=prepend_embeddings, ) if self.share_embedding_weights: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.d_model ** -0.5) if ( cstorch.use_cs() and cstorch.current_executor().cs_config.precision_opt_level == 1 ): return cstorch.pol(bwd_level=0)(self.lm_head)(sequence_output) return self.lm_head(sequence_output)
[docs] def tie_weights(self): """ Tie the weights between the input embeddings and the output embeddings and (if enabled) tie encoder/decoder weights. """ output_embeddings = self.get_output_embeddings() if output_embeddings is not None and self.share_embedding_weights: self._tie_or_clone_weights( output_embeddings, self.get_input_embeddings() ) if self.tie_encoder_decoder: if hasattr(self, self.base_model_prefix): self = getattr(self, self.base_model_prefix) self._tie_encoder_decoder_weights( self.encoder, self.decoder, self.base_model_prefix )
@staticmethod def _tie_encoder_decoder_weights( encoder: nn.Module, decoder: nn.Module, base_model_prefix: str ): uninitialized_encoder_weights: List[str] = [] if decoder.__class__ != encoder.__class__: print( f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." ) def tie_encoder_to_decoder_recursively( decoder_pointer: nn.Module, encoder_pointer: nn.Module, module_name: str, uninitialized_encoder_weights: List[str], depth=0, ): assert isinstance(decoder_pointer, nn.Module) and isinstance( encoder_pointer, nn.Module ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" if hasattr(decoder_pointer, "weight"): assert hasattr(encoder_pointer, "weight") encoder_pointer.weight = decoder_pointer.weight if hasattr(decoder_pointer, "bias"): assert hasattr(encoder_pointer, "bias") encoder_pointer.bias = decoder_pointer.bias return encoder_modules = encoder_pointer._modules decoder_modules = decoder_pointer._modules if len(decoder_modules) > 0: assert ( len(encoder_modules) > 0 ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" all_encoder_weights = set( [ module_name + "/" + sub_name for sub_name in encoder_modules.keys() ] ) encoder_layer_pos = 0 for name, module in decoder_modules.items(): if name.isdigit(): encoder_name = str(int(name) + encoder_layer_pos) decoder_name = name if not isinstance( decoder_modules[decoder_name], type(encoder_modules[encoder_name]), ) and len(encoder_modules) != len(decoder_modules): # this can happen if the name corresponds to the position in a list module list of layers # in this case the decoder has added a cross-attention that the encoder does not have # thus skip this step and subtract one layer pos from encoder encoder_layer_pos -= 1 continue elif name not in encoder_modules: continue elif depth > 500: raise ValueError( "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." ) else: decoder_name = encoder_name = name tie_encoder_to_decoder_recursively( decoder_modules[decoder_name], encoder_modules[encoder_name], module_name + "/" + name, uninitialized_encoder_weights, depth=depth + 1, ) all_encoder_weights.remove(module_name + "/" + encoder_name) uninitialized_encoder_weights += list(all_encoder_weights) # tie weights recursively tie_encoder_to_decoder_recursively( decoder, encoder, base_model_prefix, uninitialized_encoder_weights ) if len(uninitialized_encoder_weights) > 0: print( f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" ) def _tie_or_clone_weights(self, output_embeddings, input_embeddings): """Tie or clone module weights depending of whether we are using TorchScript or not""" if not isinstance(output_embeddings, list): output_embeddings = [output_embeddings] input_embeddings = [input_embeddings] for output_embedding, input_embedding in zip( output_embeddings, input_embeddings ): output_embedding.weight = input_embedding.weight if getattr(output_embedding, "bias", None) is not None: output_embedding.bias.data = nn.functional.pad( output_embedding.bias.data, ( 0, output_embedding.weight.shape[0] - output_embedding.bias.shape[0], ), "constant", 0, ) if hasattr(output_embedding, "out_features") and hasattr( input_embedding, "num_embeddings" ): output_embedding.out_features = input_embedding.num_embeddings def get_input_embeddings(self): # This function returns decoder token embeddings # in order to properly tie embeddings between the decoder # input and decoder output. return self.decoder_embeddings.get_input_embeddings() def set_input_embeddings(self, new_embeddings): self.decoder_embeddings.set_input_embeddings(new_embeddings) if self.share_embedding_weights: self.encoder_embeddings.set_input_embeddings(new_embeddings) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_output_embeddings(self): return self.lm_head def _shift_right(self, input_ids): decoder_start_token_id = self.decoder_start_token_id pad_token_id = self.pad_token_id assert ( decoder_start_token_id is not None ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = decoder_start_token_id assert ( pad_token_id is not None ), "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) assert torch.all( shifted_input_ids >= 0 ).item(), "Verify that `shifted_input_ids` has only positive values" return shifted_input_ids