Source code for modelzoo.transformers.pytorch.bert.fine_tuning.classifier.model

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from cerebras_pytorch.metrics import AccuracyMetric, FBetaScoreMetric
from modelzoo.transformers.pytorch.bert.bert_finetune_models import (
    BertForSequenceClassification,
    BertForSequenceClassificationLoss,
)
from modelzoo.transformers.pytorch.bert.utils import check_unused_model_params


[docs]class BertForSequenceClassificationModel(torch.nn.Module):
[docs] def __init__(self, params): super().__init__() model_params = params["model"].copy() self.num_labels = model_params.pop("num_labels") problem_type = model_params.pop("problem_type") classifier_dropout = model_params.pop("task_dropout") dropout_rate = model_params.pop("dropout_rate") embedding_dropout_rate = model_params.pop( "embedding_dropout_rate", dropout_rate ) model_kwargs = { "vocab_size": model_params.pop("vocab_size"), "hidden_size": model_params.pop("hidden_size"), "num_hidden_layers": model_params.pop("num_hidden_layers"), "num_heads": model_params.pop("num_heads"), "filter_size": model_params.pop("filter_size"), "nonlinearity": model_params.pop("encoder_nonlinearity"), "pooler_nonlinearity": model_params.pop( "pooler_nonlinearity", None ), "embedding_dropout_rate": embedding_dropout_rate, "dropout_rate": dropout_rate, "attention_dropout_rate": model_params.pop( "attention_dropout_rate" ), "max_position_embeddings": model_params.pop( "max_position_embeddings" ), "layer_norm_epsilon": float(model_params.pop("layer_norm_epsilon")), } self.model = BertForSequenceClassification( self.num_labels, problem_type, classifier_dropout, **model_kwargs, ) self.loss_fn = BertForSequenceClassificationLoss( self.num_labels, problem_type ) self.compute_eval_metrics = model_params.pop( "compute_eval_metrics", False ) # Below flag helps create two more accuracy objects for # matched and mismatched partitions self.is_mnli_dataset = model_params.pop("is_mnli_dataset", False) check_unused_model_params(model_params) if self.compute_eval_metrics: self.accuracy_metric = AccuracyMetric(name="eval/accuracy") if self.num_labels == 2: self.f1_metric = FBetaScoreMetric( num_classes=self.num_labels, beta=1.0, average_type="micro", name="eval/f1_score", ) if self.is_mnli_dataset: self.matched_accuracy_metric = AccuracyMetric( name="eval/accuracy_matched" ) self.mismatched_accuracy_metric = AccuracyMetric( name="eval/accuracy_mismatched" )
def forward(self, data): logits = self.model( input_ids=data["input_ids"], token_type_ids=data["token_type_ids"], attention_mask=data["attention_mask"], ) loss = self.loss_fn(data["labels"], logits) if not self.model.training and self.compute_eval_metrics: labels = data["labels"].clone() predictions = logits.argmax(-1).int() self.accuracy_metric( labels=labels, predictions=predictions, dtype=logits.dtype ) if self.num_labels == 2: self.f1_metric(labels=labels, predictions=predictions) if self.is_mnli_dataset: self.matched_accuracy_metric( labels=labels, predictions=predictions, weights=data["is_matched"], dtype=logits.dtype, ) self.mismatched_accuracy_metric( labels=labels, predictions=predictions, weights=data["is_mismatched"], dtype=logits.dtype, ) return loss