# Copyright 2022 Cerebras Systems.
# This code is adapted from

Language Modeling with nn.Transformer and TorchText
This is a tutorial on training a sequence-to-sequence model that uses the
`nn.Transformer <>`__ module.
The PyTorch 1.2 release includes a standard transformer module based on the
paper `Attention is All You Need <>`__.
Compared to Recurrent Neural Networks (RNNs), the transformer model has proven
to be superior in quality for many sequence-to-sequence tasks while being more
parallelizable. The ``nn.Transformer`` module relies entirely on an attention
mechanism (implemented as
`nn.MultiheadAttention <>`__)
to draw global dependencies between input and output. The ``nn.Transformer``
module is highly modularized such that a single component (e.g.,
`nn.TransformerEncoder <>`__)
can be easily adapted/composed.
# Define the model
# ----------------

# In this tutorial, we train a ``nn.TransformerEncoder`` model on a
# language modeling task. The language modeling task is to assign a
# probability for the likelihood of a given word (or a sequence of words)
# to follow a sequence of words. A sequence of tokens are passed to the embedding
# layer first, followed by a positional encoding layer to account for the order
# of the word (see the next paragraph for more details). The
# ``nn.TransformerEncoder`` consists of multiple layers of
# `nn.TransformerEncoderLayer <>`__.
# Along with the input sequence, a square attention mask is required because the
# self-attention layers in ``nn.TransformerEncoder`` are only allowed to attend
# the earlier positions in the sequence. For the language modeling task, any
# tokens on the future positions should be masked. To produce a probability
# distribution over output words, the output of the ``nn.TransformerEncoder``
# model is passed through a linear layer followed by a log-softmax function.

import math

import torch
from torch import Tensor, nn
from torch.nn import Embedding, TransformerEncoder, TransformerEncoderLayer

[docs]class TransformerModel(nn.Module):
[docs] def __init__( self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5, activation: str = "gelu", max_len: int = 10, ): super().__init__() self.model_type = 'Transformer' self.encoder = Embedding(ntoken, d_model) self.pos_encoder = PositionalEncoding(d_model, dropout, max_len) encoder_layers = TransformerEncoderLayer( d_model, nhead, d_hid, dropout, batch_first=True, activation=activation, ) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) self.d_model = d_model self.decoder = nn.Linear(d_model, ntoken) self.init_weights()
def init_weights(self) -> None: initrange = 0.1, initrange), initrange)
[docs] def forward(self, src: Tensor, src_mask: Tensor) -> Tensor: """ Args: src: Tensor, shape [seq_len, batch_size] src_mask: Tensor, shape [seq_len, seq_len] Returns: output Tensor of shape [seq_len, batch_size, ntoken] """ src = self.encoder(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) output = self.transformer_encoder(src, src_mask) output = self.decoder(output) return output
[docs]def generate_square_subsequent_mask(sz: int, device=None) -> Tensor: """Generates an upper-triangular matrix of -inf, with zeros on diag.""" return ( torch.triu( torch.ones((sz, sz), device=device, dtype=torch.float16), diagonal=1 ) * torch.finfo(torch.float16).min )
###################################################################### # ``PositionalEncoding`` module injects some information about the # relative or absolute position of the tokens in the sequence. The # positional encodings have the same dimension as the embeddings so that # the two can be summed. Here, we use ``sine`` and ``cosine`` functions of # different frequencies. #
[docs]class PositionalEncoding(nn.Module):
[docs] def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.transpose(1, 0))
[docs] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x +[:, : x.size(1), :] return self.dropout(x)