Source code for modelzoo.transformers.pytorch.t5.model

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch.nn as nn

import cerebras_pytorch as cstorch
from cerebras_pytorch.metrics import AccuracyMetric, PerplexityMetric
from modelzoo.common.pytorch.model_utils.T5ForConditionalGenerationLoss import (
    T5ForConditionalGenerationLoss,
)
from modelzoo.transformers.pytorch.t5.t5_model import T5ForConditionalGeneration


[docs]class T5ForConditionalGenerationModel(nn.Module): """ T5 models """
[docs] def __init__(self, params): super().__init__() model_params = params["model"].copy() self.model = self.build_model(model_params) self.loss_fn = T5ForConditionalGenerationLoss( params["model"].get("lm_loss_weight", 1.0), mlm_loss_scaling=params["model"].get( "mlm_loss_scaling", "batch_size" ), label_smoothing=params["model"].get("label_smoothing", 0.0), ) self.compute_eval_metrics = model_params.pop( "compute_eval_metrics", True ) if self.compute_eval_metrics: self.accuracy_metric = AccuracyMetric(name="eval/accuracy_lm") self.perplexity_metric = PerplexityMetric(name="eval/perplexity_lm")
def _post_device_transfer(self): self.model.tie_weights() def build_model(self, model_params): model = None kwargs = { "src_vocab_size": model_params.pop("src_vocab_size"), "tgt_vocab_size": model_params.pop("tgt_vocab_size", None), "mlm_loss_scaling": model_params.pop( "mlm_loss_scaling", "batch_size" ), "label_smoothing": model_params.pop("label_smoothing", 0.0), "extra_ids": model_params.pop("extra_ids", 0), "d_model": model_params.pop("d_model"), "d_kv": model_params.pop("d_kv"), "d_ff": model_params.pop("d_ff"), "encoder_num_hidden_layers": model_params.pop( "encoder_num_hidden_layers" ), "decoder_num_hidden_layers": model_params.pop( "decoder_num_hidden_layers" ), "num_heads": model_params.pop("num_heads"), "use_projection_bias_in_attention": model_params.pop( "use_projection_bias_in_attention", False ), "relative_attention_num_buckets": model_params.pop( "relative_attention_num_buckets", 32 ), # This param ties weights between lm_head and # decoder.embed_tokens layers. "tie_word_embeddings": model_params.pop( "share_embedding_weights", True, ), "norm_type": model_params.pop("norm_type", "rmsnorm"), "dropout_rate": model_params.pop("dropout_rate"), "layer_norm_epsilon": float( model_params.pop("layer_norm_epsilon", 1.0e-5), ), "encoder_nonlinearity": model_params.pop("encoder_nonlinearity"), "decoder_nonlinearity": model_params.pop("decoder_nonlinearity"), "position_embedding_type": model_params.pop( "position_embedding_type", "relative" ), "src_max_position_embeddings": model_params.pop( "src_max_position_embeddings" ), "tgt_max_position_embeddings": model_params.pop( "tgt_max_position_embeddings" ), "use_dropout_outside_residual_path": model_params.pop( "use_dropout_outside_residual_path", True ), # This param ties weights between encoder.embed_tokens and # decoder.embed_tokens layers. "share_encoder_decoder_embedding": model_params.pop( "share_encoder_decoder_embedding", True ), "relu_dropout_rate": model_params.pop("relu_dropout_rate", None), "use_pre_encoder_decoder_dropout": model_params.pop( "use_pre_encoder_decoder_dropout", False ), "use_pre_encoder_decoder_layer_norm": model_params.pop( "use_pre_encoder_decoder_layer_norm", True ), "use_ffn_bias": model_params.pop("use_ffn_bias", False), "lm_loss_weight": model_params.pop("lm_loss_weight", 1.0), "use_transformer_initialization": model_params.pop( "use_transformer_initialization", False ), "attention_softmax_fp32": model_params.pop( "attention_softmax_fp32", True ), "attention_kernel": model_params.pop("attention_kernel", None), } # Updating input and model params to account extra ids # for T5 Language Modeling task. extra_ids = kwargs.pop("extra_ids", 0) kwargs["src_vocab_size"] += extra_ids # T5 model has the same vocabulary size for source and target # sequences. if kwargs["tgt_vocab_size"] is None: kwargs["tgt_vocab_size"] = kwargs["src_vocab_size"] else: kwargs["tgt_vocab_size"] += extra_ids # T5 model does not distinguish dropout rate for # after relu computations, and utilizes the common dropout rate # across the whole model. Transformer, however, is using `0` # dropout rate there. if kwargs["relu_dropout_rate"] is None: kwargs["relu_dropout_rate"] = kwargs["dropout_rate"] model_params.pop("to_float16", None) model_params.pop("mixed_precision", None) model = T5ForConditionalGeneration(**kwargs) self.enable_vts = model_params.pop("enable_vts", False) if self.enable_vts: self.vts = cstorch.nn.StripPadding() else: self.vts = None # `precision_opt_level` is accessed later, # so we remove it from the list of unused params unused_params = [ key for key in model_params.keys() if key not in ["precision_opt_level", "use_bfloat16"] ] if unused_params: logging.warning( "The following model params are unused: " + ", ".join(unused_params) ) return model def _xentropy_loss(self, labels, logits, weights=None): """ Calculates MLM Cross-Entropy (to be used for Perplexity calculation) Args: labels: Tensor of shape (batch, sequence) and type int32. logits: Tensor of shape (batch, sequence, vocab) and type float. weights: Optional float Tensor of shape (batch, sequence). Returns: The loss tensor """ labels = labels.detach() logits = logits.detach() loss_fct = nn.CrossEntropyLoss(reduction="none") vocab_size = logits.shape[2] loss = loss_fct(logits.view(-1, vocab_size), labels.view(-1).long(),) if weights is not None: weights = weights.detach() loss = loss * weights.view(-1) return loss.sum() def forward(self, data): if self.enable_vts and not self.model.training: self.enable_vts = False logging.info( "VTS is only supported in train mode. Disabling for the " "current run." ) if self.enable_vts: data["input_ids"] = self.vts( data["input_ids"], data["attention_mask"] ) data["decoder_input_ids"] = self.vts( data["decoder_input_ids"], data["decoder_attention_mask"] ) data["labels"] = self.vts( data["labels"], data["decoder_attention_mask"] ) kwargs = { "input_ids": data["input_ids"], "attention_mask": data["attention_mask"], "decoder_input_ids": data["decoder_input_ids"], "decoder_attention_mask": data["decoder_attention_mask"], "labels": data["labels"], } logits = self.model(**kwargs) loss = None if data["labels"] is not None: loss = self.loss_fn( logits, data["labels"], data["decoder_attention_mask"], data.get("loss_weight", None), ).to(logits.dtype) # Calculate eval metrics if not training if not self.model.training and self.compute_eval_metrics: labels = data["labels"].clone() decoder_mask = ( data["decoder_attention_mask"].clone().to(logits.dtype) ) predictions = logits.argmax(-1).int() self.accuracy_metric( labels=labels, predictions=predictions, weights=decoder_mask, dtype=logits.dtype, ) # eval/perplexity_lm cross_entropy_loss = self._xentropy_loss( labels, logits, decoder_mask ) self.perplexity_metric( labels=labels, loss=cross_entropy_loss, weights=decoder_mask, dtype=logits.dtype, ) return loss