Source code for cerebras_pytorch.sparse.rigl

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an optimizer implementing RigL for use with the WSE.
"""
from functools import partial

import torch
from torch.optim.optimizer import required

from .dynamic import DynamicSparsityOptimizer, InitMethodType, ScheduleType
from .utils import (
    HyperParameterType,
    InputGroupScoreShaper,
    OutputGroupScoreShaper,
    make_mask_drop_minimum,
    make_mask_grow_maximum,
    set_param_group_hyperparam,
)


[docs]class RigLSparsityOptimizer(DynamicSparsityOptimizer): r"""Implements Rigging the Lottery (RigL) Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown using a proxy measure of where a pruned connection would have had the most impact by finding the highest magnitude (dense) gradients of pruned weights. See: https://arxiv.org/abs/1911.11134 Args: params (iterable): iterable of parameters to sparsify or dicts defining parameter groups to sparsify init_method: Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda. sparsity: Sparsity, either constant or step-aware hyperparameter schedule: Sparsity update schedule. May be one of: * ``int``: Single regular update frequency. * ``list``: Irregular update on the given steps. * ``dict``: Containing ``{"start": start, "freq": freq, "stop": stop}`` for regular updates with start & stop. * ``ScheduleCallable`` : User function accepting a rankless ``torch.LongTensor`` and returning a rankless ``torch.BoolTensor`` drop_fraction: Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter. Example: >>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = RigLSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity=0.9, schedule={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step() """
[docs] def __init__( self, params, init_method: InitMethodType = "random", sparsity: HyperParameterType = required, schedule: ScheduleType = required, drop_fraction: HyperParameterType = 0.3, **kwargs, ): # drop_fraction is a required value for RigL though it has a default # value. Pass it as dynamic optimizer kwarg. It will be configured # on each param_group. kwargs["drop_fraction"] = drop_fraction self._dense_grads = {} super().__init__( params, init_method=init_method, sparsity=sparsity, schedule=schedule, **kwargs, )
def add_param_group(self, param_group): param_group = super().add_param_group(param_group) set_param_group_hyperparam(param_group, "drop_fraction") # RigL may need per-head balancing of attention projection weights in_groups = param_group.pop("balance_in_groups", None) out_groups = param_group.pop("balance_out_groups", None) def validate_balance(groups, err_key): for name, param in zip( param_group["param_names"], param_group["params"] ): for dim in param.shape: if dim % groups == 0: break else: raise ValueError( f"Sparsity group configured with `{err_key}`={groups} " f"but parameter {name} does not have a dimension with " f"a multiple of {groups}: {param.shape}" ) if out_groups: if in_groups: raise ValueError( "Only one of `balance_in_groups` and `balance_out_groups` " "can be specified at a time." ) validate_balance(out_groups, "balance_out_groups") score_shaper = OutputGroupScoreShaper(out_groups) elif in_groups: validate_balance(in_groups, "balance_in_groups") score_shaper = InputGroupScoreShaper(in_groups) else: score_shaper = None param_group["score_shaper"] = score_shaper # Also add score shaping to the init_method. param_group["init_method"] = partial( param_group["init_method"], score_shaper=score_shaper ) def _grad_hook(self, p, grad): # Save a copy of the dense gradients before masking. if p in self._dense_grads: # GPU gradient accumulation mode. self._dense_grads[p] += grad else: self._dense_grads[p] = grad.clone() return super()._grad_hook(p, grad)
[docs] def zero_grad(self, set_to_none: bool = True): """ Clears the accumulated dense gradients. """ if set_to_none: self._dense_grads = {} else: for g in self._dense_grads.values(): g.zero_()
@torch.no_grad() def update_mask(self, p, mask, sparsity, group): if p not in self._dense_grads: raise RuntimeError( "RigL requires dense gradients, ensure you have called " "sparsity_optimizer.apply_sparsity()" ) # RigL may need per-head balancing of attention projection weights score_shaper = group["score_shaper"] drop_fraction = group["drop_fraction"]( self._step, group["is_update_step"] ) # Keep the connections of highest magnitude weights but drop some. p_score = p.abs() mask, k = make_mask_drop_minimum( p_score, mask, drop_fraction, score_shaper=score_shaper ) # Regrow where the gradient magnitude is the largest. regrow_score = self._dense_grads[p].abs() return make_mask_grow_maximum( regrow_score, mask, sparsity, k, score_shaper=score_shaper, )