cerebras_pytorch.sparse#

Configuration routines#

The highest level entry-point to enabling sparsity is configure_sparsity_wrapper, which will return a drop-in replacement optimizer that automatically applies the sparsity algorithm according to the high-level configuration dictionary. These config dictionaries (actually passed as **kwargs) follow the same form as given in Sparsity.

cerebras_pytorch.sparse.configure_sparsity_wrapper(model: torch.nn.Module, optimizer: torch.optim.Optimizer, sparsity_type: str, param_name_patterns: Optional[Union[str, List[str], Dict[str, dict]]] = None, sparse_param_filter: Optional[Callable[[str, torch.nn.Parameter], bool]] = None, **kwargs) cerebras_pytorch.sparse.wrapper.SparsityWrapperOptimizer[source]#

Returns a SparsityWrapperOptimizer ready to be a drop-in replacement for the given optimizer, while also constructing a SparsityOptimizer of the appropriate sparsity_type according to param_name_patterns or sparse_param_filter of the given model.named_parameters().

**kwargs are passed along to the SparsityOptimizer __init__

Parameters
  • model – Root module to extract parameters and hook the FWD pass

  • optimizer – Optimizer to wrap to sparsify the optimizer state.

  • sparsity_type – Type of sparsity optimizer to construct.

  • param_name_patterns – Filter to select which parameters are sparse and optionally if any more specific config should be applied to certain parameters.

  • sparse_param_filter – Callable to provide fallback selection of which parameters are sparse in case no param_name_patterns are provided.

  • kwargs – Passed along to the sparsity optimizer __init__.

For more control, the helper configure_sparsity_optimizer can construct a BaseSparsityOptimizer from the same high-level configuration dictionary.

cerebras_pytorch.sparse.configure_sparsity_optimizer(sparsity_type: str, named_parameters: List[Tuple[str, torch.nn.Parameter]], param_name_patterns: Optional[Union[str, List[str], Dict[str, dict]]] = None, sparse_param_filter: Optional[Callable[[str, torch.nn.Parameter], bool]] = None, **kwargs) cerebras_pytorch.sparse.base.BaseSparsityOptimizer[source]#

Construct a SparsityOptimizer of the appropriate sparsity_type according to param_name_patterns or sparse_param_filter of the given named_parameters.

**kwargs are passed along to the SparsityOptimizer __init__

Parameters
  • sparsity_type – Type of sparsity optimizer to construct.

  • named_parameters – List of (name, param) from model.named_parameters()

  • param_name_patterns – Filter to select which parameters are sparse and optionally if any more specific config should be applied to certain parameters.

  • sparse_param_filter – Callable to provide fallback selection of which parameters are sparse if no param_name_patterns are provided.

  • kwargs – Passed along to the chosen sparsity optimizer __init__.

These optimizers can be used manually like any other pytorch optimizer, but see hook_module for having them automatically apply sparsity during a module’s forward() and backward(). The same high-level drop-in replacment optimizer wrapper from configure_sparsity_wrapper can be directly constructed and used.

class cerebras_pytorch.sparse.SparsityWrapperOptimizer[source]#

Helper Optimizer that can be used as a drop-in replacement for the main optimizer that also takes care of updating and applying sparsity.

__init__(optimizer, sparsity_optimizer)[source]#
visit_state(fn)[source]#

Applies a lambda to each stateful value.

Sparsity Optimizers#

These classes are the built-in sparsity algorithms. StaticSparsityOptimizer is an “optimizer” that maintains a static sparsity pattern throughout all training. The rest implement published dyanmic sparsity algorithms. These are the objects returned from configure_sparsity_optimizer et al.

Even though static sparsity never updates it sparsity pattern throughout training, it is still implemented as an “Optimizer” to provide a consistent API and allow static & dynamic sparsity to be easily swapped via configuration.

class cerebras_pytorch.sparse.StaticSparsityOptimizer[source]#

Implements a static sparsity optimizer.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • sparsity (float) – target sparsity

  • init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = StaticSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity=0.5,
    )
