Source code for cerebras.pytorch.sparse.set

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an optimizer implementing SET for use with the WSE.
"""

import torch

from .dynamic import DynamicSparsityAlgorithm
from .utils import (
    HyperParameterScheduleType,
    make_hyperparam_schedule,
    make_mask_drop_minimum,
    make_mask_grow_maximum,
)


[docs]class SET(DynamicSparsityAlgorithm): r"""Implements Sparse Evolutionary Training (SET) Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown randomly. See: https://arxiv.org/abs/1707.04780 """ def __init__( self, drop_fraction: HyperParameterScheduleType = 0.3, **kwargs ): """ Args: drop_fraction: Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter. **kwargs: Any additional arguments are passed to the :py:func:`cstorch.sparse.DynamicSparsityAlgorithm`'s constructor. Example: .. code-block:: python sparsity_opt = cstorch.sparse.SET( sparsity=0.9, update={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) """ super().__init__(**kwargs) # drop_fraction is a required value for SET though it has a default # value. Pass it as dynamic optimizer kwarg. It will be configured # on each param_group. self.drop_fraction = make_hyperparam_schedule(drop_fraction) @torch.no_grad() def update_mask(self, p, mask, sparsity): drop_fraction = self.drop_fraction(self.step) # Update the drop fraction schedule if it is an update step self.drop_fraction.update(self.is_update_step) # Keep the connections of highest magnitude weights but drop some. p_score = p.abs() mask, k = make_mask_drop_minimum(p_score, mask, drop_fraction) # Regrow randomly. regrow_score = torch.rand_like(p) return make_mask_grow_maximum(regrow_score, mask, sparsity, k)