Training with number of tokens loss scaling#
Language networks are trained on sequences of tokens which can be up to “max sequence length” (MSL) in size. To account for smaller sequence lengths less than the MSL, inputs can be “padded” to fill the remaining slots in the sequence. This padding takes place through the attention mask which marks what tokens in an input sequence should have no effect on gradient and loss calculations. In autoregressive models like GPT, attention masks also serve to prevent the model from looking at future tokens in an input sequence when predicting a particular output token.
The loss of a network during training is a sum over all tokens in the input sequence. To not overweight the loss of larger sequences over smaller sequences, it is important to normalize the final loss by the number of input tokens. This normalization factor is a sum-reduction of the attention mask tensor and we call it num_tokens.
The attention mask has a batch dimension, and so when performing gradient accumulation or multibox computation, it is important to scale gradients and loss by a value of num_tokens that is computed over the entire batch instead the value local to the sub-batch on a single CSX or a single iteration of gradient accumulation. On a multibox system, we reduce the local num_tokens through the branch-reduce (BR) tree via a transfer to the weight host (this will likely change once we enable direct communication between activation and weight hosts). With gradient accumulation, we use unrolling in the intermediate representation (IR) to sum the local values of num_tokens before using the full-batch value in all grad accum iterations.
How to enable#
Loss scaling by num_tokens can be enabled within the configuration yaml under the model section by changing the loss_scaling and loss_weight parameters.
model: ... loss_scaling: “num_tokens” loss_weight: 1.0
Known issues and limitations#
Loss scaling by num_tokens has not been thoroughly tested across all variants and may have issues when used with gradient accumulation.