Source code for cerebras.pytorch.sparse.base

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Dict, Generator, Optional, Union, final
from warnings import warn
from weakref import WeakValueDictionary, ref

import torch
import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle

import cerebras.pytorch as cstorch
from cerebras.pytorch.backend import current_backend_impl
from cerebras.pytorch.utils.weak import DefaultWeakIdKeyDictionary

from .init import InitMethodType, make_init_method
from .utils import HyperParameterSchedule, make_hyperparam_schedule


[docs]class SparsityAlgorithm(ABC): """Base class for all sparsity algorithms. This class is responsible for sparsifying parameters and registering hooks to apply the sparsity pattern to the parameters before forward and to the gradients after backward. It also registers hooks to update the sparsity pattern after each optimizer step. .. warning:: The way that sparse parameters are represented in the cerebras.pytorch API is via a mask tensor. This mask tensor is multiplied inplace to the original dense parameter before forward and to the gradients after backward. However, this is not the way that sparse parameters are represented on a Cerebras system. There, sparse parameters are handled natively in CSR format. As such, there is no mask tensor that can be referenced on the system side. What this means is that using the mask tensor haphazardly can lead to compile failures. Even if compile succeeds, any operations performed on the mask can be very computationally expensive. Having said that, there are several operations on masks that are supported on the Cerebras system. Please see the usage in the prepackaged algorithms as a guide for when and how it is acceptable to use the mask. """ _sparsity_algorithm_count = defaultdict(int) def __init__( self, sparsity: Union[float, HyperParameterSchedule, None], init_method: InitMethodType = "random", ): """Constructs a `SparsityAlgorithm` instance. Args: sparsity: The sparsity level to use for the algorithm. This can be a float or a :py:class:`~cerebras.pytorch.sparse.utils.HyperParameterSchedule`. If a dictionary is passed in, then it is automatically converted to a :py:class:`~cerebras.pytorch.sparse.utils.HyperParameterSchedule` init_method: The method to use to initialize the sparsity mask. See :py:func:`~cerebras.pytorch.sparse.init.make_init_method` for more details. """ count = SparsityAlgorithm._sparsity_algorithm_count[self.__class__] self.name = f"sparsity_{self.__class__.__name__.lower()}_{count}" SparsityAlgorithm._sparsity_algorithm_count[self.__class__] += 1 if sparsity is not None: self.sparsity = sparsity self.init_method = make_init_method(init_method) self.sparse_modules = torch.utils.weak.WeakIdKeyDictionary() self.sparse_optimizers = torch.utils.weak.WeakIdKeyDictionary() self.sparse_params = WeakValueDictionary() self._backend = current_backend_impl() self._backend.setup_sparsity(self) self.autoupdate = True self._target_sparsity_hooks = OrderedDict() self._computed_sparsity_hooks = OrderedDict() @property def num_sparse_params(self) -> int: """Return the number of parameters that have been sparsified by this algorithm.""" return len(self.sparse_params)
[docs] def get_sparse_params( self, obj: Union[torch.Tensor, torch.nn.Module, torch.optim.Optimizer] ) -> Union["SparseParameter", Generator["SparseParameter", None, None]]: """Get all sparse parameters that were sparsified by this algorithm. Args: obj: The object to get sparse parameters from. Returns: If obj is a Tensor, returns the sparse parameter associated with that tensor (if any). If obj is a Module, returns an iterator over all sparse parameters of the module and its submodules recursively. If obj is an Optimizer, returns an iterator over all sparse parameters associated with the optimize param groups. """ if isinstance(obj, torch.Tensor): return getattr(obj, "_sparse_param", None) elif isinstance(obj, torch.nn.Module): return ( sparse_param for _, param in obj.named_parameters() if (sparse_param := getattr(param, "_sparse_param", None)) and sparse_param.name in self.sparse_params ) elif isinstance(obj, torch.optim.Optimizer): return ( sparse_param for group in obj.param_groups for param in group["params"] if (sparse_param := getattr(param, "_sparse_param", None)) and sparse_param.name in self.sparse_params ) else: raise TypeError( f"Expected torch.nn.Module or torch.optim.Optimizer, " f"but got {type(obj)}" )
[docs] def initialize(self) -> None: """Initialize the sparsity pattern for all parameters sparsified by this algorithm.""" for sparse_param in self.sparse_params.values(): sparse_param.initialize()
[docs] def csx_annotate_sparsity(self, param: "SparseParameter") -> None: """Annotate the parameter with hints about the sparsity pattern. These hints are used as performance hints for the Cerebras compiler. Args: param: The sparse parameter to annotate with hints. """
@property def sparsity(self) -> Dict[torch.Tensor, HyperParameterSchedule]: """Return the mapping between a parameter and its sparsity schedule.""" if not hasattr(self, "_sparsity"): def default_error(): raise ValueError( f"{self.__class__.__name__} sparsity algorithm expected " f"`sparsity` to be specified, but got none." ) self._sparsity = DefaultWeakIdKeyDictionary(default_error) return self._sparsity @sparsity.setter def sparsity(self, sparsity: Union[float, HyperParameterSchedule, None]): """Create a mapping between a parameter and its sparsity schedule. If a mapping already exists, it will be updated. """ if isinstance(sparsity, dict) and any( isinstance(k, torch.Tensor) for k in sparsity ): def default_error(): raise KeyError("No sparsity schedule found for parameter") self._sparsity = DefaultWeakIdKeyDictionary( default_error, {p: make_hyperparam_schedule(s) for p, s in sparsity.items()}, ) else: # If a mapping exists, this will effectively just set the default # schedule and keep the previously existing schedules prev = getattr(self, "_sparsity", {}) default_schedule = make_hyperparam_schedule(sparsity) self._sparsity = DefaultWeakIdKeyDictionary( lambda: default_schedule, prev )
[docs] def sparsify_parameter( self, module: torch.nn.Module, name: str, param: torch.Tensor ) -> None: """Initialize the mask for a parameter in the given module. Args: module: The module that owns the parameter name: The full name of the parameter param: The parameter to initialze the sparsity mask for. """ if param is None: # Parameter is None, nothing to sparsify return if self.get_sparse_params(param): # Parameter already sparsified return if getattr(param, "requires_dense", False): # Parameter has been marked as not sparsifiable return # This simple scalar computation does not need to be traced with torch.device("cpu"): # Get the sparsity schedule for the given parameter and then # call it with a step of 0 to get the initial sparsity value sparsity = self.sparsity[param](getattr(self, "step", 0)) # Ensure that the sparsity level is valid sparsity = torch.clamp(sparsity, min=0.0, max=1.0) init_method = partial(self.init_method, sparsity=sparsity) sparse_param = SparseParameter(module, name, init_method) param._sparse_param = sparse_param # Keep a reference to the sparse parameter so that we can query them later on self.sparse_params[name] = sparse_param
[docs] @final def apply( self, obj: Union[torch.nn.Module, cstorch.optim.Optimizer] ) -> None: """Sparsify the passed in object. .. note:: This is called implicitly when calling ``module.apply(sparsity)`` or ``optimizer.apply(sparsity)`` Args: obj: a ``torch.nn.Module`` or a ``cstorch.optim.Optimizer`` object to sparsify. """ if isinstance(obj, torch.nn.Module): self.sparsify_module(obj) elif isinstance(obj, cstorch.optim.Optimizer): self.sparsify_optimizer(obj) else: raise TypeError( f"Expected torch.nn.Module or cstorch.optim.Optimizer, " f"but got {type(obj)}" )
[docs] def sparsify_module(self, module: torch.nn.Module) -> None: """Sparsify the ``torch.nn.Module`` object. Args: module: the ``torch.nn.Module`` object to sparsify """ def get_members_fn(submodule): if getattr(submodule, "requires_dense", False): # Module has been marked as not sparsifiable return () if submodule in self.sparse_modules or getattr( submodule, "is_sparse", False ): # Already applied sparsity for this module warn(f"Module {submodule} has already been sparsified.") return () self.sparse_modules[submodule] = True submodule.is_sparse = True return ( (k, (submodule, p)) for k, p in submodule._parameters.items() ) pre_sparsification_count = self.num_sparse_params # Recursively get all parameters in the module as well as the module # that owns them. for name, (submodule, param) in module._named_members( get_members_fn, recurse=True ): self.sparsify_parameter(submodule, name, param) if self.num_sparse_params == pre_sparsification_count: warn(f"No parameters were sparsified in module {module}") # No parameters were sparsified, so no need to register # a forward pre hook return module.register_forward_pre_hook(self._forward_pre_hook) with self._backend.device: if ( self._backend.is_csx and not self._backend.device.config.lazy_initialization ): # We need to move the masks to the device if we are doing # eager initialization self._backend.device.move_to_device(module) self.visit_state(lambda x: x.to(self._backend.torch_device))
def _forward_pre_hook(self, module: torch.nn.Module, args: Any): """Hook the given module to apply sparsity patterns. The sparsity pattern is applied to both the parameters before `forward()` call and gradients after `backward()` call. Args: module: The module that `forward()` is called on. args: Positional arguments passed to the `forward()` call. """ for sparse_param in self.get_sparse_params(module): # Clear sparse param's internal state sparse_param.clear() # Annotate the sparse param with hints for the Cerebras compiler if cstorch.use_cs(): self.csx_annotate_sparsity(sparse_param) self.prune_weight(sparse_param) @torch.no_grad() def prune_weight(self, sparse_param: "SparseParameter"): """Prune the dense weight and register a hook to prune the gradients. .. note:: This is called automatically in a module forward pre-hook. """ p = sparse_param.param sparse_param.prune(p, sparse_param.name) if p.requires_grad and sparse_param.grad_hook is None: sparse_param.grad_hook = p.register_hook( partial(self._grad_hook, p) )
[docs] def _grad_hook(self, p: torch.Tensor, grad: torch.Tensor): """Hook to prune the gradients after backward(). .. note:: This is called automatically in the parameter's backward grad hook. Args: p: The original parameter. grad: The gradient of the parameter. """ # In the case there any NaNs in the unused gradients that correspond to # zero'd out weights, we use a selection to replace these NaNs with # zeros. (multiplying with the mask would preserve them). # DLS will skip a weight update if there is a NaN in the gradient, but # we only want this to happen if there is a NaN in gradients # corresponding to non-zero weights. This is the behavior of the CS2 # which doesn't even compute the full gradients on most steps. zero = torch.zeros_like(grad) # Return modified gradient. with SparseParameter.disable_mask_access_warning(): return torch.where(p.mask, grad, zero)
[docs] def sparsify_optimizer(self, optimizer: torch.optim.Optimizer) -> None: """Sparsify the ``torch.optim.Optimizer`` object. Args: optimizer: the ``torch.optim.Optimizer`` object to sparsify """ if optimizer in self.sparse_optimizers or getattr( optimizer, "is_sparse", False ): # Already applied sparsity for this optimizer return self.sparse_optimizers[optimizer] = True optimizer.is_sparse = True if len(self.sparse_optimizers) > 1: # TODO: Support multiple optimizers # This is not a high priority as we never really use # more than one optimizer in practice raise RuntimeError( "Sparsifying multiple optimizers using the same sparsity " "algorithm is not supported." ) def prune_optimizer_states(optimizer, args, kwargs): params = list(self.get_sparse_params(optimizer)) if len(params) == 0: raise RuntimeError( "Detected that optimizer.apply(sparsity) was called " "but model.apply(sparsity) was not.\n" "Please call model.apply(sparsity)." ) for sparse_param in params: p = sparse_param.param for name, s in optimizer.state[p].items(): # sparsify all optimizer state tensors that match the # original parameter's shape and doesn't require dense if s.shape == p.shape and not getattr( s, "requires_dense", False ): sparse_param.prune(s, name) # Mark the pruned tensor to be the value that GradScaler # restores to if DLS detects non-fininte grads. Note that # GradScaler may have already marked the state pre-pruning, # so this is overriding it with the pruned version, with # the assumption that `prune()` modifies `s` in-place. cstorch.amp.update_if_finite(optimizer, s) # Only prune optimizer state if optimizer step is called optimizer.register_step_pre_hook(prune_optimizer_states) def step_post_hook(optimizer, args, kwargs): # The weights and optimizer state were just updated. In case we # _decrease_ sparsity in the update instead of increasing it, prune # the weights using the current weight masks for sparse_param in self.get_sparse_params(optimizer): sparse_param.prune() if self.autoupdate: self.update(optimizer) optimizer.register_step_post_hook(step_post_hook)
[docs] @abstractmethod def update(self, optimizer: Optional[cstorch.optim.Optimizer] = None): """Update the parameter's sparsity masks. Args: optimizer: The optimizer that is being used to update the sparse parameters. """
[docs] def register_target_sparsity_hook( self, hook: Callable[[str], torch.Tensor] ) -> RemovableHandle: r"""Register a hook which will be called when a new target sparsity is computed. It should have the following signature: hook(sparsity, name, target) ``sparsity`` argument is the sparsity instance being used. ``name`` is the name of the group of parameters that the target sparsity is being computed for. ``target`` is the computed target sparsity value. Args: hook (Callable): The user defined hook to be registered. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._target_sparsity_hooks) self._target_sparsity_hooks[handle.id] = hook return handle
[docs] def register_computed_sparsity_hook( self, hook: Callable[[str], torch.Tensor] ) -> RemovableHandle: r"""Register a hook which will be called when a new sparsity mask is computed. It should have the following signature: hook(sparsity, name, computed) ``sparsity`` argument is the sparsity instance being used. ``name`` is the name of the parameter that the mask belongs to. ``computed`` is the calculated sparsity level of the newly computed mask. Args: hook (Callable): The user defined hook to be registered. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._computed_sparsity_hooks) self._computed_sparsity_hooks[handle.id] = hook return handle
[docs] def visit_state(self, f: Callable): """Apply a callable to the stateful tensors."""
[docs] def state_dict(self) -> Dict[str, torch.Tensor]: """Return a dictionary of all stateful tensors.""" return {}
[docs] def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: """Load the state of all stateful tensors."""
class SparseParameter: """Representation of a sparse parameter. This class does not own the original parameter or the mask. It registers the mask with the module that owns the parameter and provides convenient accessors and modifiers for the mask. """ DISABLE_MASK_ACCESS_WARNING = False @staticmethod @contextmanager def disable_mask_access_warning(): prev = SparseParameter.DISABLE_MASK_ACCESS_WARNING try: SparseParameter.DISABLE_MASK_ACCESS_WARNING = True yield finally: SparseParameter.DISABLE_MASK_ACCESS_WARNING = prev def __init__( self, module: torch.nn.Module, name: str, init_method: InitMethodType ): # Save a weak reference to the module so that we can access it # without creating a reference cycle. self._module_ref = ref(module) self.name = name self.param_name = name.rsplit(".", 1)[-1] self.mask_name = f"{self.param_name}_mask" self.init_method = init_method self._backend = current_backend_impl() with self._backend.device: placeholder = cstorch.ones_like(self.param, dtype=torch.bool).to( self._backend.torch_device ) module.register_buffer(self.mask_name, placeholder, persistent=True) self._initialized = False def load_state_dict_pre_hook(state_dict, *args, **kwargs): # If we are loading the mask from a checkpoint, then # consider the mask as already initialized if f"{self.name}_mask" in state_dict: self._initialized = True module._register_load_state_dict_pre_hook(load_state_dict_pre_hook) self.grad_hook = None # Keep track of all tensors that were sparsified by this mask self._pruned_tensors = torch.utils.weak.WeakTensorKeyDictionary() # Keep track of all annotations that were applied to this mask self._annotations = {} def mask_property(p): if hasattr(p, "_sparse_param"): if not SparseParameter.DISABLE_MASK_ACCESS_WARNING: warn( f"Using the mask tensor haphazardly can lead to compile failures " f"and/or be very computationally expensive. Please only use the " f"mask tensor directly if you really know what you are doing." ) return p._sparse_param.mask else: return None # Add a property to the param so that the mask tensor can # be accessed as param.mask type(self.param).mask = property(mask_property) def initialize(self): if self._initialized: return # Use the CPU device if doing eager initialization on CSX. # Otherwise, use the parameter's device. # This allows us to trace the mask initialization during # lazy initialization. device = None if ( self._backend.is_csx and not self._backend.device.config.lazy_initialization ): device = "cpu" with self._backend.device: mask = self.init_method(self.param, device=device) if not isinstance(mask, torch.Tensor): raise TypeError( f"Expected init_method to return a Tensor, " f"but got {type(mask)}" ) if mask.device.type != self._backend.torch_device.type: mask = mask.to(self._backend.torch_device) # overwrite buffer setattr(self.module, self.mask_name, mask) self._initialized = True @property def module(self): m = self._module_ref() if m is None: raise ValueError(f"Attempting to access mask after module deleted") return m @property def param(self): return self.module._parameters[self.param_name] @property def data(self): return self.param @property def mask(self): return self.module._buffers[self.mask_name] @mask.setter def mask(self, new_mask): self.update(new_mask) def annotate(self, name, value): if len(self._pruned_tensors) > 0: raise RuntimeError( f"Detected that annotations are being set after pruning tensors: " f"{sorted(self._pruned_tensors.values())}" ) self._annotations[name] = value @torch.no_grad() def prune(self, tensor=None, tensor_name=None): """ Prunes the tensor using the sparse parameter's mask. If no tensor is provided, re-prune all tensors in the pruned tensor registry. """ if tensor is None: for tensor, name in self._pruned_tensors.items(): self.prune(tensor, name) return # annotate the tensor before pruning for name, value in self._annotations.items(): self._backend.set_attribute(tensor, name, value) tensor.mul_(self.mask) self._pruned_tensors[tensor] = tensor_name def update(self, new_mask): if not self._initialized: raise RuntimeError( "Detected that mask is being updated before it was initialized" ) if len(self._pruned_tensors) == 0: raise RuntimeError( "Detected that mask is being updated before it was used" ) self.module._buffers[self.mask_name].copy_(new_mask) # Need to re-prune all tensors that depend on this mask # to let the compiler know that the sparsity pattern # has changed for all tensors sparsified by this mask self.prune() def clear(self): """Clear pruned tensors and annotations for the next iteration.""" self._pruned_tensors.clear() self._annotations.clear() def __str__(self): return f"SparseParameter({self.name})"