cerebras.pytorch.sparse#

Sparsity Algorithms#

These classes are the built-in sparsity algorithms. SparsityAlgorithm is the abstract base class that all sparsity algorithms should derive from.

class cerebras.pytorch.sparse.SparsityAlgorithm[source]#

Base class for all sparsity algorithms

This class is responsible for sparsifying parameters and registering hooks to apply the sparsity pattern to the parameters before forward and to the gradients after backward. It also registers hooks to update the sparsity pattern after each optimizer step.

Warning

The way that sparse parameters are represented in the cerebras.pytorch API is via a mask tensor. This mask tensor is multiplied inplace to the original dense parameter before forward and to the gradients after backward. However, this is not the way that sparse parameters are represented on a Cerebras system. There, sparse parameters are handled natively in CSR format. As such, there is no mask tensor that can be referenced on the system side. What this means is that using the mask tensor haphazardly can lead to compile failures. Even if compile succeeds, any operations performed on the mask can be very computationally expensive. Having said that, there are several operations on masks that are supported on the Cerebras system. Please see the usage in the prepackaged algorithms as a guide for when and how it is acceptable to use the mask.

Parameters
  • sparsity – The sparsity level to use for the algorithm. This can be a float or a HyperParameterSchedule. If a dictionary is passed in, then it is automatically converted to a HyperParameterSchedule

  • init_method – The method to use to initialize the sparsity mask. See make_init_method for more details.

property num_sparse_params#

Returns the number of parameters that have been sparsified by this algorithm

initialize()[source]#

Initialize the sparsity pattern for all parameters sparsified by this algorithm

csx_annotate_sparsity(param: cerebras.pytorch.sparse.base.SparseParameter)[source]#

Annotate the parameter with hints about the sparsity pattern as performance hints for the Cerebras compiler

Parameters

param – The sparse parameter to annotate with hints

property sparsity: Dict[torch.Tensor, cerebras.pytorch.sparse.utils.HyperParameterSchedule]#

Returns a mapping between a parameter and its sparsity schedule

sparsify_parameter(module: torch.nn.Module, name: str, param: torch.Tensor) None[source]#

Initialize the mask for a parameter in the given module.

Parameters
  • module – The module that owns the parameter

  • name – The full name of the parameter

final apply(obj: Union[torch.nn.Module, torch.optim.Optimizer])[source]#

Sparsifies the passed in object.

Note

This is called implicitly when calling module.apply(sparsity) or optimizer.apply(sparsity)

Parameters

obj – a torch.nn.Module or a torch.optim.Optimizer object to sparsify

sparsify_module(module: torch.nn.Module)[source]#

Sparsify the torch.nn.Module object

Parameters

module – the torch.nn.Module object to sparsify

prune_weights()#

Prune the dense weights.

Note

This is called automatically in a module forward pre-hook

_grad_hook(p, grad)[source]#

Hook to apply the prune the gradients after backward()

Note

This is called automatically in the parameter’s backward grad hook

Parameters
  • p – The original parameter

  • grad – The gradient of the parameter

sparsify_optimizer(optimizer: torch.optim.Optimizer)[source]#

Sparsify the torch.optim.Optimizer object

Parameters

optimizer – the torch.optim.Optimizer object to sparsify

abstract update(optimizer: Optional[cerebras.pytorch.optim.optimizer.Optimizer] = None)[source]#

Update the parameter’s sparsity masks

Parameters

optimizer – The optimizer that is being used to update the sparse parameters

visit_state(f: Callable)[source]#

Apply a callable to the stateful tensors

state_dict()[source]#

Return a dictionary of all stateful tensors

load_state_dict(state_dict)[source]#

Load the state of all stateful tensors

Static Sparsity Algorithms#

class cerebras.pytorch.sparse.Static[source]#

Bases: cerebras.pytorch.sparse.base.SparsityAlgorithm

Parameters

sparsity – A float specifying the level of sparsity to apply to each parameter

Dynamic Sparsity Algorithms#

class cerebras.pytorch.sparse.DynamicSparsityAlgorithm[source]#

Bases: cerebras.pytorch.sparse.base.SparsityAlgorithm, abc.ABC

Parameters
  • sparsity

    A float specifying the level of sparsity to apply to each parameter or a dictionary specifying the schedule to use for sparsity. The dictionary must have a “type” key, which specifies the type of schedule to use. The remaining keys are schedule-specific. The following schedule types are supported:

  • update – A dictionary specifying the schedule to use for updating the sparsity pattern. The dictionary must contain keys that can be used to construct either a FreqSchedule or a ListSchedule. If not provided, the sparsity pattern will be updated every step.

  • add_summaries – Whether to add summaries for the sparsity patterns

