Source code for cerebras.pytorch.sparse.base

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Callable, Dict, Optional, Union, final
from warnings import warn
from weakref import WeakValueDictionary, ref

import torch

import cerebras.pytorch as cstorch
from cerebras.pytorch.backend import current_backend_impl
from cerebras.pytorch.utils.weak import DefaultWeakIdKeyDictionary

from .init import InitMethodType, make_init_method
from .utils import HyperParameterSchedule, make_hyperparam_schedule


[docs]class SparsityAlgorithm(ABC): """Base class for all sparsity algorithms This class is responsible for sparsifying parameters and registering hooks to apply the sparsity pattern to the parameters before forward and to the gradients after backward. It also registers hooks to update the sparsity pattern after each optimizer step. .. warning:: The way that sparse parameters are represented in the cerebras.pytorch API is via a mask tensor. This mask tensor is multiplied inplace to the original dense parameter before forward and to the gradients after backward. However, this is not the way that sparse parameters are represented on a Cerebras system. There, sparse parameters are handled natively in CSR format. As such, there is no mask tensor that can be referenced on the system side. What this means is that using the mask tensor haphazardly can lead to compile failures. Even if compile succeeds, any operations performed on the mask can be very computationally expensive. Having said that, there are several operations on masks that are supported on the Cerebras system. Please see the usage in the prepackaged algorithms as a guide for when and how it is acceptable to use the mask. """ _sparsity_algorithm_count = defaultdict(int) def __init__(self, sparsity, init_method: InitMethodType = "random"): """ Args: sparsity: The sparsity level to use for the algorithm. This can be a float or a :py:class:`~cerebras.pytorch.sparse.utils.HyperParameterSchedule`. If a dictionary is passed in, then it is automatically converted to a :py:class:`~cerebras.pytorch.sparse.utils.HyperParameterSchedule` init_method: The method to use to initialize the sparsity mask. See :py:func:`~cerebras.pytorch.sparse.init.make_init_method` for more details. """ count = SparsityAlgorithm._sparsity_algorithm_count[self.__class__] self.name = f"sparsity_{self.__class__.__name__.lower()}_{count}" SparsityAlgorithm._sparsity_algorithm_count[self.__class__] += 1 if sparsity is not None: self.sparsity = sparsity self.init_method = make_init_method(init_method) self.sparse_modules = torch.utils.weak.WeakIdKeyDictionary() self.sparse_optimizers = torch.utils.weak.WeakIdKeyDictionary() self.sparse_params = WeakValueDictionary() self._backend = current_backend_impl() self._backend.setup_sparsity(self) self.autoupdate = True @property def num_sparse_params(self): """ Returns the number of parameters that have been sparsified by this algorithm """ return len(self.sparse_params)
[docs] def initialize(self): """ Initialize the sparsity pattern for all parameters sparsified by this algorithm """ for sparse_param in self.sparse_params.values(): sparse_param.initialize()
[docs] def csx_annotate_sparsity(self, param: "SparseParameter"): """ Annotate the parameter with hints about the sparsity pattern as performance hints for the Cerebras compiler Args: param: The sparse parameter to annotate with hints """
@property def sparsity(self) -> Dict[torch.Tensor, HyperParameterSchedule]: """ Returns a mapping between a parameter and its sparsity schedule """ if not hasattr(self, "_sparsity"): def default_error(): raise ValueError( f"{self.__class__.__name__} sparsity algorithm expected " f"`sparsity` to be specified, but got none." ) self._sparsity = DefaultWeakIdKeyDictionary(default_error) return self._sparsity @sparsity.setter def sparsity(self, sparsity): """ Creates a mapping between a parameter and its sparsity schedule If a mapping already exists, it will be updated. """ if isinstance(sparsity, dict) and any( isinstance(k, torch.Tensor) for k in sparsity ): def default_error(): raise KeyError("No sparsity schedule found for parameter") self._sparsity = DefaultWeakIdKeyDictionary( default_error, {p: make_hyperparam_schedule(s) for p, s in sparsity.items()}, ) else: # If a mapping exists, this will effectively just set the default # schedule and keep the previously existing schedules prev = getattr(self, "_sparsity", {}) default_schedule = make_hyperparam_schedule(sparsity) self._sparsity = DefaultWeakIdKeyDictionary( lambda: default_schedule, prev )
[docs] def sparsify_parameter( self, module: torch.nn.Module, name: str, param: torch.Tensor ) -> None: """ Initialize the mask for a parameter in the given module. Args: module: The module that owns the parameter name: The full name of the parameter """ if hasattr(param, "_sparse_param"): # Parameter already sparsified return if getattr(param, "requires_dense", False): # Parameter has been marked as not sparsifiable return # This simple scalar computation does not need to be traced with torch.device("cpu"): # Get the sparsity schedule for the given parameter and then # call it with a step of 0 to get the initial sparsity value sparsity = self.sparsity[param](getattr(self, "step", 0)) # Ensure that the sparsity level is valid sparsity = torch.clamp(sparsity, min=0.0, max=1.0) init_method = partial(self.init_method, sparsity=sparsity) param._sparse_param = SparseParameter(module, name, init_method) # Keep a reference to the sparse parameter so that we can apply them later on self.sparse_params[name] = param._sparse_param
[docs] @final def apply(self, obj: Union[torch.nn.Module, torch.optim.Optimizer]): """ Sparsifies the passed in object. .. note:: This is called implicitly when calling ``module.apply(sparsity)`` or ``optimizer.apply(sparsity)`` Args: obj: a ``torch.nn.Module`` or a ``torch.optim.Optimizer`` object to sparsify """ if isinstance(obj, torch.nn.Module): return self.sparsify_module(obj) if isinstance(obj, cstorch.optim.Optimizer): return self.sparsify_optimizer(obj) raise TypeError( f"Expected torch.nn.Module or cstorch.optim.Optimizer, " f"but got {type(obj)}" )
[docs] def sparsify_module(self, module: torch.nn.Module): """ Sparsify the ``torch.nn.Module`` object Args: module: the ``torch.nn.Module`` object to sparsify """ def get_members_fn(submodule): if getattr(submodule, "requires_dense", False): # Module has been marked as not sparsifiable return () if submodule in self.sparse_modules or getattr( submodule, "is_sparse", False ): # Already applied sparsity for this module warn(f"Module {submodule} has already been sparsified.") return () self.sparse_modules[submodule] = True submodule.is_sparse = True return ( (k, (submodule, p)) for k, p in submodule._parameters.items() ) pre_sparsification_count = self.num_sparse_params # Recursively get all parameters in the module as well as the module # that owns them. for name, (submodule, param) in module._named_members( get_members_fn, recurse=True ): self.sparsify_parameter(submodule, name, param) if self.num_sparse_params == pre_sparsification_count: warn(f"No parameters were sparsified in module {module}") # No parameters were sparsified, so no need to register # a forward pre hook return module.register_forward_pre_hook(self._forward_pre_hook) with self._backend.device: if ( self._backend.is_csx and not self._backend.device.config.lazy_initialization ): # We need to move the masks to the device if we are doing # eager initialization self._backend.device.move_to_device(module) self.visit_state(lambda x: x.to(self._backend.torch_device))
def _forward_pre_hook(self, module, input): """ Hook the given module such that the sparsity pattern is applied to both the parameters before forward() and gradients after backward() """ for name, sparse_param in self.sparse_params.items(): sparse_param.registered_tensors.clear() self._annotate_sparse_params() self.prune_weights() def _annotate_sparse_params(self): """ Annotate sparsity params and prune weights. """ for sparse_param in self.sparse_params.values(): self.csx_annotate_sparsity(sparse_param) @torch.no_grad() def prune_weights(self): """ Prune the dense weights. .. note:: This is called automatically in a module forward pre-hook """ for sparse_param in self.sparse_params.values(): sparse_param.apply() if ( sparse_param.param.requires_grad and sparse_param.grad_hook is None ): sparse_param.grad_hook = sparse_param.param.register_hook( partial(self._grad_hook, sparse_param.param) )
[docs] def _grad_hook(self, p, grad): """ Hook to apply the prune the gradients after backward() .. note:: This is called automatically in the parameter's backward grad hook Args: p: The original parameter grad: The gradient of the parameter """ # In the case there any NaNs in the unused gradients that correspond to # zero'd out weights, we use a selection to replace these NaNs with # zeros. (multiplying with the mask would preserve them). # DLS will skip a weight update if there is a NaN in the gradient, but # we only want this to happen if there is a NaN in gradients # corresponding to non-zero weights. This is the behavior of the CS2 # which doesn't even compute the full gradients on most steps. zero = torch.zeros_like(grad) # Return modified gradient. with SparseParameter.disable_mask_access_warning(): return torch.where(p.mask, grad, zero)
[docs] def sparsify_optimizer(self, optimizer: torch.optim.Optimizer): """ Sparsify the ``torch.optim.Optimizer`` object Args: optimizer: the ``torch.optim.Optimizer`` object to sparsify """ if optimizer in self.sparse_optimizers or getattr( optimizer, "is_sparse", False ): # Already applied sparsity for this optimizer return self.sparse_optimizers[optimizer] = True optimizer.is_sparse = True if len(self.sparse_optimizers) > 1: # TODO: Support multiple optimizers # This is not a high priority as we never really use # more than one optimizer in practice raise RuntimeError( "Sparsifying multiple optimizers using the same sparsity " "algorithm is not supported." ) def get_optimizer_states(optimizer): params = [ (p, sparse_param) for group in optimizer.param_groups for p in group["params"] if (sparse_param := getattr(p, "_sparse_param", None)) ] if len(params) == 0: raise RuntimeError( "Detected that optimizer.apply(sparsity) was called " "before model.apply(sparsity). This means that " "no optimizer state got sparsified.\n" "Please call model.apply(sparsity) first." ) states = defaultdict(list) for p, sparse_param in params: for name, tensor in optimizer.state[p].items(): # sparsify all optimizer state tensors that match the # original parameter's shape and doesn't require dense if tensor.shape == p.shape and not getattr( tensor, "requires_dense", False ): states[sparse_param].append((name, tensor)) return states def step_pre_hook(optimizer, args, kwargs): optimizer_state = get_optimizer_states(optimizer) for sparse_param, items in optimizer_state.items(): for name, tensor in items: sparse_param.register_tensor(name, tensor) self._annotate_sparse_params() for sparse_param, items in optimizer_state.items(): for name, tensor in items: sparse_param._prune(tensor) optimizer.register_step_pre_hook(step_pre_hook) # Store the hooks so that they can be removed later self.step_post_hook = optimizer.register_step_post_hook( lambda optimizer, args, kwargs: self.update(optimizer) if self.autoupdate else None )
def _ensure_sparsity_applied(self): if not all( sparse_param._applied for sparse_param in self.sparse_params.values() ): raise RuntimeError( "Detected that sparsity masks were not properly applied. " "A module hook was installed which should have taken care " "of applying the sparsity masks it, but did not. " "Check that you have not disabled module hooks." )
[docs] @abstractmethod def update(self, optimizer: Optional[cstorch.optim.Optimizer] = None): """ Update the parameter's sparsity masks Args: optimizer: The optimizer that is being used to update the sparse parameters """
[docs] def visit_state(self, f: Callable): """Apply a callable to the stateful tensors"""
[docs] def state_dict(self): """Return a dictionary of all stateful tensors""" return {}
[docs] def load_state_dict(self, state_dict): """Load the state of all stateful tensors"""
class SparseParameter: """Representation of a sparse parameter This class does not own the original parameter or the mask. It registers the mask with the module that owns the parameter and provides convenient accessors and modifiers for the mask. """ DISABLE_MASK_ACCESS_WARNING = False @staticmethod @contextmanager def disable_mask_access_warning(): prev = SparseParameter.DISABLE_MASK_ACCESS_WARNING try: SparseParameter.DISABLE_MASK_ACCESS_WARNING = True yield finally: SparseParameter.DISABLE_MASK_ACCESS_WARNING = prev def __init__( self, module: torch.nn.Module, name: str, init_method: InitMethodType ): # Save a weak reference to the module so that we can access it # without creating a reference cycle. self._module_ref = ref(module) self.name = name self.param_name = name.rsplit(".", 1)[-1] self.mask_name = f"{self.param_name}_mask" self.init_method = init_method self._backend = current_backend_impl() with self._backend.device: placeholder = cstorch.ones_like(self.param, dtype=torch.bool).to( self._backend.torch_device ) module.register_buffer(self.mask_name, placeholder, persistent=True) self._initialized = False def load_state_dict_pre_hook(state_dict, *args, **kwargs): # If we are loading the mask from a checkpoint, then # consider the mask as already initialized if f"{self.name}_mask" in state_dict: self._initialized = True module._register_load_state_dict_pre_hook(load_state_dict_pre_hook) self._applied = False self.grad_hook = None # Other tensors that depend on this mask can be registered. # When the mask is applied to the original param, the mask # will also be applied to these tensors as well. self._registered_tensors = torch.utils.weak.WeakTensorKeyDictionary() def mask_property(p): if not SparseParameter.DISABLE_MASK_ACCESS_WARNING: warn( f"Using the mask tensor haphazardly can lead to compile failures " f"and/or be very computationally expensive. Please only use the " f"mask tensor if you really know what you are doing." ) if hasattr(p, "_sparse_param"): return p._sparse_param.mask else: return None # Add a property to the param so that the mask tensor can # be accessed as param.mask type(self.param).mask = property(mask_property) def initialize(self): if self._initialized: return # Use the CPU device if doing eager initialization on CSX. # Otherwise, use the parameter's device. # This allows us to trace the mask initialization during # lazy initialization. device = None if ( self._backend.is_csx and not self._backend.device.config.lazy_initialization ): device = "cpu" # We mark the mask as applied here so that we can initialize it # without raising an error about the mask being updated before # it was applied. self._applied = True with self._backend.device: mask = self.init_method(self.param, device=device) if not isinstance(mask, torch.Tensor): raise TypeError( f"Expected init_method to return a Tensor, " f"but got {type(mask)}" ) if mask.device.type != self._backend.torch_device.type: mask = mask.to(self._backend.torch_device) # overwrite buffer setattr(self.module, self.mask_name, mask) self._initialized = True @property def module(self): m = self._module_ref() if m is None: raise ValueError(f"Attempting to access mask after module deleted") return m @property def param(self): return self.module._parameters[self.param_name] @property def data(self): return self.param @property def mask(self): return self.module._buffers[self.mask_name] @mask.setter def mask(self, new_mask): self.update(new_mask) def register_tensor(self, name: str, tensor: torch.Tensor): """ Register a tensor that depends on this mask. When the mask is applied to the original parameter, the mask will also be applied to these tensors as well. """ self._registered_tensors[tensor] = name @property def registered_tensors(self): return self._registered_tensors def annotate(self, name, value): self._backend.set_attribute(self.param, name, value) for tensor in self.registered_tensors: self._backend.set_attribute(tensor, name, value) def _prune(self, tensor): tensor.mul_(self.mask) @torch.no_grad() def apply(self): if not self._initialized: raise RuntimeError( f"Cannot apply mask to parameter before it has been initialized" ) self._applied = True self._prune(self.param) for tensor, name in self._registered_tensors.items(): self._prune(tensor) def update(self, new_mask): if not self._applied: raise RuntimeError( "Detected that mask is being updated before it was applied" ) self._applied = False # TODO: Make this update conditional for DLS self.module._buffers[self.mask_name].copy_(new_mask) def __str__(self): return f"SparseParameter({self.name})"