modelzoo.transformers.pytorch.bert.model.BertForPreTrainingModel#
- class modelzoo.transformers.pytorch.bert.model.BertForPreTrainingModel[source]#
Bases:
torch.nn.Module
BERT-based models
Methods
build_model
forward
Calculates MLM Cross-Entropy (to be used for Perplexity calculation)
- __call__(*args: Any, **kwargs: Any) Any #
Call self as a function.
- static __new__(cls, *args: Any, **kwargs: Any) Any #
- mlm_xentropy_loss(labels, logits, weights=None)[source]#
Calculates MLM Cross-Entropy (to be used for Perplexity calculation)
- Parameters
labels – Tensor of shape (batch, sequence) and type int32.
logits – Tensor of shape (batch, sequence, vocab) and type float.
weights – Optional float Tensor of shape (batch, sequence).
- Returns
The loss tensor