Source code for cerebras.pytorch.sparse.group

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

import logging
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Union
from warnings import warn

import torch

import cerebras.pytorch as cstorch
from cerebras.pytorch.sparse.base import SparsityAlgorithm
from cerebras.pytorch.utils.utils import convert_glob_to_regex


[docs]class Group(SparsityAlgorithm): """ Group sparsity algorithm. This algorithm allows for multiple sparsity algorithms to be applied to different groups of parameters. For example: .. code:: python sparsity = cstorch.sparse.Group({ "fc1.*": cstorch.sparse.Static(sparsity=0.5), "fc2.*": cstorch.sparse.GMP( schedule=[0.3, 0.4, 0.5], update: {"freq": 100} ), }) sparsity.add("fc3.*", cstorch.sparse.RigL(sparsity=0.5)) model.apply(sparsity) optimizer.apply(sparsity) The group sparsity algorithm will apply the sparsity algorithms to the parameters that match the filter. If a parameter name matches multiple filters, the first filter that matches will be used. """ @dataclass class Filter: filter: Callable[[str, torch.Tensor], bool] algorithm: SparsityAlgorithm def __init__(self, groups: Dict[str, SparsityAlgorithm] = None): """ Args: groups: A dictionary of filter -> algorithm pairs. See :py:meth:`~cerebras.pytorch.sparse.Group.add` for more details. """ super().__init__(sparsity=None) self._groups = [] if groups is not None: for group_filter, algorithm in groups.items(): self.add(group_filter, algorithm) @property def num_sparse_params(self): return sum(len(g.algorithm.sparse_params) for g in self._groups) @property def sparsity(self): raise NotImplementedError( "Group sparsity algorithm does not have a sparsity level. " "You can access the sparsity of nested sparsity algorithms by " "indexing the Group object, i.e. group[0].sparsity" ) def __getitem__(self, index) -> SparsityAlgorithm: """Returns the algorithm at the given index.""" return self._groups[index].algorithm
[docs] def add( self, filter: Union[str, Callable[[str, torch.Tensor], bool]], algorithm: SparsityAlgorithm, ): """ Add a sparsity algorithm to the group. Args: filter: A string, list of strings, or callable that takes a parameter name and a parameter tensor and returns True if the parameter should be sparsified. If one or more strings are provided, the filter will match if any of the strings match the parameter name. The strings may contain glob patterns, e.g. "fc1.*" will match all parameters in the "fc1" module. algorithm: An instance of :py:class:`~cerebras.pytorch.sparse.SparsityAlgorithm` """ if not isinstance(algorithm, SparsityAlgorithm): raise TypeError( f"algorithm must be an instance of SparsityAlgorithm, got {type(algorithm)}" ) elif isinstance(algorithm, Group): raise TypeError( f"algorithm must be not be Group sparsity algorithm. " f"If you want to merge groups, use the extend method." ) if isinstance(filter, str): filter = [filter] if isinstance(filter, (list, tuple)): filter_re = list(map(convert_glob_to_regex, filter)) filter = lambda name, _: any( f.match(name) is not None for f in filter_re ) self._groups.append(Group.Filter(filter, algorithm)) elif callable(filter): self._groups.append(Group.Filter(filter, algorithm)) else: raise TypeError( f"filter must be a string or callable, got {type(filter)}" ) self.sparse_params.update(algorithm.sparse_params)
[docs] def extend(self, group: "Group"): """ Extend the group with the filters and algorithms from another group. Args: group: An instance of :py:class:`~cerebras.pytorch.sparse.Group` """ if not isinstance(group, Group): raise TypeError( f"group must be an instance of Group, got {type(group)}" ) for g in group._groups: self.add(g.filter, g.algorithm)
def sparsify_parameter( self, module: torch.nn.Module, name: str, param: torch.Tensor ) -> None: if param is None: # Parameter is None, nothing to sparsify return if self.get_sparse_params(param): # Parameter is already sparsified return if getattr(param, "requires_dense", False): # Parameter has been marked as not sparsifiable return for group in self._groups: if group.filter(name, param): logging.debug(f"Sparsity filter matched: {name}") group.algorithm.sparsify_parameter(module, name, param) # Update the sparse_params self.sparse_params.update(group.algorithm.sparse_params) return else: logging.debug(f"Sparsity filter did *not* match: {name}") def sparsify_module(self, module): if len(self._groups) == 0: raise RuntimeError( "No groups were added to the Group sparsity algorithm" ) super().sparsify_module(module) if sum(len(g.algorithm.sparse_params) for g in self._groups) == 0: warn( "No parameters were sparsified in the module. " "This is likely due to the parameter filter not matching any " "parameters in the module" ) def _forward_pre_hook(self, module, input): for group in self._groups: group.algorithm._forward_pre_hook(module, input) def sparsify_optimizer(self, optimizer): super().sparsify_optimizer(optimizer) # Call sparsify optimizer on each algorithm # so that it can apply any optimizer hooks for group in self._groups: group.algorithm.sparsify_optimizer(optimizer) def update(self, optimizer: Optional[cstorch.optim.Optimizer] = None): for group in self._groups: group.algorithm.update(optimizer) def register_target_sparsity_hook(self, hook): return [ group.algorithm.register_target_sparsity_hook(hook) for group in self._groups ] def register_computed_sparsity_hook(self, hook): return [ group.algorithm.register_computed_sparsity_hook(hook) for group in self._groups ] def visit_state(self, f): for group in self._groups: group.algorithm.visit_state(f) def state_dict(self): return [group.algorithm.state_dict() for group in self._groups] def load_state_dict(self, state_dict): if isinstance(state_dict, dict): state_dict = [state_dict] if isinstance(state_dict, list): if len(state_dict) != len(self._groups): raise ValueError( f"Expected a list of {len(self._groups)} state_dicts for " f"the Group sparsity algorithm but got {len(state_dict)}." ) for s, group in zip(state_dict, self._groups): group.algorithm.load_state_dict(s)