# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from cerebras_pytorch.metrics import AccuracyMetric, PerplexityMetric
from modelzoo.common.pytorch.model_utils.GPTLMHeadModelLoss import (
GPTLMHeadModelLoss,
)
from modelzoo.transformers.pytorch.gptj.gptj_model import GPTJModel
[docs]class GptjModel(torch.nn.Module):
"""
GPT-2 models
"""
[docs] def __init__(self, params):
super().__init__()
model_params = params["model"].copy()
self.compute_eval_metrics = model_params.pop(
"compute_eval_metrics", True
)
if self.compute_eval_metrics:
self.perplexity_metric = PerplexityMetric(name="eval/lm_perplexity")
self.accuracy_metric = AccuracyMetric(name="eval/accuracy")
self.model = self.build_model(model_params)
self.loss_fn = GPTLMHeadModelLoss(
params["model"]["vocab_size"], self.loss_scaling, self.loss_weight,
)
def _post_device_transfer(self):
self.model.tie_weights()
def build_model(self, model_params):
attention_type = model_params.pop("attention_type")
if attention_type not in ["scaled_dot_product", "dot_product"]:
raise ValueError(
"attention_type should be 'scaled_dot_product' or 'dot_product'."
)
position_embedding_type = model_params.pop(
"position_embedding_type", "rotary"
).lower()
assert (
position_embedding_type != "alibi"
), "alibi position embedding is not yet supported by gptj"
rope_theta = model_params.pop("rope_theta", 10000)
rotary_dim = None
num_relative_attention_buckets = None
if position_embedding_type == "rotary":
rotary_dim = model_params.pop(
"rotary_dim",
int(
model_params["hidden_size"]
// model_params["num_heads"]
* 0.25
),
)
# https://github.com/huggingface/transformers/blob/f0577df6de36e7e7f28e90fa76da0657de038a39/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L84-L85
# https://arxiv.org/pdf/2104.09864.pdf Section 3.3
assert (
rotary_dim
<= model_params["hidden_size"] / model_params["num_heads"]
), "Rotary dimensions should be <= hidden size divided by number of attention heads."
assert (
rotary_dim % 2 == 0
), "Rotary dimension must be an even number."
else:
# relative PE
num_relative_attention_buckets = model_params.pop(
"num_relative_attention_buckets", 32
)
self.loss_weight = model_params.pop("loss_weight", 1.0)
self.loss_scaling = model_params.pop(
"loss_scaling", "num_tokens"
).lower()
if self.loss_weight != 1.0 and self.loss_scaling == "num_tokens":
logging.warning(
f"loss_weight cannot be {self.loss_weight} for num_tokens "
f"loss_scaling. Setting loss_weight to 1.0."
)
self.loss_weight = 1.0
model = GPTJModel(
hidden_size=model_params.pop("hidden_size"),
# Embedding params
vocab_size=model_params.pop("vocab_size"),
max_position_embeddings=model_params.pop(
"max_position_embeddings", 1024
),
embd_pdrop=model_params.pop("embedding_dropout_rate", 0.1),
share_embedding_weights=model_params.pop(
"share_embedding_weights", True
),
position_embedding_type=position_embedding_type,
rotary_dim=rotary_dim,
rope_theta=rope_theta,
num_relative_attention_buckets=num_relative_attention_buckets,
# Decoder params
num_hidden_layers=model_params.pop("num_hidden_layers"),
filter_size=model_params.pop("filter_size"),
dropout_rate=model_params.pop("residual_dropout_rate", 0.1),
nonlinearity=model_params.pop("nonlinearity", "gelu"),
norm_type=model_params.pop("norm_type", "layernorm"),
layer_norm_epsilon=float(
model_params.pop("layer_norm_epsilon", 1.0e-5)
),
use_ffn_bias=model_params.pop("use_ffn_bias", False),
use_untied_layer_norm=model_params.pop(
"use_untied_layer_norm", False
),
# Attention params
num_heads=model_params.pop("num_heads"),
attention_module=model_params.pop(
"attention_module", "aiayn_attention"
),
extra_attention_params=model_params.pop(
"extra_attention_params", {}
),
attention_type=attention_type,
attention_dropout_rate=model_params.pop(
"attention_dropout_rate", 0.1
),
attention_softmax_fp32=model_params.pop(
"attention_softmax_fp32", True
),
use_projection_bias_in_attention=model_params.pop(
"use_projection_bias_in_attention", True
),
use_ffn_bias_in_attention=model_params.pop(
"use_ffn_bias_in_attention", True
),
# Task-specific
initializer_range=model_params.pop("initializer_range", 0.02),
use_bias_in_output=model_params.pop("use_bias_in_output", False),
norm_first=model_params.pop("norm_first", True),
embedding_initializer=model_params.pop(
"embedding_initializer", None
),
attention_initializer=model_params.pop("initializer", None),
output_layer_initializer=model_params.pop(
"output_layer_initializer", None
),
alibi_trainable_slopes=model_params.pop(
"alibi_trainable_slopes", False
),
pos_scaling_factor=float(
model_params.pop("pos_scaling_factor", 1.0)
),
)
model_params.pop("mixed_precision", None)
# `fp16_type` is accessed later,
# so we remove these from the list of unused params
unused_params = [
key for key in model_params.keys() if key != "fp16_type"
]
if unused_params:
logging.warning(
"The following model params are unused: "
+ ", ".join(unused_params)
)
return model
def forward(self, data, reduce_batch=True):
assert (
"input_ids" in data
and "attention_mask" in data
and "labels" in data
), "GPT-J model expects these data fields: input_ids, attention_mask, labels"
assert (
data["input_ids"].dtype == torch.int32
and data["attention_mask"].dtype == torch.int32
and data["labels"].dtype == torch.int32
), "The dtype for all inputs should be torch.int32"
lm_logits = self.model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
attention_span=data.get("attention_span"), # VSL-only input
position_ids=data.get("position_ids"), # VSL-only input
)
loss = self.loss_fn(
lm_logits,
data["labels"],
data["attention_mask"],
reduce_batch=reduce_batch,
)
# Calculate eval metrics if not training
if not self.model.training and self.compute_eval_metrics:
lm_labels = data["labels"].clone()
lm_weights = data["attention_mask"].clone()
lm_preds = lm_logits.argmax(-1).int()
self.accuracy_metric(
labels=lm_labels, predictions=lm_preds, weights=lm_weights,
)
if self.loss_scaling == "num_tokens":
unscaled_loss = loss * torch.sum(
data["attention_mask"].clone(), dtype=torch.float32
)
elif self.loss_scaling == "batch_size":
unscaled_loss = loss * torch.tensor(
lm_labels.shape[0] / self.loss_weight, dtype=torch.float32
)
else:
raise ValueError(
f"Loss scaling can't be set to {self.loss_scaling}. \
Should be either 'num_tokens' or 'batch_size'"
)
self.perplexity_metric(
labels=lm_labels, loss=unscaled_loss, weights=lm_weights,
)
return loss