Source code for cerebras.pytorch.sparse.static

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provide an "optimizer" implementing static sparsity.
"""

import torch

import cerebras.pytorch as cstorch

from .base import SparsityAlgorithm


[docs]class Static(SparsityAlgorithm): def __init__(self, sparsity: float = None, **kwargs): """ Args: sparsity: A float specifying the level of sparsity to apply to each parameter """ if sparsity is not None and not (0.0 <= sparsity < 1.0): raise ValueError( f"Invalid sparsity level {sparsity}. Must be 0.0 <= s < 1.0" ) super().__init__(sparsity, **kwargs) def csx_annotate_sparsity(self, param: "SparseParameter"): if cstorch.use_cs(): # This simple scalar computation does not need to be traced with torch.device("cpu"): # We can just take the sparsity value at step 0 # as the sparsity value is constant sparsity = self.sparsity[param.data](step=0).item() min_max_end = (sparsity, sparsity, sparsity) min_v, max_v, ending_v = min_max_end param.annotate("min_sparsity", min_v) param.annotate("max_sparsity", max_v) param.annotate("sparsity", ending_v) @torch.no_grad() def update(self, optimizer): # Ensure we've called apply_sparsity before update self._ensure_sparsity_applied() # Merely apply the mask to maintain initial sparsity pattern. self.prune_weights()