property is_update_step#

Returns True if the current step is an update step according to the update schedule.

abstract update_mask(p, mask, sparsity) torch.Tensor#

Compute an updated sparsity pattern.

Parameters
  • p (torch.Tensor) – the parameter to sparsify

  • mask (torch.tensor(dtype=torch.bool)) – the current mask of param p

  • sparsity (torch.tensor(dtype=torch.float32)) – the desired sparsity level

Returns

The updated sparsity pattern on parameter p

class cerebras.pytorch.sparse.GMP[source]#

Bases: cerebras.pytorch.sparse.dynamic.DynamicSparsityAlgorithm

Implements Gradual Magnitude Pruning

Sparsity increases monotonically based on weight magnitude.

See: https://arxiv.org/abs/1710.01878

Parameters

**kwargs – All arguments are passed to the DynamicSparsityAlgorithm’s constructor.

Example

sparsity_opt = cstorch.sparse.GMP(

schedule={“type”: “exp”, “init”: 0, “gamma”: 1000*math.log(0.3) update={“freq”: 1000},

)

class cerebras.pytorch.sparse.SET[source]#

Bases: cerebras.pytorch.sparse.dynamic.DynamicSparsityAlgorithm

Implements Sparse Evolutionary Training (SET)

Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown randomly.

See: https://arxiv.org/abs/1707.04780

Parameters
  • drop_fraction – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter.

  • **kwargs – Any additional arguments are passed to the cstorch.sparse.DynamicSparsityAlgorithm’s constructor.

Example:

sparsity_opt = cstorch.sparse.SET(
    sparsity=0.9,
    update={"freq": 100, "stop": 1000},
    drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000},
)
class cerebras.pytorch.sparse.RigL[source]#

Bases: cerebras.pytorch.sparse.dynamic.DynamicSparsityAlgorithm

Implements Rigging the Lottery (RigL)

Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown using a proxy measure of where a pruned connection would have had the most impact by finding the highest magnitude (dense) gradients of pruned weights.

See: https://arxiv.org/abs/1911.11134

Parameters
  • drop_fraction – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter.

  • balance_in_groups – The number of groups used by InputGroupScoreShaper

  • balance_out_groups – The number of groups used by OutputGroupScoreShaper

  • **kwargs – Any additional arguments are passed to the DynamicSparsityAlgorithm’s constructor.

  • super ().__init__(**kwargs) –

Example:

sparsity = cstorch.sparse.RiGL(
    sparsity=0.9,
    update={"freq": 100, "stop": 1000},
    drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000},
)

Group Sparsity Algorithm#

class cerebras.pytorch.sparse.Group[source]#

Bases: cerebras.pytorch.sparse.base.SparsityAlgorithm

Group sparsity algorithm. This algorithm allows for multiple sparsity algorithms to be applied to different groups of parameters.

For example:

sparsity = cstorch.sparse.Group({
    "fc1.*": cstorch.sparse.Static(sparsity=0.5),
    "fc2.*": cstorch.sparse.GMP(
        schedule=[0.3, 0.4, 0.5],
        update: {"freq": 100}
    ),
})
sparsity.add("fc3.*", cstorch.sparse.RigL(sparsity=0.5))

model.apply(sparsity)
optimizer.apply(sparsity)

The group sparsity algorithm will apply the sparsity algorithms to the parameters that match the filter. If a parameter name matches multiple filters, the first filter that matches will be used.

Parameters

groups – A dictionary of filter -> algorithm pairs. See add for more details.

add(filter: Union[str, Callable[[str, torch.Tensor], bool]], algorithm: cerebras.pytorch.sparse.base.SparsityAlgorithm)[source]#

Add a sparsity algorithm to the group.

Parameters
  • filter

    A string, list of strings, or callable that takes a parameter name and a parameter tensor and returns True if the parameter should be sparsified.

    If one or more strings are provided, the filter will match if any of the strings match the parameter name. The strings may contain glob patterns, e.g. “fc1.*” will match all parameters in the “fc1” module.

  • algorithm – An instance of SparsityAlgorithm

extend(group: cerebras.pytorch.sparse.group.Group)[source]#

Extend the group with the filters and algorithms from another group.

Parameters

group – An instance of Group

Configuration routine#

