Source code for cerebras_pytorch.sparse.gmp

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an optimizer implementing GMP for use with the WSE.
"""
import torch
from torch.optim.optimizer import required

from .dynamic import DynamicSparsityOptimizer, InitMethodType, ScheduleType
from .utils import (
    HyperParameterType,
    initialize_tiebreak,
    make_mask_topk_sparsity,
)


[docs]class GMPSparsityOptimizer(DynamicSparsityOptimizer): r"""Implements Gradual Magnitude Pruning Sparsity increases monotonically based on weight magnitude. See: https://arxiv.org/abs/1506.02626 Args: params (iterable): iterable of parameters to sparsify or dicts defining parameter groups to sparsify init_method: Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda. sparsity: Sparsity, either constant or step-aware hyperparameter schedule: Sparsity update schedule. May be one of: * ``int``: Single regular update frequency. * ``list``: Irregular update on the given steps. * ``dict``: Containing ``{"start": start, "freq": freq, "stop": stop}`` for regular updates with start & stop. * ``ScheduleCallable`` : User function accepting a rankless ``torch.LongTensor`` and returning a rankless ``torch.BoolTensor`` Example: >>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = GMPSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity={"type": "exp", "init": 0, "gamma": 1000*math.log(0.3) schedule=1000, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step() """
[docs] def __init__( self, params, init_method: InitMethodType = "random", sparsity: HyperParameterType = required, schedule: ScheduleType = required, tiebreak=None, **kwargs, ): if isinstance(sparsity, (int, float)): # Currently, constant sparisty is not supported for GMP, but a # cycling schedule with equal value is a workaround. Not that # you'd want this anyway, you might as well use static sparsity. sparsity = [sparsity] # tiebreak is optional for GMP kwargs["tiebreak"] = tiebreak super().__init__( params, init_method=init_method, sparsity=sparsity, schedule=schedule, **kwargs, )
def add_param_group(self, param_group): sparsity = param_group.get("sparsity") if isinstance(sparsity, (int, float)): raise ValueError( f"Configured with {sparsity=}. This is not valid, because " f"the sparsity pattern would not change during training. " f"For a static sparsity pattern, use `type=\"static\".`" ) # Verify all required values are specified. param_group = super().add_param_group(param_group) # If no tiebreak is specified, this will be a No-Op param_group["tiebreak"] = initialize_tiebreak( param_group.get("tiebreak") ) @torch.no_grad() def update_mask(self, p, mask, sparsity, group): tiebreak = group["tiebreak"] score = tiebreak(p.abs()) return make_mask_topk_sparsity(score, sparsity)