Gradient Scaling#

Overview#

Gradient scaling can improve convergence when training models with float16 gradients by minimizing gradient underflow. Please see the PyTorch docs for a more detailed explanation.

To facilitate gradient scaling, we introduce a Cerebras-compliant implementation of the AMP GradScaler class found in core PyTorch at cerebras_pytorch.amp.GradScaler. For example:

grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")

It is designed to be as similar as possible to the API of the CUDA AMP GradScaler class.

Its usage is identical to the usage of the CUDA AMP GradScaler:

loss: torch.Tensor = ...

optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()

# Unscales the gradients of the optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)

# (Optional) Global gradient clipping
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    1.0,  # max gradient norm
)

# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)

# update the grad scaler once all optimizers have been stepped
grad_scaler.update()

Using automatic mixed precision with cbfloat16 or bfloat16#

By default, automatic mixed precision uses float16. If you want to use cbfloat16 or bfloat16 instead of float16, call cerebras_pytorch.amp.set_half_dtype, e.g.

cstorch.amp.set_half_dtype("cbfloat16")

cerebras_pytorch.amp.optimizer_step#

We introduce an optional helper function cerebras_pytorch.amp.optimizer_step to take care of the details of gradient scaling

cstorch.amp.optimizer_step(
    loss,
    optimizer,
    grad_scaler,
    max_gradient_norm=...,  # optionally perform gradient clipping by norm
    max_gradient_val=...,  # optional perform gradient clipping by value
)

Note

It is useful for quickly constructing typical examples that use gradient scaling without needing to type up the details or worry about whether the grad scaler is being used correctly.

This is entirely optional and only covers the basic gradient scaler use case. For more complicated use cases, the grad scaler object must be used explicitly.