modelzoo.transformers.pytorch.bert.model.BertForPreTrainingModel#

class modelzoo.transformers.pytorch.bert.model.BertForPreTrainingModel[source]#

Bases: torch.nn.Module

BERT-based models

Methods

build_model

forward

mlm_xentropy_loss

Calculates MLM Cross-Entropy (to be used for Perplexity calculation)

__call__(*args: Any, **kwargs: Any) Any#

Call self as a function.

__init__(params)[source]#
static __new__(cls, *args: Any, **kwargs: Any) Any#
mlm_xentropy_loss(labels, logits, weights=None)[source]#

Calculates MLM Cross-Entropy (to be used for Perplexity calculation)

Parameters
  • labels – Tensor of shape (batch, sequence) and type int32.

  • logits – Tensor of shape (batch, sequence, vocab) and type float.

  • weights – Optional float Tensor of shape (batch, sequence).

Returns

The loss tensor