Source code for cerebras_pytorch.sparse.set

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an optimizer implementing SET for use with the WSE.
"""

import torch
from torch.optim.optimizer import required

from .dynamic import DynamicSparsityOptimizer, InitMethodType, ScheduleType
from .utils import (
    HyperParameterType,
    make_mask_drop_minimum,
    make_mask_grow_maximum,
    set_param_group_hyperparam,
)


[docs]class SETSparsityOptimizer(DynamicSparsityOptimizer): r"""Implements Sparse Evolutionary Training (SET) Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown randomly. See: https://arxiv.org/abs/1707.04780 Args: params (iterable): iterable of parameters to sparsify or dicts defining parameter groups to sparsify init_method: Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda. sparsity: Sparsity, either constant or step-aware hyperparameter schedule: Sparsity update schedule. May be one of: * ``int``: Single regular update frequency. * ``list``: Irregular update on the given steps. * ``dict``: Containing ``{"start": start, "freq": freq, "stop": stop}`` for regular updates with start & stop. * ``ScheduleCallable`` : User function accepting a rankless ``torch.LongTensor`` and returning a rankless ``torch.BoolTensor`` drop_fraction: Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter. Example: >>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = SETSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity=0.9, schedule={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step() """
[docs] def __init__( self, params, init_method: InitMethodType = "random", sparsity: HyperParameterType = required, schedule: ScheduleType = required, drop_fraction: HyperParameterType = 0.3, **kwargs, ): # drop_fraction is a required value for SET though it has a default # value. Pass it as dynamic optimizer kwarg. It will be configured # on each param_group. kwargs["drop_fraction"] = drop_fraction super().__init__( params, init_method=init_method, sparsity=sparsity, schedule=schedule, **kwargs, )
def add_param_group(self, param_group): # Verify all required values are specified. param_group = super().add_param_group(param_group) set_param_group_hyperparam(param_group, "drop_fraction") @torch.no_grad() def update_mask(self, p, mask, sparsity, group): drop_fraction = group["drop_fraction"]( self._step, group["is_update_step"] ) # Keep the connections of highest magnitude weights but drop some. p_score = p.abs() mask, k = make_mask_drop_minimum(p_score, mask, drop_fraction) # Regrow randomly. regrow_score = torch.rand_like(p) return make_mask_grow_maximum(regrow_score, mask, sparsity, k)