Gradient Scaling#
Overview#
Gradient scaling can improve convergence when training models with float16 gradients by minimizing gradient underflow. Please see the PyTorch docs for a more detailed explanation.
To facilitate gradient scaling, we introduce a Cerebras-compliant implementation
of the AMP GradScaler class found in core PyTorch at
cerebras_pytorch.amp.GradScaler
. For example:
grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")
It is designed to be as similar as possible to the API of the CUDA AMP GradScaler class.
Its usage is identical to the usage of the CUDA AMP GradScaler:
loss: torch.Tensor = ...
optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()
# Unscales the gradients of the optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)
# (Optional) Global gradient clipping
torch.nn.utils.clip_grad_norm_(
model.parameters(),
1.0, # max gradient norm
)
# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)
# update the grad scaler once all optimizers have been stepped
grad_scaler.update()
Using automatic mixed precision with cbfloat16
or bfloat16
#
By default, automatic mixed precision uses float16
. If you want to use cbfloat16
or
bfloat16
instead of float16
, call cerebras_pytorch.amp.set_half_dtype
, e.g.
cstorch.amp.set_half_dtype("cbfloat16")
cerebras_pytorch.amp.optimizer_step
#
We introduce an optional helper function cerebras_pytorch.amp.optimizer_step
to take care of the details of gradient
scaling
cstorch.amp.optimizer_step(
loss,
optimizer,
grad_scaler,
max_gradient_norm=..., # optionally perform gradient clipping by norm
max_gradient_val=..., # optional perform gradient clipping by value
)
Note
It is useful for quickly constructing typical examples that use gradient scaling without needing to type up the details or worry about whether the grad scaler is being used correctly.
This is entirely optional and only covers the basic gradient scaler use case. For more complicated use cases, the grad scaler object must be used explicitly.