Per-layer precision optimization level#

Overview#

Per-layer Precision Optimization Level (POL) is a powerful feature that allows for fine-grained control over the precision level of individual operations within a model. This capability offers significant advantages in terms of improving model convergence stability and addressing numerical issues associated with low-precision data types used in kernels.

Per-layer POL vs. Per-model POL#

There are two types of POL: per-model POL and per-layer POL.

Per-model POL

This approach sets a single precision level for the entire model using the CsConfig.precision_opt_level parameter. This method provides a simple and efficient way to control model precision but lacks the granularity to address specific numerical challenges in individual layers.

Per-layer POL

This approach allows for individual operations (layers) to be annotated with specific precision levels, including both forward and backward passes. This granularity enables a more precise and targeted approach to optimizing model performance and addressing numerical issues. It is achieved by utilizing the cstorch.pol decorator to annotate specific functions within the model code. The granularity of the POL annotation is determined by the content of the decorated function, allowing for individual operations or blocks of operations to be targeted. Importantly, per-layer POL overrides any per-model POL settings, ensuring that the specific operation precision takes precedence.

CSTorch API#

To annotate a layer or set of layers, use the cstorch.pol decorator. For example:

@cstorch.pol(level=0)
def fc1(x):
    out = self.fc1(x)
    return F.relu(out)

out = fc1(x)

This example demonstrates utilizing the cstorch.pol decorator to annotate all operations within the fc1 function with a precision level (POL) setting of 0. By default, cstorch.pol applies the specified precision level to both the forward pass operations as well as their corresponding gradients computation operations during backpropagation. However, the decorator accepts parameters to allow configuring precision policies independently for forward versus backward passes.

For example:

@cstorch.pol(bwd_level=0) # annotates only bwd pass ops.
def fc1(x):
    return self.fc1(x)

An alternative way of using the decorator is to pass a layer as a parameter instead of wrapping a function. For example, the example above is equivalent to:

cstorch.pol(bwd_level=0)(  # set POL parameters
    self.fc1                 # pass a layer (callable)
)(hidden_states)           # pass layer input arguments

To implement per-operation precision level (POL) annotations, the cstorch.pol API communicates with the Cerebras PyTorch backend. Specifically, it tags IR nodes in the Layerwise Tensor Compiler (LTC) intermediate representation with precision level attributes. These attributes are then propagated through the subsequent compilation stages on the Cerebras accelerator backend.

Note

Current implementation of per-layer POL only supports annotating MatMul operations. While future updates may expand support to additional operations, this limitation is currently in place.