Weight Streaming Execution#

Overview#

The execution mode pertains to the manner in which the Cerebras runtime deploys your neural network model onto the Cerebras Wafer-Scale Engine (WSE). Within the Cerebras Wafer-Scale Cluster, there exists support for the Weight Streaming Mode. In this particular mode, the neural network model is loaded one layer at a time. This layer-by-layer approach is employed when working with expansive models, specifically those in which the weight parameters of a single layer can fit within memory, yet the entirety of the model as a whole cannot.

../../_images/cs-exec-mode-ws-computation.png

Fig. 2 Weight Streaming Mode#

To demonstrate the functioning of Weight Streaming, let’s use the following example of a 3-layer FC-MNIST network:

../../_images/cs-exec-mode-3l-fcmnist.png

Weight Streaming Execution during runtime#

During runtime, the Weight Streaming process unfolds as depicted in Fig. 3. At each step, one layer is loaded onto the Cerebras Wafer-Scale Engine (WSE). The weight data flows from the MemoryX server to the WSE during the forward propagation phase. In the subsequent backpropagation phase, the weights are once again streamed from MemoryX to the WSE. Here on the WSE, weight gradients are computed, and then they are streamed from the WSE to MemoryX for storage. These weight gradients are vital for the learning process, during which the weights are adjusted using the weight gradient in conjunction with the learning algorithm.

../../_images/cs-exec-mode-ws-execution.png

Fig. 3 Weight streaming execution mode.#

Compilation#

The Cerebras compiler plays a crucial role by extracting the operation graph from the code and aligning these operations with the compatible kernels within the Cerebras Software Platform. Each such matched kernel corresponds to a layer in the network’s dataflow graph. You can visualize a mathematical representation of these layers in Fig. 4.

Should you wish to inspect the outcome of this preliminary compilation phase, you can utilize the --validate_only flag. The Cerebras compiler strategically plans the mapping of one kernel or layer at a time onto the entire Cerebras Wafer-Scale Engine (WSE), initially for the forward propagation and subsequently for the backward propagation passes.

In the event that multiple CS-X systems are requested within the Cerebras Wafer-Scale cluster, the same mapping is employed for every layer across all CS-X systems. For those interested in precompilation, the --compile_only flag can be employed for this purpose.

../../_images/cs-exec-mode-fcmnist-math.png

Fig. 4 Mathematical representation of example network showing color-coded layers and the activations flowing from one layer to the next layer.#

Training#

During the forward propagation phase, as illustrated in Fig. 5:

  • Layer 1 is the first to be loaded onto the WSE. Input pre-processing servers handle the processing and streaming of training data to the WSE, while MemoryX streams the weights of layer 1 into the WSE. If there’s a request for multiple CS-X in the training job, SwarmX broadcasts the weights from MemoryX to all the WSEs. The batch of training samples is divided into equally sized subsets, or shards, with each shard directed to a respective CS-X. This technique is known as data parallelism.

  • Each WSE simultaneously performs the forward computation for layer 1 in parallel with the others.

  • The calculated activations for layer 1 are retained in the WSE’s memory.

  • Subsequently, MemoryX broadcasts the weights of layer 2 to the WSEs.

  • Each WSE conducts the forward computations for layer 2, utilizing its stored activations from layer 1.

  • This same process is repeated for layer 3.

  • In this manner, the forward computation for each layer is carried out by employing the previously computed activations from the preceding layer. Additionally, the calculated activations for the current layer are stored in the WSE’s memory for use by the next loaded layer.

  • At the loss layer, the actual labels from the training data are employed to compute the network loss delta, which represents the gradient of the scalar loss concerning the activation of the output layer (layer 3). This loss delta plays a crucial role in calculating layer-by-layer deltas and weight gradients during the subsequent backward pass.

    ../../_images/cs-exec-mode-ws-fwd.png

    Fig. 5 Schematic representation of the forward propagation in Weight Streaming Mode.#

In the backward propagation phase, as illustrated in Fig. 6:

  • The weights of layer three are broadcast from MemoryX to the WSEs, where the gradient and delta computations for layer three are performed. It’s worth noting that, in practice, the weights of this output layer can be retained to save time.

  • The gradients for layer three are then streamed out of the WSEs to SwarmX. SwarmX aggregates (sums together) the gradients from multiple WSEs and presents this sum to MemoryX. MemoryX employs this aggregate gradient within the learning algorithm to update its stored copy of the layer weights.

  • Subsequently, the weights of layer two are streamed from MemoryX to the WSE, and the WSE performs similar gradient and delta computations for layer two. The layer two gradients are then streamed out of the WSE to MemoryX, where weight updates take place. If there’s a request for multiple CS-X in the training job, SwarmX broadcasts the weights from MemoryX to all the WSEs and consolidates the gradient updates from the WSEs to MemoryX.

  • This backward pass sequence continues until the gradients for layer one are streamed out to MemoryX, where weight updates occur. Following this, the forward pass for the next training batch commences with the updated weights.

    ../../_images/cs-exec-mode-ws-bwd.png

    Fig. 6 Schematic representation of the backward propagation in Weight Streaming Mode.#

Meanwhile, within the user node:

  • As the loss values are calculated on the WSEs, SwarmX consolidates them and then transmits the results to the user node.

  • Additionally, at specified intervals, all the weights can be retrieved from MemoryX and transferred to the user node for checkpoint purposes.