Source code for modelzoo.transformers.pytorch.bert.bert_finetune_models

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding=utf-8
#
# This code is adapted from
# https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_bert.py
#
# Copyright 2022 Cerebras Systems.
#
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

from modelzoo.transformers.pytorch.bert.bert_model import BertModel


[docs]class BertForSequenceClassificationLoss(nn.Module):
[docs] def __init__(self, num_labels, problem_type): super(BertForSequenceClassificationLoss, self).__init__() self.num_labels = num_labels self.problem_type = problem_type
def forward(self, labels, logits): loss = None if labels is not None: if self.problem_type is None: if self.num_labels == 1: self.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.problem_type = "single_label_classification" else: self.problem_type = "multi_label_classification" if self.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze().float()) else: loss = loss_fct(logits, labels.float()) elif self.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_labels), labels.view(-1).long(), ) elif self.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels.reshape(-1)) return loss
[docs]class BertForSequenceClassification(nn.Module):
[docs] def __init__( self, num_labels, problem_type, classifier_dropout, **model_kwargs, ): super().__init__() self.num_labels = num_labels self.problem_type = problem_type self.bert = BertModel(**model_kwargs) if classifier_dropout is None: classifier_dropout = model_kwargs["dropout_rate"] self.dropout = nn.Dropout(classifier_dropout) hidden_size = model_kwargs["hidden_size"] self.classifier = nn.Linear(hidden_size, self.num_labels) self.loss_fn = BertForSequenceClassificationLoss( self.num_labels, self.problem_type ) self.__reset_parameters()
def reset_parameters(self): self.bert.reset_parameters() self.__reset_parameters() def __reset_parameters(self): self.classifier.weight.data.normal_(mean=0.0, std=0.02) if self.classifier.bias is not None: self.classifier.bias.data.zero_() def forward( self, input_ids=None, token_type_ids=None, attention_mask=None, ): _, pooled_outputs = self.bert( input_ids, segment_ids=token_type_ids, attention_mask=attention_mask, ) pooled_outputs = self.dropout(pooled_outputs) logits = self.classifier(pooled_outputs) return logits
[docs]class BertForQuestionAnsweringLoss(nn.Module):
[docs] def __init__(self): super(BertForQuestionAnsweringLoss, self).__init__()
def forward(self, logits, labels, cls_label_weights): # [batch, max_seq_len, 2] -> [batch, 2, max_seq_len] logits = torch.permute(logits, [0, 2, 1]) max_seq_len = logits.shape[-1] loss_fct = nn.CrossEntropyLoss(reduction='none') loss = loss_fct(logits.reshape(-1, max_seq_len), labels.view(-1).long()) return (loss * cls_label_weights.view(-1)).sum() / labels.shape[0]
[docs]class BertForQuestionAnswering(nn.Module):
[docs] def __init__(self, **model_kwargs): super().__init__() hidden_size = model_kwargs["hidden_size"] self.bert = BertModel(**model_kwargs, add_pooling_layer=False) self.classifier = nn.Linear(hidden_size, 2) self.loss_fn = BertForQuestionAnsweringLoss() self.__reset_parameters()
def reset_parameters(self): self.bert.reset_parameters() self.__reset_parameters() def __reset_parameters(self): self.classifier.weight.data.normal_(mean=0.0, std=0.02) if self.classifier.bias is not None: self.classifier.bias.data.zero_() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, ): encoded_outputs, _ = self.bert( input_ids, segment_ids=token_type_ids, attention_mask=attention_mask, ) logits = self.classifier(encoded_outputs) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() return logits, start_logits, end_logits
[docs]class BertForTokenClassificationLoss(nn.Module):
[docs] def __init__(self, num_labels, loss_weight=1.0): super(BertForTokenClassificationLoss, self).__init__() self.num_labels = num_labels self.loss_weight = loss_weight
def forward(self, logits, labels, attention_mask): loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(reduction='none') loss = loss_fct( logits.view(-1, self.num_labels), labels.view(-1).long() ) if attention_mask is not None: # Only keep active parts of the loss loss = loss * attention_mask.to(dtype=logits.dtype).view(-1) loss = (torch.sum(loss) / labels.shape[0] * self.loss_weight).to( logits.dtype ) return loss
[docs]class BertForTokenClassification(nn.Module):
[docs] def __init__( self, num_labels, classifier_dropout=None, loss_weight=1.0, include_padding_in_loss=True, **model_kwargs, ): super().__init__() self.num_labels = num_labels self.include_padding_in_loss = include_padding_in_loss self.bert = BertModel(**model_kwargs, add_pooling_layer=False) if classifier_dropout is None: classifier_dropout = model_kwargs["dropout_rate"] self.dropout = nn.Dropout(classifier_dropout) hidden_size = model_kwargs["hidden_size"] self.classifier = nn.Linear(hidden_size, num_labels) self.loss_fn = BertForTokenClassificationLoss( self.num_labels, loss_weight ) self.__reset_parameters()
def reset_parameters(self): self.bert.reset_parameters() self.__reset_parameters() def __reset_parameters(self): self.classifier.weight.data.normal_(mean=0.0, std=0.02) if self.classifier.bias is not None: self.classifier.bias.data.zero_() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, ): encoded_outputs, _ = self.bert( input_ids, segment_ids=token_type_ids, attention_mask=attention_mask, ) encoded_outputs = self.dropout(encoded_outputs) logits = self.classifier(encoded_outputs) return logits
[docs]class BertForSummarizationLoss(nn.Module):
[docs] def __init__( self, num_labels, loss_weight=1.0, ): super(BertForSummarizationLoss, self).__init__() self.loss_weight = loss_weight self.num_labels = num_labels
def forward(self, logits, labels, label_weights): loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(reduction="none") loss = loss_fct( logits.view(-1, self.num_labels), labels.view(-1).long() ) loss = loss * label_weights.view(-1) loss = (loss.sum() / labels.shape[0] * self.loss_weight).to( logits.dtype ) return loss
[docs]class BertForSummarization(nn.Module):
[docs] def __init__( self, num_labels=2, loss_weight=1.0, use_cls_bias=True, **model_kwargs, ): super().__init__() self.num_labels = num_labels hidden_size = model_kwargs["hidden_size"] self.bert = BertModel(**model_kwargs, add_pooling_layer=False) self.classifier = nn.Linear( hidden_size, self.num_labels, bias=use_cls_bias ) self.loss_fn = BertForSummarizationLoss(self.num_labels, loss_weight,) self.__reset_parameters()
def reset_parameters(self): self.bert.reset_parameters() self.__reset_parameters() def __reset_parameters(self): self.classifier.weight.data.normal_(mean=0.0, std=0.02) if self.classifier.bias is not None: self.classifier.bias.data.zero_() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, cls_tokens_positions=None, ): encoded_outputs, _ = self.bert( input_ids, segment_ids=token_type_ids, attention_mask=attention_mask, ) hidden_size = list(encoded_outputs.size())[-1] batch_size, max_pred = list(cls_tokens_positions.size()) index = torch.broadcast_to( cls_tokens_positions.unsqueeze(2), (batch_size, max_pred, hidden_size), ).long() masked_output = torch.gather(encoded_outputs, dim=1, index=index) encoded_outputs = masked_output logits = self.classifier(encoded_outputs) return logits