Source code for cerebras.modelzoo.losses.BertPretrainModelLoss

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding=utf-8
#
# This code is adapted from
# https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_bert.py
#
# Copyright 2022 Cerebras Systems.
#
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn

from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.common.utils.model.transformer_utils import smooth_loss


[docs]@registry.register_loss("BertPretrainModelLoss") class BertPretrainModelLoss(nn.Module):
[docs] def __init__( self, disable_nsp=False, mlm_loss_weight=1.0, label_smoothing=0.0, ): super(BertPretrainModelLoss, self).__init__() self.disable_nsp = disable_nsp self.mlm_loss_weight = mlm_loss_weight self.label_smoothing = label_smoothing
def forward( self, mlm_logits, vocab_size, mlm_labels, nsp_logits, nsp_labels, mlm_weights, mlm_loss_scale=None, ): mlm_loss_fn = nn.CrossEntropyLoss(reduction="none") if mlm_loss_scale is not None: mlm_weights = mlm_weights * mlm_loss_scale mlm_loss = mlm_loss_fn( mlm_logits.view(-1, vocab_size), mlm_labels.view(-1).long(), ) if self.label_smoothing > 0.0 and self.training: # Calculate loss correction for label smoothing mlm_loss = smooth_loss( mlm_logits, mlm_loss, self.label_smoothing, vocab_size ) mlm_loss *= mlm_weights.view(-1) mlm_loss = torch.sum(mlm_loss) / mlm_labels.shape[0] if mlm_loss_scale is None: mlm_loss *= self.mlm_loss_weight total_loss = mlm_loss if not self.disable_nsp: nsp_loss_fn = nn.CrossEntropyLoss() nsp_loss = nsp_loss_fn( nsp_logits.view(-1, 2), nsp_labels.view(-1).long(), ) total_loss += nsp_loss return total_loss