# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Provide an optimizer implementing SET for use with the WSE.
"""
import torch
from torch.optim.optimizer import required
from .dynamic import DynamicSparsityOptimizer, InitMethodType, ScheduleType
from .utils import (
HyperParameterType,
make_mask_drop_minimum,
make_mask_grow_maximum,
set_param_group_hyperparam,
)
[docs]class SETSparsityOptimizer(DynamicSparsityOptimizer):
r"""Implements Sparse Evolutionary Training (SET)
Sparsity levels stay constant throughout training, but the lowest
magnitude weights are pruned and then regrown randomly.
See: https://arxiv.org/abs/1707.04780
Args:
params (iterable): iterable of parameters to sparsify or dicts defining
parameter groups to sparsify
init_method: Method to initialize sparsity pattern. Can either be the
name of a built-in method or a lambda.
sparsity: Sparsity, either constant or step-aware hyperparameter
schedule: Sparsity update schedule. May be one of:
* ``int``: Single regular update frequency.
* ``list``: Irregular update on the given steps.
* ``dict``: Containing ``{"start": start, "freq": freq, "stop":
stop}`` for regular updates with start & stop.
* ``ScheduleCallable`` : User function accepting a rankless
``torch.LongTensor`` and returning a rankless
``torch.BoolTensor``
drop_fraction: Fraction of non-pruned weights to drop each update step.
Either a constant or a step-aware hyperparamter.
Example:
>>> optimizer = torch.optim.SGD(
model.parameters(), lr=0.1, momentum=0.9
)
>>> sparsity_opt = SETSparsityOptimizer(
[p for n,p in model.named_parameters() if should_sparsify(n,p)],
sparsity=0.9,
schedule={"freq": 100, "stop": 1000},
drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000},
)
>>> sparsity_opt.hook_module(model)
>>> sparsity_opt.initialize_sparsity()
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
"""
[docs] def __init__(
self,
params,
init_method: InitMethodType = "random",
sparsity: HyperParameterType = required,
schedule: ScheduleType = required,
drop_fraction: HyperParameterType = 0.3,
**kwargs,
):
# drop_fraction is a required value for SET though it has a default
# value. Pass it as dynamic optimizer kwarg. It will be configured
# on each param_group.
kwargs["drop_fraction"] = drop_fraction
super().__init__(
params,
init_method=init_method,
sparsity=sparsity,
schedule=schedule,
**kwargs,
)
def add_param_group(self, param_group):
# Verify all required values are specified.
param_group = super().add_param_group(param_group)
set_param_group_hyperparam(param_group, "drop_fraction")
@torch.no_grad()
def update_mask(self, p, mask, sparsity, group):
drop_fraction = group["drop_fraction"](
self._step, group["is_update_step"]
)
# Keep the connections of highest magnitude weights but drop some.
p_score = p.abs()
mask, k = make_mask_drop_minimum(p_score, mask, drop_fraction)
# Regrow randomly.
regrow_score = torch.rand_like(p)
return make_mask_grow_maximum(regrow_score, mask, sparsity, k)