>>> sparsity_opt.hook_module(model)
>>> sparsity_opt.initialize_sparsity()
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, sparsity=torch.optim.optimizer.required, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', **kwargs)[source]#
class cerebras_pytorch.sparse.GMPSparsityOptimizer[source]#

Implements Gradual Magnitude Pruning

Sparsity increases monotonically based on weight magnitude.

See: https://arxiv.org/abs/1506.02626

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.

  • sparsity – Sparsity, either constant or step-aware hyperparameter

  • schedule

    Sparsity update schedule. May be one of:

    • int: Single regular update frequency.

    • list: Irregular update on the given steps.

    • dict: Containing {"start": start, "freq": freq, "stop": stop} for regular updates with start & stop.

    • ScheduleCallable : User function accepting a rankless torch.LongTensor and returning a rankless torch.BoolTensor

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = GMPSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity={"type": "exp", "init": 0, "gamma": 1000*math.log(0.3)
        schedule=1000,
    )
>>> sparsity_opt.hook_module(model)
>>> sparsity_opt.initialize_sparsity()
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', sparsity: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, tiebreak=None, **kwargs)[source]#
class cerebras_pytorch.sparse.SETSparsityOptimizer[source]#

Implements Sparse Evolutionary Training (SET)

Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown randomly.

See: https://arxiv.org/abs/1707.04780

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.

  • sparsity – Sparsity, either constant or step-aware hyperparameter

  • schedule

    Sparsity update schedule. May be one of:

    • int: Single regular update frequency.

    • list: Irregular update on the given steps.

    • dict: Containing {"start": start, "freq": freq, "stop": stop} for regular updates with start & stop.

    • ScheduleCallable : User function accepting a rankless torch.LongTensor and returning a rankless torch.BoolTensor

  • drop_fraction – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter.

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = SETSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity=0.9,
        schedule={"freq": 100, "stop": 1000},
        drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000},
    )
>>> sparsity_opt.hook_module(model)
>>> sparsity_opt.initialize_sparsity()
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', sparsity: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, drop_fraction: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = 0.3, **kwargs)[source]#
class cerebras_pytorch.sparse.RigLSparsityOptimizer[source]#

Implements Rigging the Lottery (RigL)

Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown using a proxy measure of where a pruned connection would have had the most impact by finding the highest magnitude (dense) gradients of pruned weights.

See: https://arxiv.org/abs/1911.11134

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.

  • sparsity – Sparsity, either constant or step-aware hyperparameter

  • schedule

    Sparsity update schedule. May be one of:

    • int: Single regular update frequency.

    • list: Irregular update on the given steps.

    • dict: Containing {"start": start, "freq": freq, "stop": stop} for regular updates with start & stop.

    • ScheduleCallable : User function accepting a rankless torch.LongTensor and returning a rankless torch.BoolTensor

  • drop_fraction – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter.

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = RigLSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity=0.9,
        schedule={"freq": 100, "stop": 1000},
        drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000},
    )
>>> sparsity_opt.hook_module(model)
>>> sparsity_opt.initialize_sparsity()
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', sparsity: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, drop_fraction: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = 0.3, **kwargs)[source]#
zero_grad(set_to_none: bool = True)[source]#

Clears the accumulated dense gradients.

Customizing Sparsity & Reference#

Several building blocks can be inherited from or composed to help build new dynamic sparsity algorithms or customize the behavior of existing ones.

cerebras_pytorch.sparse.init#

Sparsity mask initialization methods and helpers, invoked by BaseSparsityOptimizer.

