cerebras_pytorch.amp#

Automatic mixed precision#

The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster

GradScaler#

class cerebras_pytorch.amp.GradScaler[source]#

Faciliates mixed precision training and DLS, DLS + GCC

For more details please see docs for amp.initialize.

Parameters
  • loss_scale – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.

  • init_scale – The initial loss scale value if loss_scale == “dynamic”

  • steps_per_increase – The number of steps after which to increase the loss scaling condition

  • min_loss_scale – The minimum loss scale value that can be chosen by dynamic loss scaling

  • max_loss_scale – The maximum loss scale value that can be chosen by dynamic loss scaling

  • overflow_tolerance – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded

  • max_gradient_norm – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect

Example usage:

grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")

loss: torch.Tensor = ...

optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)

# Global gradient clipping
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    1.0,  # max gradient norm
)

# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)

# update the grad scaler once all optimizers have been stepped
grad_scaler.update()
__init__(loss_scale: Optional[Union[str, float]] = None, init_scale: Optional[float] = None, steps_per_increase: Optional[int] = None, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = None, overflow_tolerance: float = 0.05, max_gradient_norm: Optional[float] = None)[source]#
clip_gradients_and_return_isfinite(optimizers)[source]#

Clip the optimizer’s params’s gradients and return whether or not the norm is finite

get_scale()[source]#

Return the loss scale

load_state_dict(state_dict)[source]#

Loads the state dictionary into the current params

scale(loss: torch.Tensor)[source]#

Scales the loss in preparation of the backwards pass

state_dict(destination=None)[source]#

Returns a dictionary containing the state to be saved to a checkpoint

step(optimizer, *args, **kwargs)[source]#

Step carries out the following two operations: 1. Internally invokes unscale_(optimizer) (unless unscale_ was

explicitly called for optimizer earlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.

  1. Invokes optimizer.step() using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.

*args and **kwargs are forwarded to optimizer.step(). Returns the return value of optimizer.step(*args, **kwargs). :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras_pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.

step_if_finite(optimizer, *args, **kwargs)[source]#

Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.

Parameters
Returns

The result of optimizer.step()

unscale_(optimizer)[source]#

Unscales the optimizer’s params gradients inplace

update(new_scale=None)[source]#

Update the gradient scalar after all optimizers have been stepped

update_scale(optimizers)[source]#

Update the scales of the optimizers

use_bfloat16#

cerebras_pytorch.amp.use_bfloat16(use_bfloat16)[source]#

By default, automatic mixed precision uses float16. If you want to bfloat16 instead of float16, call this function.

Example usage:

cstorch.amp.bfloat16(True)

optimizer_step#

cerebras_pytorch.amp.optimizer_step(loss: torch.Tensor, optimizer: cerebras_pytorch.optim.optimizer.Optimizer, grad_scaler: cerebras_pytorch.amp.grad_scaler.GradScaler, max_gradient_norm: Optional[float] = None, max_gradient_value: Optional[float] = None)[source]#

Performs loss scaling, gradient scaling and optimizer step

Parameters
  • loss – The loss value to scale. loss.backward should be called before this function

  • optimizer – The optimizer to step

  • grad_scaler – The gradient scaler to use to scale the parameter gradients

  • max_gradient_norm – the max gradient norm to use for gradient clipping

  • max_gradient_value – the max gradient value to use for gradient clipping

Example usage:

cstorch.amp.optimizer_step(
    loss,
    optimizer,
    grad_scaler,
    max_gradient_norm=1.0,
)