.. _pytorch-vts: PyTorch Variable Tensor Shape ============================= Overview -------- Variable Tensor Shape (VTS) is a feature that allows computations on the CS system running in pipeline mode to process tensors which vary in shape from one element of a batch to the next. In natural language processing applications, it is common for input data to consist of sequences of heterogeneous length. When this is the case, short samples get padded up to a user defined maximum sequence length so that they can be batched together. Naive treatment of padded data can result in wasting significant computation on padding tokens. VTS allows users to strip away this padding as samples enter the wafer and perform the model computation on non-padded variable length tensors. This leads to less wasted computation and faster training times. Typically models written for GPU include logic to ensure that padding tokens do not contribute to the final loss of the model. In this case, enabling VTS will have no affect on the model's function other than increasing throughput. Interface Details ----------------- The VTS interface consists of two custom PyTorch ops: ``cerebras.framework.torch.nn.StripPadding`` and ``cerebras.framework.torch.nn.RestorePadding``. The ``StripPadding`` function accepts the following arguments: - ``input``: The input tensor to process. - ``mask``: A mask that defines which portions of the input correspond to padding and should be stripped away. In particular, this mask defines where the end of the sequence is. - ``axis``: The axis along which to strip away part of ``input``. When not running on CS system, this operation does nothing. On CS system, it produces a version of ``input`` with the end of the tensor stripped away along axis ``axis``. The end of the tensor is defined by the first element of ``mask`` that is either 0 or ``False``. The ``RestorePadding`` function is not used as commonly as the ``StripPadding`` function but is useful in some cases when the user wants to ignore padding values only in some subset of a model. It accepts the following arguments: - ``input``: The tensor to add padding back into. - ``axis``: The axis along which to add in padding. - ``pad_value``: A scalar value that will be used to pad the input tensor to the maximum shape. As with ``StripPadding``, this operation is the identity function when not run on a CS system. On a CS system, it performs the inverse operation to ``StripPadding``. That is, it takes a variable-shape tensor and pads it out to its full shape. This input tensor can be the output of a ``StripPadding`` call or a tensor derived from passing the output of a ``StripPadding`` call through certain VTS compatible operations. The Cerebras compiler stack infers the maximum shape of ``input`` using information from the shape of the input to the ``StripPadding`` operation from which ``input`` was derived. It then pads out ``input`` along axis ``axis`` using value ``pad_value`` to the shape derived by the compile stack. Example Usage in a Model ------------------------ Using the ``StripPadding`` and ``RestorePadding`` functions, it is easy to convert an existing model to use VTS. For example, you can enable VTS in a BERT model as follows: .. code-block:: python class BertModel(torch.nn.Module): # model initialization code def forward( self, input_ids, attention_mask, masked_lm_positions, masked_lm_weights, mlm_labels, ): input_ids = StripPadding(input_ids, attention_mask) masked_lm_positions = StripPadding(masked_lm_positions, masked_lm_weights) labels = StripPadding(mlm_labels, masked_lm_weights) # remaining model code Limitations ----------- Variable Tensor Shape is a feature that is still maturing, and as such has several limitations in its current form. If a model fails Cerebras compile with VTS turned on, we suggest attempting a compile without VTS. The only supported axis for VTS is ``axis=1``, which corresponds to a variable sequence dimension for common language modeling applications.