# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
#
# This code is adapted from
# https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_bert.py
#
# Copyright 2022 Cerebras Systems.
#
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from modelzoo.transformers.pytorch.bert.bert_model import BertModel
[docs]class BertForSequenceClassificationLoss(nn.Module):
[docs] def __init__(self, num_labels, problem_type):
super(BertForSequenceClassificationLoss, self).__init__()
self.num_labels = num_labels
self.problem_type = problem_type
def forward(self, labels, logits):
loss = None
if labels is not None:
if self.problem_type is None:
if self.num_labels == 1:
self.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.problem_type = "single_label_classification"
else:
self.problem_type = "multi_label_classification"
if self.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze().float())
else:
loss = loss_fct(logits, labels.float())
elif self.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1).long(),
)
elif self.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels.reshape(-1))
return loss
[docs]class BertForSequenceClassification(nn.Module):
[docs] def __init__(
self, num_labels, problem_type, classifier_dropout, **model_kwargs,
):
super().__init__()
self.num_labels = num_labels
self.problem_type = problem_type
self.bert = BertModel(**model_kwargs)
if classifier_dropout is None:
classifier_dropout = model_kwargs["dropout_rate"]
self.dropout = nn.Dropout(classifier_dropout)
hidden_size = model_kwargs["hidden_size"]
self.classifier = nn.Linear(hidden_size, self.num_labels)
self.loss_fn = BertForSequenceClassificationLoss(
self.num_labels, self.problem_type
)
self.__reset_parameters()
def reset_parameters(self):
self.bert.reset_parameters()
self.__reset_parameters()
def __reset_parameters(self):
self.classifier.weight.data.normal_(mean=0.0, std=0.02)
if self.classifier.bias is not None:
self.classifier.bias.data.zero_()
def forward(
self, input_ids=None, token_type_ids=None, attention_mask=None,
):
_, pooled_outputs = self.bert(
input_ids,
segment_ids=token_type_ids,
attention_mask=attention_mask,
)
pooled_outputs = self.dropout(pooled_outputs)
logits = self.classifier(pooled_outputs)
return logits
[docs]class BertForQuestionAnsweringLoss(nn.Module):
[docs] def __init__(self):
super(BertForQuestionAnsweringLoss, self).__init__()
def forward(self, logits, labels, cls_label_weights):
# [batch, max_seq_len, 2] -> [batch, 2, max_seq_len]
logits = torch.permute(logits, [0, 2, 1])
max_seq_len = logits.shape[-1]
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(logits.reshape(-1, max_seq_len), labels.view(-1).long())
return (loss * cls_label_weights.view(-1)).sum() / labels.shape[0]
[docs]class BertForQuestionAnswering(nn.Module):
[docs] def __init__(self, **model_kwargs):
super().__init__()
hidden_size = model_kwargs["hidden_size"]
self.bert = BertModel(**model_kwargs, add_pooling_layer=False)
self.classifier = nn.Linear(hidden_size, 2)
self.loss_fn = BertForQuestionAnsweringLoss()
self.__reset_parameters()
def reset_parameters(self):
self.bert.reset_parameters()
self.__reset_parameters()
def __reset_parameters(self):
self.classifier.weight.data.normal_(mean=0.0, std=0.02)
if self.classifier.bias is not None:
self.classifier.bias.data.zero_()
def forward(
self, input_ids=None, attention_mask=None, token_type_ids=None,
):
encoded_outputs, _ = self.bert(
input_ids,
segment_ids=token_type_ids,
attention_mask=attention_mask,
)
logits = self.classifier(encoded_outputs)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
return logits, start_logits, end_logits
[docs]class BertForTokenClassificationLoss(nn.Module):
[docs] def __init__(self, num_labels, loss_weight=1.0):
super(BertForTokenClassificationLoss, self).__init__()
self.num_labels = num_labels
self.loss_weight = loss_weight
def forward(self, logits, labels, attention_mask):
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1).long()
)
if attention_mask is not None:
# Only keep active parts of the loss
loss = loss * attention_mask.to(dtype=logits.dtype).view(-1)
loss = (torch.sum(loss) / labels.shape[0] * self.loss_weight).to(
logits.dtype
)
return loss
[docs]class BertForTokenClassification(nn.Module):
[docs] def __init__(
self,
num_labels,
classifier_dropout=None,
loss_weight=1.0,
include_padding_in_loss=True,
**model_kwargs,
):
super().__init__()
self.num_labels = num_labels
self.include_padding_in_loss = include_padding_in_loss
self.bert = BertModel(**model_kwargs, add_pooling_layer=False)
if classifier_dropout is None:
classifier_dropout = model_kwargs["dropout_rate"]
self.dropout = nn.Dropout(classifier_dropout)
hidden_size = model_kwargs["hidden_size"]
self.classifier = nn.Linear(hidden_size, num_labels)
self.loss_fn = BertForTokenClassificationLoss(
self.num_labels, loss_weight
)
self.__reset_parameters()
def reset_parameters(self):
self.bert.reset_parameters()
self.__reset_parameters()
def __reset_parameters(self):
self.classifier.weight.data.normal_(mean=0.0, std=0.02)
if self.classifier.bias is not None:
self.classifier.bias.data.zero_()
def forward(
self, input_ids=None, attention_mask=None, token_type_ids=None,
):
encoded_outputs, _ = self.bert(
input_ids,
segment_ids=token_type_ids,
attention_mask=attention_mask,
)
encoded_outputs = self.dropout(encoded_outputs)
logits = self.classifier(encoded_outputs)
return logits
[docs]class BertForSummarizationLoss(nn.Module):
[docs] def __init__(
self, num_labels, loss_weight=1.0,
):
super(BertForSummarizationLoss, self).__init__()
self.loss_weight = loss_weight
self.num_labels = num_labels
def forward(self, logits, labels, label_weights):
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1).long()
)
loss = loss * label_weights.view(-1)
loss = (loss.sum() / labels.shape[0] * self.loss_weight).to(
logits.dtype
)
return loss
[docs]class BertForSummarization(nn.Module):
[docs] def __init__(
self, num_labels=2, loss_weight=1.0, use_cls_bias=True, **model_kwargs,
):
super().__init__()
self.num_labels = num_labels
hidden_size = model_kwargs["hidden_size"]
self.bert = BertModel(**model_kwargs, add_pooling_layer=False)
self.classifier = nn.Linear(
hidden_size, self.num_labels, bias=use_cls_bias
)
self.loss_fn = BertForSummarizationLoss(self.num_labels, loss_weight,)
self.__reset_parameters()
def reset_parameters(self):
self.bert.reset_parameters()
self.__reset_parameters()
def __reset_parameters(self):
self.classifier.weight.data.normal_(mean=0.0, std=0.02)
if self.classifier.bias is not None:
self.classifier.bias.data.zero_()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
cls_tokens_positions=None,
):
encoded_outputs, _ = self.bert(
input_ids,
segment_ids=token_type_ids,
attention_mask=attention_mask,
)
hidden_size = list(encoded_outputs.size())[-1]
batch_size, max_pred = list(cls_tokens_positions.size())
index = torch.broadcast_to(
cls_tokens_positions.unsqueeze(2),
(batch_size, max_pred, hidden_size),
).long()
masked_output = torch.gather(encoded_outputs, dim=1, index=index)
encoded_outputs = masked_output
logits = self.classifier(encoded_outputs)
return logits