# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn import CrossEntropyLoss
from cerebras_pytorch.metrics import AccuracyMetric, PerplexityMetric
from modelzoo.common.pytorch.model_utils.BertPretrainModelLoss import (
BertPretrainModelLoss,
)
from modelzoo.transformers.pytorch.bert.bert_pretrain_models import (
BertPretrainModel,
)
from modelzoo.transformers.pytorch.bert.utils import check_unused_model_params
[docs]class BertForPreTrainingModel(torch.nn.Module):
"""
BERT-based models
"""
[docs] def __init__(self, params):
super().__init__()
model_params = params["model"].copy()
self.model = self.build_model(model_params)
self.loss_fn = BertPretrainModelLoss(
disable_nsp=self.disable_nsp,
mlm_loss_weight=self.mlm_loss_weight,
label_smoothing=self.label_smoothing,
)
self.compute_eval_metrics = model_params.pop(
"compute_eval_metrics", True
)
if self.compute_eval_metrics:
if not self.disable_nsp:
self.accuracy_metric_cls = AccuracyMetric(
name="eval/accuracy_cls"
)
self.accuracy_metric_mlm = AccuracyMetric(
name="eval/accuracy_masked_lm"
)
self.perplexity_metric = PerplexityMetric(
name="eval/mlm_perplexity"
)
def _post_device_transfer(self):
self.model.tie_weights()
def model_class(self):
return BertPretrainModel
def build_model(self, model_params):
cls = self.model_class()
args = self.build_model_args(model_params)
check_unused_model_params(model_params)
return cls(**args)
def build_model_args(self, model_params):
self.disable_nsp = model_params.pop("disable_nsp", False)
self.mlm_loss_weight = model_params.pop("mlm_loss_weight", 1.0)
self.label_smoothing = model_params.pop("label_smoothing", 0.0)
self.vocab_size = model_params.pop("vocab_size")
position_embedding_type = model_params.pop(
"position_embedding_type", "learned"
).lower()
rotary_dim = None
if position_embedding_type == "rotary":
# rotary_dim defaults to 25% of head dim (hidden_size / num_heads)
# similar to other models that use RoPE like GPT-NeoX
rotary_dim = model_params.pop(
"rotary_dim",
int(
model_params["hidden_size"]
// model_params["num_heads"]
* 0.25
),
)
# https://github.com/huggingface/transformers/blob/f0577df6de36e7e7f28e90fa76da0657de038a39/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L84-L85
# https://arxiv.org/pdf/2104.09864.pdf Section 3.3
assert (
rotary_dim
<= model_params["hidden_size"] / model_params["num_heads"]
), "Rotary dimensions should be <= hidden size divided by number of attention heads."
assert (
rotary_dim % 2 == 0
), "Rotary dimension must be an even number."
return {
"disable_nsp": self.disable_nsp,
"num_classes": model_params.pop("num_classes", 2),
"vocab_size": self.vocab_size,
"max_position_embeddings": model_params.pop(
"max_position_embeddings"
),
"position_embedding_type": position_embedding_type,
"embedding_pad_token_id": model_params.pop("pad_token_id", 0),
"mask_padding_in_positional_embed": model_params.pop(
"mask_padding_in_positional_embed", False
),
"rotary_dim": rotary_dim,
"rope_theta": model_params.pop("rope_theta", 10000),
"num_relative_attention_buckets": model_params.pop(
"num_relative_attention_buckets", 32
),
"alibi_trainable_slopes": model_params.pop(
"alibi_trainable_slopes", False
),
"pos_scaling_factor": float(
model_params.pop("pos_scaling_factor", 1.0)
),
"hidden_size": model_params.pop("hidden_size"),
"share_embedding_weights": model_params.pop(
"share_embedding_weights", True
),
"num_hidden_layers": model_params.pop("num_hidden_layers"),
"layer_norm_epsilon": float(model_params.pop("layer_norm_epsilon")),
# Encoder Attn
"num_heads": model_params.pop("num_heads"),
"attention_module": model_params.pop(
"attention_module", "aiayn_attention"
),
"extra_attention_params": model_params.pop(
"extra_attention_params", {}
),
"attention_type": model_params.pop(
"attention_type", "scaled_dot_product"
),
"dropout_rate": model_params.pop("dropout_rate"),
"nonlinearity": model_params.pop("encoder_nonlinearity", "gelu"),
"mlm_nonlinearity": model_params.pop("mlm_nonlinearity", None),
"pooler_nonlinearity": model_params.pop(
"pooler_nonlinearity", None
),
"attention_dropout_rate": model_params.pop(
"attention_dropout_rate"
),
"attention_softmax_fp32": model_params.pop(
"attention_softmax_fp32", True
),
"use_projection_bias_in_attention": model_params.pop(
"use_projection_bias_in_attention", True
),
"use_ffn_bias_in_attention": model_params.pop(
"use_ffn_bias_in_attention", True
),
"filter_size": model_params.pop("filter_size"),
"use_ffn_bias": model_params.pop("use_ffn_bias", True),
"use_ffn_bias_in_mlm": model_params.pop(
"use_ffn_bias_in_mlm", True
),
"use_output_bias_in_mlm": model_params.pop(
"use_output_bias_in_mlm", True
),
"initializer_range": model_params.pop("initializer_range", 0.02),
"num_segments": model_params.pop(
"num_segments", None if self.disable_nsp else 2
),
}
[docs] def mlm_xentropy_loss(self, labels, logits, weights=None):
"""
Calculates MLM Cross-Entropy (to be used for Perplexity calculation)
Args:
labels: Tensor of shape (batch, sequence) and type int32.
logits: Tensor of shape (batch, sequence, vocab) and type float.
weights: Optional float Tensor of shape (batch, sequence).
Returns:
The loss tensor
"""
labels = labels.detach()
logits = logits.detach()
loss_fct = CrossEntropyLoss(reduction="none")
vocab_size = logits.shape[2]
loss = loss_fct(logits.view(-1, vocab_size), labels.view(-1).long(),)
if weights is not None:
weights = weights.detach()
loss = loss * weights.view(-1)
return loss.sum()
def forward(self, data):
next_sentence_label = data.pop("next_sentence_label", None)
# MLM Needs a half precision "weights" tensor; use binary mask for now.
masked_lm_weights = data.pop("masked_lm_mask")
should_calc_loss = data.pop("should_calc_loss", True)
mlm_loss_scale = data.pop("mlm_loss_scale", None)
labels = data.pop("labels")
_, len_labels = list(labels.size())
seq_len = data["input_ids"].shape[1]
should_gather_mlm_labels = len_labels != seq_len
data["should_gather_mlm_labels"] = should_gather_mlm_labels
mlm_logits, nsp_logits, _, _ = self.model(**data)
if mlm_loss_scale is not None:
mlm_loss_scale = mlm_loss_scale.to(mlm_logits.dtype)
masked_lm_weights = masked_lm_weights.to(mlm_logits.dtype)
total_loss = None
if should_calc_loss:
total_loss = self.loss_fn(
mlm_logits,
self.vocab_size,
labels,
nsp_logits,
next_sentence_label,
masked_lm_weights,
mlm_loss_scale,
)
if not self.model.training and self.compute_eval_metrics:
if not self.disable_nsp:
nsp_label = next_sentence_label.clone()
nsp_pred = nsp_logits.argmax(-1).int()
# eval/accuracy_cls
self.accuracy_metric_cls(
labels=nsp_label,
predictions=nsp_pred,
dtype=mlm_logits.dtype,
)
mlm_preds = mlm_logits.argmax(-1).int()
mlm_labels = labels.clone()
mlm_weights = masked_lm_weights.clone()
mlm_xentr = self.mlm_xentropy_loss(
mlm_labels, mlm_logits, mlm_weights
)
# eval/accuracy_masked_lm
self.accuracy_metric_mlm(
labels=mlm_labels,
predictions=mlm_preds,
weights=mlm_weights,
dtype=mlm_logits.dtype,
)
# eval/mlm_perplexity
self.perplexity_metric(
labels=mlm_labels,
loss=mlm_xentr,
weights=mlm_weights,
dtype=mlm_logits.dtype,
)
return total_loss