The highest level entry-point to enabling sparsity is configure, which will configure a sparsity algorithm and return it. The config dictionary follows the same form as given in Sparsity via YAML.

cerebras.pytorch.sparse.configure(config: Union[float, dict, List[dict]]) cerebras.pytorch.sparse.group.Group[source]#

If param_filter is not provided, the following default param filter gets applied.

cerebras.pytorch.sparse.configure.default_sparse_param_filter(name: str, param: torch.nn.Parameter) bool[source]#

Return True if the given parameter should be sparse.

Only returns true if the parameter is > 1D and not an embedding or norm or lm_head or pe_helper.

Parameters
  • name – Name of the parameter

  • param – The parameter itself

Customizing Sparsity & Reference#

Several building blocks can be inherited from or composed to help build new dynamic sparsity algorithms or customize the behavior of existing ones.

cerebras.pytorch.sparse.init#

Sparsity mask initialization methods and helpers, invoked by SparsityAlgorithm.

cerebras.pytorch.sparse.init.random(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras.pytorch.sparse.utils.ScoreShaper] = None, device: Optional[torch.device] = None) torch.BoolTensor[source]#

Uniformly random sparsity pattern.

A score tensor with the same shape as the parameter is randomly generated with values between 0.0 and 1.0. The mask is then created by taking the top-k of the score tensor, where k is determined by the sparsity level.

cerebras.pytorch.sparse.init.topk(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras.pytorch.sparse.utils.ScoreShaper] = None, device: Optional[torch.device] = None) torch.BoolTensor[source]#

Prune lowest magnitude weights.

cerebras.pytorch.sparse.init.from_zeros(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras.pytorch.sparse.utils.ScoreShaper] = None, device: Optional[torch.device] = None) torch.BoolTensor[source]#

Any zeros currently in the weights represent pruned connections. NOTE: Doesn’t actualy honor the configured sparsity.

cerebras.pytorch.sparse.init.checkerboard(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras.pytorch.sparse.utils.ScoreShaper] = None, device: Optional[torch.device] = None) torch.BoolTensor[source]#

Mostly for stress and performance testing, creates a sparsity mask that is maximally distributed in a checkerboard across the weight.

cerebras.pytorch.sparse.init.make_init_method(init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras.pytorch.sparse.utils.ScoreShaper], Optional[torch.device]], torch.BoolTensor]]) Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras.pytorch.sparse.utils.ScoreShaper], Optional[torch.device]], torch.BoolTensor][source]#

Returns the corresponding init method callable for the given init_method.

Parameters

init_method

The method to use to initialize the sparsity mask. This can be a string or a callable. If a string, it must be one of

  • random”: Randomly initialize the mask

  • topk”: prune the lowest magnitude weights

  • from_zeros”: Any zeros in the weights represent pruned connections

  • checkerboard”: Creates a sparsity mask that is maximally distributed across the weight

If a callable, it must have the signature:

def init_method(
    param: torch.Tensor,
    sparsity: float,
    scope_shaper: Optional[ScoreShaper] = None,
    device: Optional[torch.device] = None
) -> torch.Tensor:
where
  • param is the original dense parameter

  • sparsity is the sparsity level

  • scope_shaper is an optional callable that can be used to reshape the mask

  • device is optionally the device to use to initialize the mask

cerebras.pytorch.sparse.utils#

class cerebras.pytorch.sparse.utils.HyperParameterSchedule[source]#

Base class for step-aware hyperparameters used in Sparsity Optimizers.