cerebras_pytorch.sparse.init.checkerboard(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor[source]#

Mostly for stress and performance testing, creates a sparsity mask that is maximally distributed in a checkerboard across the weight.

cerebras_pytorch.sparse.init.from_zeros(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor[source]#

Any zeros currently in the weights represent pruned connections. NOTE: Doesn’t actualy honor the configured sparsity.

cerebras_pytorch.sparse.init.random(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor[source]#

Uniformly random sparsity pattern.

cerebras_pytorch.sparse.init.topk(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor[source]#

Prune lowest magnitude weights.

cerebras_pytorch.sparse.utils#

class cerebras_pytorch.sparse.utils.BaseHyperParameter[source]#

Base class for step-aware hyperparameters used in Sparsity Optimizers.

static get_cls(typename: str)[source]#

Looks up the class by its typename in the registry.

Raises a ValueError if none exist with that name.

get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

visit_state(fn)[source]#

Applies a lambda to each stateful value.

class cerebras_pytorch.sparse.utils.ConstantHyperParameter[source]#

Constant at every step.

__init__(value)[source]#
class cerebras_pytorch.sparse.utils.CosineHyperParameter[source]#

Cosine function for oscilating between an initial (maximum) value down to a minimum and back to the maximum every period.

\(y(step) = o + a \cdot \cos(step \cdot \pi / half\_period)\), where \(o = (init + minimum)/2\) and \(a = init - o\).

__init__(init, half_period, minimum=0.0)[source]#
get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

class cerebras_pytorch.sparse.utils.CyclingHyperParameter[source]#

Hyper parameter cycling between discrete values at update steps.

__init__(values)[source]#
get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

class cerebras_pytorch.sparse.utils.ExpHyperParameter[source]#

Exponential, approaching an asymptotic final value

\(y(step) = final + (init-final) e^{step \cdot gamma}\)

__init__(init, gamma, final=1)[source]#
class cerebras_pytorch.sparse.utils.FirstScoreTiebreaker[source]#
__init__(eps)[source]#
class cerebras_pytorch.sparse.utils.InputGroupScoreShaper[source]#

A ScoreShaper interface when weights are logically shaped as [outsize, num_groups*in_per_group], but need to be scored in a “balanced” fashion as [num_groups, outsize*in_per_group]

Examples

>>> # Common score used for the following examples
>>> score=torch.tensor([[1.0, 0.0],
...                     [2.0, -1.0]])
>>> # 50% sparsity, drops the 2 lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
... )
tensor([[ True, False],
        [ True, False]])
>>> # 50% sparsity, but computed columnwise
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
...     score_shaper=InputGroupScoreShaper(num_groups=2)
... )
tensor([[False,  True],
        [ True, False]])
__init__(num_groups)[source]#
class cerebras_pytorch.sparse.utils.LambdaHyperParameter[source]#

Invoke a user’s lambda function of step to obtain the hyper parameter.

__init__(fn)[source]#
get_min_max_end(begin: int, end: int) Tuple[float, float, float][source]#

Given a beginning and ending step, compute the statistics of this step-aware hyper parameter. Used for estimating memory requirements for dynamic sparsity.

Return [min, max, ending]

class cerebras_pytorch.sparse.utils.LinearHyperParameter[source]#

Linear change from an initial value.

\(y(step) = init + step * slope\)

__init__(init, slope)[source]#
class cerebras_pytorch.sparse.utils.OutputGroupScoreShaper[source]#

A ScoreShaper interface when weights are logically shaped as [num_groups*out_per_group, insize], but need to be scored in a “balanced” fashion as [num_groups, out_per_group*insize]

Examples

>>> # Common score used for the following examples
>>> score=torch.tensor([[1.0, 2.0],
...                     [0.0, -1.0]])
>>> # 50% sparsity, drops the 2 lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
... )
tensor([[ True,  True],
        [False, False]])
>>> # 50% sparsity, but computed rowwise
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.5),
...     score_shaper=OutputGroupScoreShaper(num_groups=2)
... )
tensor([[False,  True],
        [ True, False]])
__init__(num_groups)[source]#
class cerebras_pytorch.sparse.utils.PowerHyperParameter[source]#

Power law.

\(y(step) = init \cdot beta^{step}\)

__init__(init, beta)[source]#
class cerebras_pytorch.sparse.utils.RandomScoreTiebreaker[source]#
__init__(eps)[source]#
class cerebras_pytorch.sparse.utils.ScoreFlattener[source]#

