Automatic mixed precision#

The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster


class cerebras.pytorch.amp.GradScaler[source]#

Faciliates mixed precision training and DLS, DLS + GCC

For more details please see docs for amp.initialize.

  • loss_scale – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.

  • init_scale – The initial loss scale value if loss_scale == “dynamic”

  • steps_per_increase – The number of steps after which to increase the loss scaling condition

  • min_loss_scale – The minimum loss scale value that can be chosen by dynamic loss scaling

  • max_loss_scale – The maximum loss scale value that can be chosen by dynamic loss scaling

  • overflow_tolerance – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded

  • max_gradient_norm – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect

Example usage:

grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")

loss: torch.Tensor = ...

# Scale the loss before calling the backward pass

# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping

# Global gradient clipping
    1.0,  # max gradient norm

# Step the optimizer using the grad scaler

# update the grad scaler once all optimizers have been stepped
__init__(loss_scale: Optional[Union[str, float]] = None, init_scale: Optional[float] = None, steps_per_increase: Optional[int] = None, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = None, overflow_tolerance: float = 0.0, max_gradient_norm: Optional[float] = None)[source]#

Returns a dictionary containing the state to be saved to a checkpoint


Loads the state dictionary into the current params

scale(loss: torch.Tensor)[source]#

Scales the loss in preparation of the backwards pass


Return the loss scale


Unscales the optimizer’s params gradients inplace

step_if_finite(optimizer, *args, **kwargs)[source]#

Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.


The result of optimizer.step()


Clip the optimizer’s params’s gradients and return whether or not the norm is finite

step(optimizer, *args, **kwargs)[source]#

Step carries out the following two operations: 1. Internally invokes unscale_(optimizer) (unless unscale_ was

explicitly called for optimizer earlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.

  1. Invokes optimizer.step() using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.

*args and **kwargs are forwarded to optimizer.step(). Returns the return value of optimizer.step(*args, **kwargs). :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras.pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.


Update the scales of the optimizers


Update the gradient scalar after all optimizers have been stepped


cerebras.pytorch.amp.set_half_dtype(value: Union[Literal['float16', 'bfloat16', 'cbfloat16'], torch.dtype]) torch.dtype[source]#

Sets the underlying 16-bit floating point dtype to use.


value – Either a 16-bit floating point torch dtype or one of “float16”, “bfloat16”, or “cbfloat16” string.


The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as value passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation.

By default, automatic mixed precision uses float16. If you want to use cbfloat16 or bfloat16 instead of float16, call this function.

Example usage:



cerebras.pytorch.amp.optimizer_step(loss: torch.Tensor, optimizer: cerebras.pytorch.optim.optimizer.Optimizer, grad_scaler: cerebras.pytorch.amp.grad_scaler.GradScaler, max_gradient_norm: Optional[float] = None, max_gradient_value: Optional[float] = None)[source]#

Performs loss scaling, gradient scaling and optimizer step

  • loss – The loss value to scale. loss.backward should be called before this function

  • optimizer – The optimizer to step

  • grad_scaler – The gradient scaler to use to scale the parameter gradients

  • max_gradient_norm – the max gradient norm to use for gradient clipping

  • max_gradient_value – the max gradient value to use for gradient clipping

Example usage: