Train a model with weight sparsity#

Overview#

In 2018, state-of-the-art neural networks such as BERT had a few hundred million parameters. Two years later, the world was introduced to GPT-3. With 175 billion parameters and a 3.14*1023 FLOPs (floating point operations) compute budget, it is estimated to have required 10,000 NVIDIA V100 GPUs for 15 days, accounting for 552 tons of CO2e emissions and 1,287 MWh of energy [Patterson et al.].

Evidently, training large models is costly. With parameter counts and datasets getting larger and larger every year, new approaches are needed to reduce the time, energy, and carbon footprint required to train. Weight sparsity, coupled with hardware that accelerates it, is a promising way to train models using significantly less compute and memory.

../../_images/unstructured-weight-sparsity.png

Weight sparse training methods set subsets of weights to zero. The resulting sparse model requires far fewer FLOPs to train and fewer parameters to store, as multiplies with zeros get skipped on both forward and backward passes through the network. Only systems that can accelerate sparsity, such as Cerebras CS-X and CS-3, can take advantage of the lower resource requirement and use the reduction in FLOPs to significantly accelerate training. Finding and training sparse models to match the accuracy of their original “dense” (i.e., non-sparse) configurations is an active and open area of research!

Sparsity via YAML#

When using the Cerebras Model Zoo reference models, runs are parameterized via YAML configs that include model, data, and optimizer parameters (GPT-3 example here). To train with sparsity, you can include a sparsity section in your run’s YAML config file as a sibling to the model and optimizer sections. Each parameter is sparsified independently. i.e. we don’t yet support global sparsity.

For example, with the following config, the sparsity level is set to 0.3 (30%), and init_method is "random", which means 30% of the elements in each Parameter (which passes the default parameter filter) will be pruned once at model initialization and kept that way throughout training. Non-Parameter tensors are not pruned.

sparsity:
    algorithm: "static"
    sparsity: 0.3
    init_method: "random"

Sparsity is parameterized primarily by the following keys:

  • algorithm:

    Sparsity training algorithm to apply.

    • Static (Default): Fixed sparsity level throughout training

    • GMP: Gradual magnitude pruning

    • SET: Sparse Evolutionary Training

    • RigL: Rigging the Lottery

    You can also define a custom class that inherits from SparsityAlgorithm. As long as the class in the global scope, i.e. by importing it in your run.py, it can be directly used in a YAML config, e.g.

    class CustomSparsity(cstorch.sparse.SparsityAlgorithm):
        ...
    
    sparsity:
        algorithm: CustomSparsity
        ...
    

    See Writing a Custom Sparsity Algorithm for more details on how to write a custom sparsity algorithm.

  • sparsity:

    The desired sparsity level between 0 and 1. 0.0 means the Parameter is kept fully dense. 1.0 means the Parameter is effectively entirely zeros. Dynamic sparsity algorithms also accept more complex configuration described below in Dynamic Hyperparameters.

    Note

    The actual sparsity level may not match the target sparsity level in practice. The target sparsity level only represents a target distribution. The true sparsity level is determined by the size of the Parameter that is being sparsified.

    For example, if you were to sparsify a Parameter with shape (5,) targeting a sparsity level of 0.5, the actual sparsity level will only ever be 0.4. The smaller the Parameter, the more extreme this discrepancy becomes. If the Parameter is a scalar tensor, then the actual sparsity level will always either be 0.0 or 1.0.

  • init_method optional:

    Method to compute the initial sparsity distribution.

    • random: (default) Sparsity is randomly distributed within each weight.

    • topk: Sparsity is distributed according to the lowest magnitude weights.

    • from_zeros: Sparsity pattern is determined by weight values that are already zero.

  • param_filter optional:

    Controls which Parameters are sparsified. The list of Parameter names can be found using model.named_parameters().

    When this is omitted, any multidimensional Parameters (except those with embedding, norm, or lm_head in their name) automatically get sparsity applied (single dimensional weights such as biases are ignored) (See default_sparse_param_filter).

    While this provides a good default heuristic for transformer based models 1, a (list of) glob expressions can also be provided to only apply sparsity to Parameters which match, e.g.

    param_filter:
    - "*dense_layer.weight"
    - "*linear_layer.weight"
    

    To match all weights, set param_filter: *

    Per-layer sparsity options can be configured by passing in a list of configuration dictionaries. See below in advanced param_filters.

Dynamic Sparsity Update Schedule#

Dynamic sparsity (e.g. GMP, SET, or RigL) needs an additional update schedule indicating when to update the sparsity pattern. There are 2 basic methods built-in with 3 different options:

Regular Interval#

When sparsity should be updated at a regular interval, a single frequency can be given:

sparsity:
    update:
        freq: 100

    algorithm: set
    sparsity: 0.9

Here, sparsity will be initialized at 90% and steps 0,…,99 will be performed with a fixed sparsity pattern. Every 100 steps, the sparsity pattern will be updated according to the SET algorithm.

To control beginning and ending steps, use a dictionary. In the following example, sparsity will be initialized at 0% and steps 0,…,76 will be performed without sparsity. Starting from step 77 and every 100 steps until step 377, the sparsity pattern will be updated according to the SET algorithm. After step 377, the sparsity pattern will continue to be applied, but it will no longer be updated (stop is exclusive).

sparsity:
    update:
        start: 77
        freq: 100
        stop: 477 # An update will _not_ be performed on step 477

    algorithm: set
    sparsity: 0.9

Irregular Interval#

When sparsity should be updated at arbitrary steps, specify them in a list:

sparsity:
    update:
        steps: [0, 5, 20, 50]

    algorithm: set
    sparsity: 0.9

Dynamic Hyperparameters#

Dynamic sparsity algorithms (e.g. GMP, SET, or RigL) can configure the sparsity (and drop_fraction for SET and RigL) field using a “step aware hyperparemeter” akin to learning rate schedules in addition to simple constants. These more complex configurations usually require additional options and so are specified as dictionaries.

Note

The base DynamicSparsityAlgorithm that invokes such a dynamic hyperparameter for sparsity ensures sparsity levels stay legal by using torch.clamp(sparsity, min=0.0, max=1.0).

Linear#

\(y(step) = init + slope \cdot step\)

sparsity:
    algorithm: "gmp"
    update:
        freq: 1000
    schedule:
        type: "linear"
        init: 0.0 # value at step zero
        slope: 0.001 # increase in value each step

Exponential#

\(y(step) = final + (init-final) e^{step \cdot gamma}\)

This is expecially useful for GMP, where the sparsity level monotonically increases throughout training because a fraction of the remaining elements in the Parameter are pruned at each update step, asymptotically approaching an empty network.

sparsity:
    algorithm: "gmp"
    update:
        freq: 1000
    schedule:
        type: "exp"
        init: 0.0  # starting value
        final: 1.0  # asymptotic ending value

        # Prune 10% of the remaining connections every 1000 steps
        gamma: -0.00010536051  # ln(1-0.10)/1000

Cosine#

\(y(step) = o + a \cdot \cos(step \cdot \pi / half\_period)\), where \(o = (init + minimum)/2\) and \(a = init - o\).

This is especially useful for RigL, which usually uses a “cosine decay” on its drop_fraction. minimum defaults to 0.0. half_period controls what step the value reaches its minimum.

sparsity:
    algorithm: "rigl"
    update:
        freq: 1000
    sparsity: 0.9
    drop_fraction:
        type: "cosine"
        init: 0.3  # starting value
        half_period: 10000 # reaches minimum (default 0) after 10 updates

More Config examples#

The most basic configuration, applying random 30% sparsity to all Parameters:

sparsity:
    sparsity: 0.3

Apply uniform (static) sparsity to a selected set of weights, with a sparsity pattern guided by the weight magnitudes:

sparsity:
    sparsity: 0.9
    init_method: "topk"
    param_filter:
    - "dense_layer.weight"
    - "linear_layer.weight"

Basic dynamic sparsity using the SET algorithm. Update the sparsity pattern every 1000 iterations.

sparsity:
    algorithm: "set"
    sparsity: 0.9
    update:
        freq: 1000
    drop_fraction: 0.3

Configuring Multiple Sparsity Algorithms#

Different groups of Parameters can be sparsified using different sparsity algorithms.

For example, if one set of weights should be statically sparsified to say 0.3, but another set of weights should be dynamically sparsified using the SET algorithm, it can be done by providing a list of sparsity algorithms.

sparsity:
  - param_filter: "fc1.*"
    sparsity: 0.3
  - param_filter: "fc2.*"
    algorithm: "set"
    sparsity: 0.9
    update:
        freq: 1000
    drop_fraction: 0.3

Advanced param_filters#

When each Parameter (or group of Parameters) needs different configuration, param_filters can be specified as a dictionary, mapping “patterns” to the config dictionaries to overlay on the default sparsity config options.

For example, when using RigL on transformer networks (uses gradient information to guide which values in a Parameter to prune), sparsity can be cyclically restributed between the heads of attention projection weights in case samples in a batch activate one head disproportionately to another. This ultimately decreases the effectiveness of dynamic sparsity and even can hurt model performance.

To ensure sparsity is fairly distributed between the different attention heads of the multi-head attention projections, you can specify balance_out_groups when the output logits are logically N independent/stacked groups (i.e. input projection weights before multi-head attention QKV), or balance_in_groups for the reverse (i.e. output projection weights). These should apply differently to different weights using param_filter since this conceptually only applies to Attention projection weights. In the following example, the model has 12 attention heads.

rigl_config: &rigl-config
    algorithm: "rigl"
    sparsity: 0.9
    update:
        freq: 1000
    drop_fraction:
        type: "cosine"
        init: 0.3
        half_period: 10000

sparsity:
  - <<: *rigl-config
    param_filter: "*proj_[qkv]_dense_layer.weight":
    balance_out_groups : 12  # ensure this matches model.num_heads
  - <<: *rigl-config
    param_filter: "*linear_layer.weight":

Running a Sparse Model#

No change is needed to the run command (see guide: Launch your job) - ensure the .yaml file has sparsity enabled. To validate your sparsity config before launching training, run with --validate_only. You can also log which weights are being sparsified by passing --logging VERBOSE to your run command.

python modelzoo/path/model/run.py CSX \
       --params params_with_sparsity.yaml \
       --num_csx=1 \
       --model_dir model_dir --mode {train,eval,eval_all,train_and_eval} \
       --mount_dirs {paths modelzoo and to data} \
       --python_paths {paths to modelzoo and other python code if used}

When using dynamic sparsity, you can see realtime summaries by setting the config field add_summaries: True. Each group of Parameters independently summarizes its target sparsity as well as the actual computed sparsity of each tensor.

gmp_config: &gmp-config
    # Enable tensorboard summaries for target and actual sparsity
    add_summaries: True

    algorithm: "gmp"
    schedule:
        type: "exp"
        init: 0.0  # starting value
        final: 1.0  # asymptote value
        # Prune 20% of the remaining connections every 1000 steps:
        gamma: -0.0002231435513142097  # ln(1-0.20)/1000

sparsity:
  - <<: *gmp-config
    param_filter: ".*fc_layers.*weight":
    update:
        freq: 2000
  - <<: *gmp-config
    param_filter: ".*last_layer.*weight":
    update:
        freq: 1000
../../_images/sparsity-summaries.png

Sparsity via API#

Please see Sparsifying models for more details on how to configure sparsity using the Cerebras PyTorch API.