abstract compute(step: torch.Tensor) torch.Tensor[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

__call__(step: torch.Tensor) torch.Tensor[source]#

Call self as a function.

update(is_update_step: torch.Tensor)[source]#

Given a boolean tensor indicating if this is an update step, update the internal state of this hyperparameter.

Parameters

is_update_step – A boolean tensor indicating if this is an update step.

visit_state(fn)[source]#

Applies a lambda to each stateful value.

get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

class cerebras.pytorch.sparse.utils.Constant[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Constant at every step.

Parameters

value – The constant value of the hyperparameter

compute(step: torch.Tensor)[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

class cerebras.pytorch.sparse.utils.Linear[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Linear change from an initial value.

\(y(step) = init + step \cdot slope\)

Parameters
  • init – The initial value of the hyperparameter

  • slope – The rate of change of the hyperparameter

compute(step: torch.Tensor)[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

class cerebras.pytorch.sparse.utils.Exp[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Exponential, approaching an asymptotic final value

\(y(step) = final + (init-final) e^{step \cdot gamma}\)

Parameters
  • init – The initial value of the hyperparameter

  • gamma – The rate of change of the hyperparameter

  • final – The final value of the hyperparameter (Default: 1.0)

compute(step: torch.Tensor)[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

class cerebras.pytorch.sparse.utils.Power[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Power law.

\(y(step) = init \cdot beta^{step}\)

Parameters
  • init – The initial value of the hyperparameter

  • beta – The rate of change of the hyperparameter

compute(step: torch.Tensor)[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

class cerebras.pytorch.sparse.utils.Cosine[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Cosine function for oscilating between an initial (maximum) value down to a minimum and back to the maximum every period.

\(y(step) = o + a \cdot \cos(step \cdot \pi / half\_period)\), where \(o = (init + minimum)/2\) and \(a = init - o\).

Parameters
  • init – The initial value of the hyperparameter

  • half_period – The number of steps to complete a full cycle

  • minimum – The minimum value of the hyperparameter

compute(step: torch.Tensor)[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

class cerebras.pytorch.sparse.utils.Cycling[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Hyper parameter cycling between discrete values at update steps.

Parameters

values – A list of discrete values to cycle through

compute(step: torch.Tensor) torch.Tensor[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

update(is_update_step: torch.Tensor)[source]#

Given a boolean tensor indicating if this is an update step, update the internal state of this hyperparameter.

Parameters

is_update_step – A boolean tensor indicating if this is an update step.

get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

visit_state(fn)[source]#

Applies a lambda to each stateful value.

class cerebras.pytorch.sparse.utils.Lambda[source]#

Bases: cerebras.pytorch.sparse.utils.HyperParameterSchedule

Invoke a user’s lambda function of step to obtain the hyper parameter.

Parameters

fn – A lambda function that takes a step and returns a hyperparameter

compute(step: torch.Tensor)[source]#

Return a torch.Tensor with the value of the hyperparatmer at the given step.

Parameters

step – int64 tensor holding current step

Returns

torch.Tensor on the device of step with the value of the

hyperparamter

get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

cerebras.pytorch.sparse.utils.make_hyperparam_schedule(schedule)[source]#

Given some user specified configuration, construct a HyperParameterSchedule object that is step aware.

class cerebras.pytorch.sparse.utils.UpdateSchedule[source]#

Bases: abc.ABC

abstract is_update_step(step: torch.LongTensor) torch.BoolTensor[source]#

Given a training step rankless tensor, return a rankless bool tensor if this is a sparsity update step.

class cerebras.pytorch.sparse.utils.FreqSchedule[source]#

Bases: cerebras.pytorch.sparse.utils.UpdateSchedule

When schedulding sparsity update steps on a regular interval, this class allows configuring the start and stop step in addition to the update frequency.

Parameters
  • freq – The frequency of steps at which to update the sparsity pattern (Default: 1)

  • start – The step at which to start updating the sparsity pattern (Default: 0)

  • stop – The step at which to stop updating the sparsity pattern (Default: None)

is_update_step(step: torch.LongTensor) torch.BoolTensor[source]#

Given a training step rankless tensor, return a rankless bool tensor if this is a sparsity update step.

class cerebras.pytorch.sparse.utils.ListSchedule[source]#

Bases: cerebras.pytorch.sparse.utils.UpdateSchedule

When schedulding requires an irregular update cadence, explicit steps can be provided as a list.

Parameters

steps – A list of steps at which to update the sparsity pattern

is_update_step(step: torch.LongTensor) torch.BoolTensor[source]#

Given a training step rankless tensor, return a rankless bool tensor if this is a sparsity update step.

cerebras.pytorch.sparse.utils.make_update_schedule(update: Union[Dict, Callable[[torch.LongTensor], torch.BoolTensor]]) Callable[[torch.LongTensor], torch.BoolTensor][source]#

Instantiate a supported schedule type.

class cerebras.pytorch.sparse.utils.ScoreShaper[source]#

Bases: abc.ABC

class cerebras.pytorch.sparse.utils.ScoreFlattener[source]#

Bases: cerebras.pytorch.sparse.utils.ScoreShaper

Default ScoreShaper which everything is flattened, providing a global competition for magnitude. If only sub-portions of the weight should compete for magnitude, provide an alternative shaper object.

class cerebras.pytorch.sparse.utils.OutputGroupScoreShaper[source]#

Bases: cerebras.pytorch.sparse.utils.ScoreShaper

A ScoreShaper interface when weights are logically shaped as [num_groups*out_per_group, insize], but need to be scored in a “balanced” fashion as [num_groups, out_per_group*insize]

Examples

>>> # Common score used for the following examples
>>> score=torch.tensor([[1.0, 2.0],
...                     [0.0, -1.0]])
>>> # 50% sparsity, drops the 2 lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
... )
tensor([[ True,  True],
        [False, False]])
>>> # 50% sparsity, but computed rowwise
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
...     score_shaper=OutputGroupScoreShaper(num_groups=2)
... )
tensor([[False,  True],
        [ True, False]])
class cerebras.pytorch.sparse.utils.InputGroupScoreShaper[source]#

Bases: cerebras.pytorch.sparse.utils.ScoreShaper

A ScoreShaper interface when weights are logically shaped as [outsize, num_groups*in_per_group], but need to be scored in a “balanced” fashion as [num_groups, outsize*in_per_group]

Examples

>>> # Common score used for the following examples
>>> score=torch.tensor([[1.0, 0.0],
...                     [2.0, -1.0]])
>>> # 50% sparsity, drops the 2 lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
... )
tensor([[ True, False],
        [ True, False]])
>>> # 50% sparsity, but computed columnwise
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
...     score_shaper=InputGroupScoreShaper(num_groups=2)
... )
tensor([[False,  True],
        [ True, False]])
cerebras.pytorch.sparse.utils.make_mask_drop_minimum(score: torch.FloatTensor, mask: torch.BoolTensor, drop_fraction: torch.FloatTensor, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor[source]#

Given a sparse score (with mask), return a new torch.BoolTensor the same shape as mask where a drop_fraction portion of the currently present (mask==True) connections are dropped (mask==False).

The connections are dropped at positions corresponding to the lowest values of score.

Equivalently, a subset of mask is returned corresponding to the highest magnitude elements of score.

Parameters
  • score – Values used to evaluate which positions to drop

  • mask – Current connections, same shape as score

  • drop_fraction – What fraction of current connections to drop

  • score_shaper – If given, score (and mask) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default, score and mask are reinterpreted as 1D tensors, yielding completely unstructured sparsity.

Returns

New mask that has existing connections dropped. No connections will be regrown (unless drop_fraction is negative).

cerebras.pytorch.sparse.utils.make_mask_grow_maximum(score: torch.FloatTensor, mask: torch.BoolTensor, sparsity: torch.FloatTensor, mask_nonzero: Optional[torch.IntTensor] = None, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor[source]#

Given a sparse score (with mask), return a new torch.BoolTensor the same shape as mask where some currently pruned connections are regrown (from those positions with the highest score) such that the returned mask has the given target sparsity.

If mask is already less sparse (has more connections) than the target, none are regrown and the original mask is returned as-is. That is, the given mask should be more sparse than the target sparsity.

Parameters
  • score – Values used to evaluate which positions to regrow

  • mask – Current connections, same shape as score

  • drop_fraction – What fraction of current connections to drop

  • mask_nonzero – If given, the number of nonzero elements currently in the mask, used to control the number of connections needing regrowth. If it is not given, will be computed as mask.nonzero().int(). Since make_mask_grow_maximum is often used in conjunction with make_mask_drop_minimum, this value is commonly available.

  • score_shaper – If given, score (and mask) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default, score and mask are reinterpreted as 1D tensors, yielding completely unstructured sparsity.

Returns

New mask that has connections regrown necessary to reach (decrease to) the target sparsity.

cerebras.pytorch.sparse.utils.make_mask_topk_sparsity(score: torch.FloatTensor, sparsity: torch.FloatTensor, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor[source]#

Given a dense score, return a torch.BoolTensor which is True at positions corresponding to values in the top k = (1-sparsity)*score.numel() of score.

Parameters
  • score – Values used to evaluate which positions to keep.

  • sparsity – rankless tensor in range [0,1] controlling fraction of the resulting mask that will be pruned.

  • score_shaper – If given, score will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default, score is reinterpreted as a 1D tensor, yielding completely unstructured sparsity.

Returns

mask with given sparsity, keeping only the highest values from score.

Examples

>>> # Common score used for the following examples
>>> score=torch.tensor([[1.0, 2.0],
...                     [0.0, -1.0]])
>>> # 25% sparsity, drops the one lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.25),
... )
tensor([[ True,  True],
        [ True, False]])
>>> # 75% sparsity, drops the 3 lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.75),
... )
tensor([[False,  True],
        [False, False]])