Training with number of tokens loss scaling#


Language models are trained on sequences of tokens which can be up to “max sequence length” (MSL) in size. To account for smaller sequences of length less than the MSL, inputs can be “padded” to fill the remaining slots in the sequence. The attention mask marks what tokens in an input sequence should contribute to the gradient computation and loss calculations. In autoregressive models like GPT, attention masks also serve to prevent the model from looking at future tokens in an input sequence when predicting a particular output token.

The loss of a network during training is a sum over all tokens in the input sequence. To not overweight the loss of larger sequences over smaller sequences, it is important to normalize the final loss by the number of input “non-padded” tokens. This normalization factor is a sum-reduction of the attention mask tensor and we call it num_tokens. Instruction fine-tuning and downstream task evaluations are potential use cases for loss scaling by num_tokens if the sequences sequences are “unpacked” i.e. each sequence contains one prompt-response pair.

Since the attention mask has a batch dimension, in runs with gradient accumulation or on multiple CSX devices, the attention mask will be appropriately reduced over all boxes and gradient accumulation micro-batches. That is, the num_tokens value used in the run will correctly reflect the global batch size of the model.


Loss scaling by num_tokens can be enabled within the configuration yaml under the model section by changing the loss_scaling and loss_weight parameters.

  loss_scaling: “num_tokens”
  loss_weight: 1.0


Training language models using loss scaling by the number of non-padded tokens (num_tokens) is essential for handling sequences of varying lengths efficiently. This technique ensures equitable loss contribution from each sequence, irrespective of its length, by normalizing the loss based on active, non-padded tokens. Implementing num_tokens in the model’s configuration allows for more accurate and balanced model training, particularly crucial in scenarios with diverse sequence lengths, thereby enhancing the model’s overall performance and robustness.