Default ScoreShaper which everything is flattened, providing a global competition for magnitude. If only sub-portions of the weight should compete for magnitude, provide an alternative shaper object.

class cerebras_pytorch.sparse.utils.ScoreShaper[source]#
class cerebras_pytorch.sparse.utils.ScoreTiebreaker[source]#

Base class for all “tiebreaking” of score for deterministic execution. The default tiebreaking is “none”, leaving ties non-deterministic across systems. In particular, torch.topk has different behavior on CPU, GPU, and WSE. It isn’t stable and doesn’t guarantee anything around ties.

cerebras_pytorch.sparse.utils.initialize_hyperparam(param)[source]#

Given some user specified configuration, construct a BaseHyperParameter class that is step aware.

cerebras_pytorch.sparse.utils.make_mask_drop_minimum(score: torch.FloatTensor, mask: torch.BoolTensor, drop_fraction: torch.FloatTensor, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor[source]#

Given a sparse score (with mask), return a new torch.BoolTensor the same shape as mask where a drop_fraction portion of the currently present (mask==True) connections are dropped (mask==False).

The connections are dropped at positions corresponding to the lowest values of score.

Equivalently, a subset of mask is returned corresponding to the highest magnitude elements of score.

Parameters
  • score – Values used to evaluate which positions to drop

  • mask – Current connections, same shape as score

  • drop_fraction – What fraction of current connections to drop

  • score_shaper – If given, score (and mask) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default, score and mask are reinterpreted as 1D tensors, yielding completely unstructured sparsity.

Returns

New mask that has existing connections dropped. No connections will be regrown (unless drop_fraction is negative).

cerebras_pytorch.sparse.utils.make_mask_grow_maximum(score: torch.FloatTensor, mask: torch.BoolTensor, sparsity: torch.FloatTensor, mask_nonzero: Optional[torch.IntTensor] = None, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor[source]#

Given a sparse score (with mask), return a new torch.BoolTensor the same shape as mask where some currently pruned connections are regrown (from those positions with the highest score) such that the returned mask has the given target sparsity.

If mask is already less sparse (has more connections) than the target, none are regrown and the original mask is returned as-is. That is, the given mask should be more sparse than the target sparsity.

Parameters
  • score – Values used to evaluate which positions to regrow

  • mask – Current connections, same shape as score

  • drop_fraction – What fraction of current connections to drop

  • mask_nonzero – If given, the number of nonzero elements currently in the mask, used to control the number of connections needing regrowth. If it is not given, will be computed as mask.nonzero().int(). Since make_mask_grow_maximum is often used in conjunction with make_mask_drop_minimum, this value is commonly available.

  • score_shaper – If given, score (and mask) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default, score and mask are reinterpreted as 1D tensors, yielding completely unstructured sparsity.

Returns

New mask that has connections regrown necessary to reach (decrease to) the target sparsity.

cerebras_pytorch.sparse.utils.make_mask_topk_sparsity(score: torch.FloatTensor, sparsity: torch.FloatTensor, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor[source]#

Given a dense score, return a torch.BoolTensor which is True at positions corresponding to values in the top k = (1-sparsity)*score.numel() of score.

Parameters
  • score – Values used to evaluate which positions to keep.

  • sparisty – rankless tensor in range [0,1] controlling fraction of the resulting mask that will be pruned.

  • score_shaper – If given, score will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default, score is reinterpreted as a 1D tensor, yielding completely unstructured sparsity.

Returns

mask with given sparsity, keeping only the highest values from score.

Examples

>>> # Common score used for the following examples
>>> score=torch.tensor([[1.0, 2.0],
...                     [0.0, -1.0]])
>>> # 25% sparsity, drops the one lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.25),
... )
tensor([[ True,  True],
        [ True, False]])
>>> # 75% sparsity, drops the 3 lowest magnitude
>>> make_mask_topk_sparsity(
...     score=score,
...     sparsity=torch.tensor(0.75),
... )
tensor([[False,  True],
        [False, False]])
cerebras_pytorch.sparse.utils.set_param_group_hyperparam(param_group, name)[source]#

Updates the param_group option inplace as a hyperparam.

cerebras_pytorch.sparse.dynamic#

Base class for all dynamic sparsity optimizer, plus dynamic schedule helpers.

class cerebras_pytorch.sparse.dynamic.BaseSchedule[source]#
static get_cls(typename: str)[source]#

Looks up the class by its typename in the registry.

Raises a ValueError if none exist with that name.

class cerebras_pytorch.sparse.dynamic.DynamicSparsityOptimizer[source]#

Abstract base class for a dynamic sparsity optimizer.

Subclasses must implement update_mask.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.

  • sparsity – Sparsity, either constant or step-aware hyperparameter

  • schedule

    Sparsity update schedule. May be one of:

    • int: Single regular update frequency.

    • list: Irregular update on the given steps.

    • dict: Containing {"start": start, "freq": freq, "stop": stop} for regular updates with start & stop.

    • ScheduleCallable : User function accepting a rankless torch.LongTensor and returning a rankless torch.BoolTensor

__init__(params, sparsity=torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', **kwargs)[source]#
abstract update_mask(p, mask, sparsity, group)#

Compute an updated sparsity pattern.

Parameters
  • p (torch.Tensor) – the parameter to sparsify

  • mask (torch.tensor(dtype=torch.bool)) – the current mask of param p

  • sparsity (torch.tensor(dtype=torch.float32)) – the desired sparsity level

  • group (dict) – The param group dict with any additional options

Returns

The updated sparsity pattern on parameter p

visit_state(fn)[source]#

Applies a lambda to each stateful value.

class cerebras_pytorch.sparse.dynamic.FreqSchedule[source]#

When schedulding sparsity update steps on a regular interval, this class allows configuring the start and stop step in addition to the update frequency.

__init__(start=None, freq=1000, stop=None)[source]#
class cerebras_pytorch.sparse.dynamic.ListSchedule[source]#

When schedulding requires an irregular update cadence, explicit steps can be provided as a list.

__init__(steps: Union[List[int], torch.Tensor])[source]#
cerebras_pytorch.sparse.dynamic.make_schedule(schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]]) Callable[[torch.LongTensor], torch.BoolTensor][source]#

Instantiate a supported schedule type.

cerebras_pytorch.sparse.base#

class cerebras_pytorch.sparse.base.BaseSparsityOptimizer[source]#

Abstract base class for a dynamic sparsity optimizer.

Subclasses must implement _get_target_sparsity_level_of_group and step.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • init_method (InitMethodType) – method by which sparsity is initialized

  • defaults (dict) – Additional defaults for param_groups

__init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', defaults=None)[source]#
annotate_sparsity()[source]#

Annotate sparsity as performance hints for the cerebras compiler

apply_sparsity()#

Apply the sparsity pattern to the parameters and optimizer states.

hook_module(module: torch.nn.Module)[source]#

Hook the given module such that the sparsity pattern is applied to both the parameters before forward() and gradients after backward()

initialize_sparsity()[source]#

Compute the initial sparsity pattern for each parameter.

manage_optimizer_state_sparsity(optimizer: torch.optim.optimizer.Optimizer, state_names: List[str])[source]#

Manage the sparsity of an optimizer’s state. For any parameters that this SparsityOptimizer manages, apply the sparsity pattern to all states named state_names

visit_state(fn)[source]#

Applies a lambda to each stateful value.

zero_grad(set_to_none: bool = True)[source]#

Override default torch.optim.Optimizer to never zero gradients: This optimizer is slightly unique in that it isn’t responsible for the main weight update of the params it manages (and thus doesn’t consult or “maintain” their gradients), but it does manage the sparsity pattern of the params.

Can be further overriden in other SparsityOptimizers if they deal with gradients (like RigL).