Source code for cerebras.pytorch.sparse.static

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an "optimizer" implementing static sparsity.
"""

from typing import Optional

import torch

from .base import SparseParameter, SparsityAlgorithm


[docs]class Static(SparsityAlgorithm): def __init__(self, sparsity: Optional[float] = None, **kwargs): """Constructs a `Static` sparsity instance. Args: sparsity: A float specifying the level of sparsity to apply to each parameter """ if sparsity is not None: if not isinstance(sparsity, float): raise ValueError( "Static sparsity algorithm only supports constant sparsity" ) if not (0.0 <= sparsity < 1.0): raise ValueError( f"Invalid sparsity level {sparsity}. Must be 0.0 <= s < 1.0" ) self.sparsity_level = sparsity super().__init__(sparsity, **kwargs) def csx_annotate_sparsity(self, param: SparseParameter): # This simple scalar computation does not need to be traced with torch.device("cpu"): # We can just take the sparsity value at step 0 # as the sparsity value is constant sparsity = self.sparsity[param.data](step=0).item() self.sparsity_level = sparsity param.annotate("min_sparsity", sparsity) param.annotate("max_sparsity", sparsity) param.annotate("sparsity", sparsity) def update(self, optimizer): # Nothing to update for static sparsity for hook in self._target_sparsity_hooks.values(): hook(self, "*", self.sparsity_level) for hook in self._computed_sparsity_hooks.values(): hook(self, "*", self.sparsity_level)