# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import cerebras_pytorch as cstorch
from modelzoo.common.pytorch.layers import (
EmbeddingLayer,
TransformerDecoder,
TransformerDecoderLayer,
)
from modelzoo.common.pytorch.model_utils.norms import get_norm
from modelzoo.transformers.pytorch.gpt2.sparse_mask import (
create_fixed_sparse_attention_mask,
)
from modelzoo.transformers.pytorch.transformer_utils import (
build_broadcastable_attention_mask,
make_sparse_mask_broadcastable,
)
[docs]class GPT2LMHeadModel(nn.Module):
"""
GPT-2 model with LM head
"""
[docs] def __init__(
self,
# Embedding
vocab_size=50257,
max_position_embeddings=1024,
embd_pdrop=0.1,
position_embedding_type="learned",
position_embedding_offset=0,
hidden_size=768,
share_embedding_weights=True,
embedding_layer_norm=False,
num_relative_attention_buckets=32,
rotary_dim=None,
rope_theta=10000,
# Encoder
num_hidden_layers=12,
dropout_rate=0.1,
norm_type="layernorm",
layer_norm_epsilon=1.0e-5,
# Encoder - Attention
num_heads=12,
attention_type="scaled_dot_product",
attention_module="aiayn_attention",
extra_attention_params={},
use_projection_bias_in_attention=True,
use_ffn_bias_in_attention=True,
attention_dropout_rate=0.1,
attention_softmax_fp32=True,
fixed_sparse_attention=None,
# Encoder - ffn
filter_size=3072,
nonlinearity="gelu",
use_ffn_bias=True,
# Task-specific
use_bias_in_output=False,
initializer_range=0.02,
embedding_initializer=None,
initializer=None,
output_layer_initializer=None,
# muP (maximal update parameterization) parameters
output_logits_scale=None,
embeddings_scale=1.0,
scale_qk_dot_by_d=False,
alibi_trainable_slopes=False,
pos_scaling_factor=1.0,
scale_qk_dot_by_layer_idx=False,
):
super(GPT2LMHeadModel, self).__init__()
# std deviation for weight initialization
self.initializer_range = initializer_range
self.num_hidden_layers = num_hidden_layers
self.share_embedding_weights = share_embedding_weights
self.embedding_layer_norm = embedding_layer_norm
self.max_position_embeddings = max_position_embeddings
self.position_embedding_type = position_embedding_type
self.embeddings_scale = embeddings_scale
self.num_heads = num_heads
if initializer is None:
attention_initializer = {
"name": "truncated_normal",
"mean": 0.0,
"std": self.initializer_range,
}
ffn_initializer = {
"name": "truncated_normal",
"mean": 0.0,
"std": self.initializer_range,
}
if output_layer_initializer is None:
output_layer_initializer = {
"name": "truncated_normal",
"mean": 0.0,
"std": self.initializer_range
/ math.sqrt(2 * self.num_hidden_layers),
}
else:
attention_initializer = initializer
ffn_initializer = initializer
if embedding_initializer is None:
embedding_initializer = {
"name": "truncated_normal",
"mean": 0.0,
"std": self.initializer_range,
}
norm_class = get_norm(norm_type)
if position_embedding_type == "rotary":
if rotary_dim is None:
rotary_dim = hidden_size // num_heads
# https://github.com/huggingface/transformers/blob/f0577df6de36e7e7f28e90fa76da0657de038a39/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L84-L85
# https://arxiv.org/pdf/2104.09864.pdf Section 3.3
assert (
rotary_dim <= hidden_size / num_heads
), "Rotary dimensions should be <= hidden size divided by number of attention heads."
assert (
rotary_dim % 2 == 0
), "Rotary dimension must be an even number."
self.embedding_layer = EmbeddingLayer(
vocab_size=vocab_size,
embedding_size=hidden_size,
embeddings_initializer=embedding_initializer,
position_embedding_type=position_embedding_type,
position_embeddings_initializer=embedding_initializer,
max_position_embeddings=max_position_embeddings,
position_embedding_offset=position_embedding_offset,
num_heads=num_heads,
num_relative_attention_buckets=num_relative_attention_buckets,
rotary_dim=rotary_dim,
rope_theta=rope_theta,
alibi_trainable_slopes=alibi_trainable_slopes,
pos_scaling_factor=pos_scaling_factor,
)
if self.embedding_layer_norm:
self.embedding_ln_f = norm_class(
hidden_size, eps=layer_norm_epsilon
)
self.drop_embd = nn.Dropout(embd_pdrop)
decoder_layer = TransformerDecoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=filter_size,
dropout=dropout_rate,
activation=nonlinearity,
layer_norm_eps=layer_norm_epsilon,
norm_layer=norm_class,
norm_first=True,
extra_attention_params=extra_attention_params,
add_cross_attention=False,
attention_type=attention_type,
scale_qk_dot_by_d=scale_qk_dot_by_d,
scale_qk_dot_by_layer_idx=scale_qk_dot_by_layer_idx,
attention_module=attention_module,
attention_dropout_rate=attention_dropout_rate,
attention_softmax_fp32=attention_softmax_fp32,
use_projection_bias_in_attention=use_projection_bias_in_attention,
use_ffn_bias_in_attention=use_ffn_bias_in_attention,
use_ffn_bias=use_ffn_bias,
attention_initializer=attention_initializer,
attention_output_layer_initializer=output_layer_initializer,
ffn_initializer=ffn_initializer,
ffn_output_layer_initializer=output_layer_initializer,
use_ff_layer1_dropout=False,
)
self.output_logits_scale = output_logits_scale
# Final LayerNorm
self.ln_f = norm_class(hidden_size, eps=layer_norm_epsilon)
self.transformer_decoder = TransformerDecoder(
decoder_layer, num_layers=num_hidden_layers, norm=self.ln_f,
)
if fixed_sparse_attention is not None:
self.fixed_sparsity_mask = create_fixed_sparse_attention_mask(
max_sequence_length=max_position_embeddings,
n_heads=num_heads,
**fixed_sparse_attention,
)
else:
self.fixed_sparsity_mask = None
self.lm_head = nn.Linear(
hidden_size, vocab_size, bias=use_bias_in_output
)
self.tie_weights()
self.__reset_parameters()
def reset_parameters(self):
self.embedding_layer.reset_parameters()
self.transformer_decoder.reset_parameters()
self.__reset_parameters()
def __reset_parameters(self):
# Init final norm layer
if hasattr(self.ln_f, "bias"):
self.ln_f.bias.data.zero_()
self.ln_f.weight.data.fill_(1.0)
# Initialize LM head
if not self.share_embedding_weights:
self.lm_head.weight.data.normal_(
mean=0.0, std=self.initializer_range
)
if self.lm_head.bias is not None:
self.lm_head.bias.data.zero_()
def tie_weights(self):
if not self.share_embedding_weights:
return
output_embedding = self.get_output_embeddings()
input_embedding = self.get_input_embeddings()
output_embedding.weight = input_embedding.weight
if getattr(output_embedding, "bias", None) is not None:
output_embedding.bias.data = nn.functional.pad(
output_embedding.bias.data,
(
0,
output_embedding.weight.shape[0]
- output_embedding.bias.shape[0],
),
"constant",
0,
)
if hasattr(output_embedding, "out_features") and hasattr(
input_embedding, "num_embeddings"
):
output_embedding.out_features = input_embedding.num_embeddings
def get_output_embeddings(self):
return self.lm_head
def get_input_embeddings(self):
return self.embedding_layer.get_input_embeddings()
def compute_input_embeddings(self, input_ids, position_ids=None):
hidden_states = self.embedding_layer(
input_ids, position_ids=position_ids
)
if self.embedding_layer_norm:
hidden_states = self.embedding_ln_f(hidden_states)
hidden_states = hidden_states * torch.tensor(
float(self.embeddings_scale), dtype=hidden_states.dtype
)
hidden_states = self.drop_embd(hidden_states)
return hidden_states
def forward(
self,
input_ids=None,
attention_mask=None,
attention_span=None,
position_ids=None,
):
hidden_states = self.compute_input_embeddings(input_ids, position_ids)
causal_attention_mask = build_broadcastable_attention_mask(
attention_mask,
attention_span=attention_span,
build_causal=True,
device=input_ids.device,
dtype=hidden_states.dtype,
num_heads=self.num_heads,
)
# Fixed sparse attention, used in GPT-3 model
sparse_attention_mask = None
if self.fixed_sparsity_mask is not None:
sparse_attention_mask = make_sparse_mask_broadcastable(
self.fixed_sparsity_mask,
attention_mask,
dtype=hidden_states.dtype,
device=hidden_states.device,
revert_mask=False,
)
# Helpers on alibi/relative position embeddings bias
length = input_ids.shape[1]
self_attn_position_bias = self.embedding_layer.compute_position_bias(
length, length
)
hidden_states = self.transformer_decoder(
hidden_states,
tgt_mask=causal_attention_mask,
sparse_mask=sparse_attention_mask,
rotary_position_embedding_helper=self.embedding_layer.get_rope_helper(),
self_attn_position_bias=self_attn_position_bias,
)
if (
cstorch.use_cs()
and cstorch.current_executor().cs_config.precision_opt_level == 1
):
lm_logits = cstorch.pol(bwd_level=0)(self.lm_head)(hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
# scale lm_logits for muP transfer
if self.output_logits_scale:
lm_logits = lm_logits * torch.tensor(
float(self.output_logits_scale), dtype=lm_logits.dtype,
)
return lm_logits