Source code for cerebras.pytorch.sparse.gmp

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an optimizer implementing GMP for use with the WSE.
"""
import torch

from .dynamic import DynamicSparsityAlgorithm
from .utils import Constant, make_mask_topk_sparsity


[docs]class GMP(DynamicSparsityAlgorithm): r"""Implements Gradual Magnitude Pruning Sparsity increases monotonically based on weight magnitude. See: https://arxiv.org/abs/1710.01878 """ def __init__(self, **kwargs): """ Args: **kwargs: All arguments are passed to the :py:func:`~cerebras.pytorch.sparse.DynamicSparsityAlgorithm`'s constructor. Example: sparsity_opt = cstorch.sparse.GMP( schedule={"type": "exp", "init": 0, "gamma": 1000*math.log(0.3) update={"freq": 1000}, ) """ super().__init__(**kwargs) if isinstance(self.sparsity.default, Constant): raise ValueError( f"Configured with constant sparsity {self.sparsity.default.value}. " f"This is not valid, because the sparsity pattern would not change " f"during training. For a static sparsity pattern, use `algorithm=\"static\".`" ) @torch.no_grad() def update_mask(self, p, mask, sparsity): score = p.abs() return make_mask_topk_sparsity(score, sparsity)