cerebras_pytorch.sparse
#
Configuration routines#
The highest level entry-point to enabling sparsity is
configure_sparsity_wrapper
, which will
return a drop-in replacement optimizer that automatically applies the sparsity
algorithm according to the high-level configuration dictionary. These config
dictionaries (actually passed as **kwargs
) follow the same form as given in
Sparsity.
- cerebras_pytorch.sparse.configure_sparsity_wrapper(model: torch.nn.Module, optimizer: torch.optim.Optimizer, sparsity_type: str, param_name_patterns: Optional[Union[str, List[str], Dict[str, dict]]] = None, sparse_param_filter: Optional[Callable[[str, torch.nn.Parameter], bool]] = None, **kwargs) cerebras_pytorch.sparse.wrapper.SparsityWrapperOptimizer [source]#
Returns a SparsityWrapperOptimizer ready to be a drop-in replacement for the given
optimizer
, while also constructing a SparsityOptimizer of the appropriatesparsity_type
according toparam_name_patterns
orsparse_param_filter
of the givenmodel.named_parameters()
.**kwargs
are passed along to the SparsityOptimizer__init__
- Parameters
model – Root module to extract parameters and hook the FWD pass
optimizer – Optimizer to wrap to sparsify the optimizer state.
sparsity_type – Type of sparsity optimizer to construct.
param_name_patterns – Filter to select which parameters are sparse and optionally if any more specific config should be applied to certain parameters.
sparse_param_filter – Callable to provide fallback selection of which parameters are sparse in case no param_name_patterns are provided.
kwargs – Passed along to the sparsity optimizer
__init__
.
For more control, the helper
configure_sparsity_optimizer
can construct
a BaseSparsityOptimizer
from the same
high-level configuration dictionary.
- cerebras_pytorch.sparse.configure_sparsity_optimizer(sparsity_type: str, named_parameters: List[Tuple[str, torch.nn.Parameter]], param_name_patterns: Optional[Union[str, List[str], Dict[str, dict]]] = None, sparse_param_filter: Optional[Callable[[str, torch.nn.Parameter], bool]] = None, **kwargs) cerebras_pytorch.sparse.base.BaseSparsityOptimizer [source]#
Construct a SparsityOptimizer of the appropriate sparsity_type according to
param_name_patterns
orsparse_param_filter
of the givennamed_parameters
.**kwargs
are passed along to the SparsityOptimizer__init__
- Parameters
sparsity_type – Type of sparsity optimizer to construct.
named_parameters – List of (name, param) from model.named_parameters()
param_name_patterns – Filter to select which parameters are sparse and optionally if any more specific config should be applied to certain parameters.
sparse_param_filter – Callable to provide fallback selection of which parameters are sparse if no param_name_patterns are provided.
kwargs – Passed along to the chosen sparsity optimizer
__init__
.
These optimizers can be used manually like any other pytorch optimizer, but see
hook_module
for
having them automatically apply sparsity during a module’s forward() and
backward(). The same high-level drop-in replacment optimizer wrapper from
configure_sparsity_wrapper
can be directly
constructed and used.
Sparsity Optimizers#
These classes are the built-in sparsity algorithms.
StaticSparsityOptimizer
is an “optimizer”
that maintains a static sparsity pattern throughout all training. The rest
implement published dyanmic sparsity algorithms. These are the objects returned
from configure_sparsity_optimizer
et al.
Even though static sparsity never updates it sparsity pattern throughout training, it is still implemented as an “Optimizer” to provide a consistent API and allow static & dynamic sparsity to be easily swapped via configuration.
- class cerebras_pytorch.sparse.StaticSparsityOptimizer[source]#
Implements a static sparsity optimizer.
- Parameters
params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify
sparsity (float) – target sparsity
init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.
Example
>>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = StaticSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity=0.5, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step()
- __init__(params, sparsity=torch.optim.optimizer.required, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', **kwargs)[source]#
- class cerebras_pytorch.sparse.GMPSparsityOptimizer[source]#
Implements Gradual Magnitude Pruning
Sparsity increases monotonically based on weight magnitude.
See: https://arxiv.org/abs/1506.02626
- Parameters
params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify
init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.
sparsity – Sparsity, either constant or step-aware hyperparameter
schedule –
Sparsity update schedule. May be one of:
int
: Single regular update frequency.list
: Irregular update on the given steps.dict
: Containing{"start": start, "freq": freq, "stop": stop}
for regular updates with start & stop.ScheduleCallable
: User function accepting a ranklesstorch.LongTensor
and returning a ranklesstorch.BoolTensor
Example
>>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = GMPSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity={"type": "exp", "init": 0, "gamma": 1000*math.log(0.3) schedule=1000, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step()
- __init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', sparsity: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, tiebreak=None, **kwargs)[source]#
- class cerebras_pytorch.sparse.SETSparsityOptimizer[source]#
Implements Sparse Evolutionary Training (SET)
Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown randomly.
See: https://arxiv.org/abs/1707.04780
- Parameters
params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify
init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.
sparsity – Sparsity, either constant or step-aware hyperparameter
schedule –
Sparsity update schedule. May be one of:
int
: Single regular update frequency.list
: Irregular update on the given steps.dict
: Containing{"start": start, "freq": freq, "stop": stop}
for regular updates with start & stop.ScheduleCallable
: User function accepting a ranklesstorch.LongTensor
and returning a ranklesstorch.BoolTensor
drop_fraction – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter.
Example
>>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = SETSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity=0.9, schedule={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step()
- __init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', sparsity: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, drop_fraction: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = 0.3, **kwargs)[source]#
- class cerebras_pytorch.sparse.RigLSparsityOptimizer[source]#
Implements Rigging the Lottery (RigL)
Sparsity levels stay constant throughout training, but the lowest magnitude weights are pruned and then regrown using a proxy measure of where a pruned connection would have had the most impact by finding the highest magnitude (dense) gradients of pruned weights.
See: https://arxiv.org/abs/1911.11134
- Parameters
params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify
init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.
sparsity – Sparsity, either constant or step-aware hyperparameter
schedule –
Sparsity update schedule. May be one of:
int
: Single regular update frequency.list
: Irregular update on the given steps.dict
: Containing{"start": start, "freq": freq, "stop": stop}
for regular updates with start & stop.ScheduleCallable
: User function accepting a ranklesstorch.LongTensor
and returning a ranklesstorch.BoolTensor
drop_fraction – Fraction of non-pruned weights to drop each update step. Either a constant or a step-aware hyperparamter.
Example
>>> optimizer = torch.optim.SGD( model.parameters(), lr=0.1, momentum=0.9 ) >>> sparsity_opt = RigLSparsityOptimizer( [p for n,p in model.named_parameters() if should_sparsify(n,p)], sparsity=0.9, schedule={"freq": 100, "stop": 1000}, drop_fraction={"type": "cosine", "init": 0.3, "half_period": 1000}, ) >>> sparsity_opt.hook_module(model) >>> sparsity_opt.initialize_sparsity() >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> sparsity_opt.step()
- __init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', sparsity: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, drop_fraction: Union[int, float, List[int], List[float], Tuple, Dict, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], cerebras_pytorch.sparse.utils.BaseHyperParameter] = 0.3, **kwargs)[source]#
Customizing Sparsity & Reference#
Several building blocks can be inherited from or composed to help build new dynamic sparsity algorithms or customize the behavior of existing ones.
cerebras_pytorch.sparse.init
#
Sparsity mask initialization methods and helpers, invoked by
BaseSparsityOptimizer
.
- cerebras_pytorch.sparse.init.checkerboard(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor [source]#
Mostly for stress and performance testing, creates a sparsity mask that is maximally distributed in a checkerboard across the weight.
- cerebras_pytorch.sparse.init.from_zeros(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor [source]#
Any zeros currently in the weights represent pruned connections. NOTE: Doesn’t actualy honor the configured sparsity.
- cerebras_pytorch.sparse.init.random(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor [source]#
Uniformly random sparsity pattern.
- cerebras_pytorch.sparse.init.topk(p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[cerebras_pytorch.sparse.utils.ScoreShaper] = None) torch.BoolTensor [source]#
Prune lowest magnitude weights.
cerebras_pytorch.sparse.utils
#
- class cerebras_pytorch.sparse.utils.BaseHyperParameter[source]#
Base class for step-aware hyperparameters used in Sparsity Optimizers.
- static get_cls(typename: str)[source]#
Looks up the class by its typename in the registry.
Raises a ValueError if none exist with that name.
- class cerebras_pytorch.sparse.utils.CosineHyperParameter[source]#
Cosine function for oscilating between an initial (maximum) value down to a minimum and back to the maximum every period.
\(y(step) = o + a \cdot \cos(step \cdot \pi / half\_period)\), where \(o = (init + minimum)/2\) and \(a = init - o\).
- class cerebras_pytorch.sparse.utils.CyclingHyperParameter[source]#
Hyper parameter cycling between discrete values at update steps.
- class cerebras_pytorch.sparse.utils.ExpHyperParameter[source]#
Exponential, approaching an asymptotic final value
\(y(step) = final + (init-final) e^{step \cdot gamma}\)
- class cerebras_pytorch.sparse.utils.InputGroupScoreShaper[source]#
A ScoreShaper interface when weights are logically shaped as [outsize, num_groups*in_per_group], but need to be scored in a “balanced” fashion as [num_groups, outsize*in_per_group]
Examples
>>> # Common score used for the following examples >>> score=torch.tensor([[1.0, 0.0], ... [2.0, -1.0]])
>>> # 50% sparsity, drops the 2 lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... ) tensor([[ True, False], [ True, False]])
>>> # 50% sparsity, but computed columnwise >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... score_shaper=InputGroupScoreShaper(num_groups=2) ... ) tensor([[False, True], [ True, False]])
- class cerebras_pytorch.sparse.utils.LambdaHyperParameter[source]#
Invoke a user’s lambda function of step to obtain the hyper parameter.
- class cerebras_pytorch.sparse.utils.LinearHyperParameter[source]#
Linear change from an initial value.
\(y(step) = init + step * slope\)
- class cerebras_pytorch.sparse.utils.OutputGroupScoreShaper[source]#
A ScoreShaper interface when weights are logically shaped as [num_groups*out_per_group, insize], but need to be scored in a “balanced” fashion as [num_groups, out_per_group*insize]
Examples
>>> # Common score used for the following examples >>> score=torch.tensor([[1.0, 2.0], ... [0.0, -1.0]])
>>> # 50% sparsity, drops the 2 lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... ) tensor([[ True, True], [False, False]])
>>> # 50% sparsity, but computed rowwise >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.5), ... score_shaper=OutputGroupScoreShaper(num_groups=2) ... ) tensor([[False, True], [ True, False]])
- class cerebras_pytorch.sparse.utils.PowerHyperParameter[source]#
Power law.
\(y(step) = init \cdot beta^{step}\)
- class cerebras_pytorch.sparse.utils.ScoreFlattener[source]#
Default ScoreShaper which everything is flattened, providing a global competition for magnitude. If only sub-portions of the weight should compete for magnitude, provide an alternative shaper object.
- class cerebras_pytorch.sparse.utils.ScoreTiebreaker[source]#
Base class for all “tiebreaking” of score for deterministic execution. The default tiebreaking is “none”, leaving ties non-deterministic across systems. In particular, torch.topk has different behavior on CPU, GPU, and WSE. It isn’t stable and doesn’t guarantee anything around ties.
- cerebras_pytorch.sparse.utils.initialize_hyperparam(param)[source]#
Given some user specified configuration, construct a BaseHyperParameter class that is step aware.
- cerebras_pytorch.sparse.utils.make_mask_drop_minimum(score: torch.FloatTensor, mask: torch.BoolTensor, drop_fraction: torch.FloatTensor, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor [source]#
Given a sparse
score
(withmask
), return a newtorch.BoolTensor
the same shape as mask where adrop_fraction
portion of the currently present (mask==True
) connections are dropped (mask==False
).The connections are dropped at positions corresponding to the lowest values of
score
.Equivalently, a subset of
mask
is returned corresponding to the highest magnitude elements ofscore
.- Parameters
score – Values used to evaluate which positions to drop
mask – Current connections, same shape as
score
drop_fraction – What fraction of current connections to drop
score_shaper – If given,
score
(andmask
) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default,score
andmask
are reinterpreted as 1D tensors, yielding completely unstructured sparsity.
- Returns
New mask that has existing connections dropped. No connections will be regrown (unless drop_fraction is negative).
- cerebras_pytorch.sparse.utils.make_mask_grow_maximum(score: torch.FloatTensor, mask: torch.BoolTensor, sparsity: torch.FloatTensor, mask_nonzero: Optional[torch.IntTensor] = None, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor [source]#
Given a sparse
score
(withmask
), return a new torch.BoolTensor the same shape asmask
where some currently pruned connections are regrown (from those positions with the highest score) such that the returned mask has the given target sparsity.If
mask
is already less sparse (has more connections) than the target, none are regrown and the original mask is returned as-is. That is, the givenmask
should be more sparse than the target sparsity.- Parameters
score – Values used to evaluate which positions to regrow
mask – Current connections, same shape as
score
drop_fraction – What fraction of current connections to drop
mask_nonzero – If given, the number of nonzero elements currently in the mask, used to control the number of connections needing regrowth. If it is not given, will be computed as
mask.nonzero().int()
. Sincemake_mask_grow_maximum
is often used in conjunction withmake_mask_drop_minimum
, this value is commonly available.score_shaper – If given,
score
(andmask
) will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default,score
andmask
are reinterpreted as 1D tensors, yielding completely unstructured sparsity.
- Returns
New mask that has connections regrown necessary to reach (decrease to) the target sparsity.
- cerebras_pytorch.sparse.utils.make_mask_topk_sparsity(score: torch.FloatTensor, sparsity: torch.FloatTensor, score_shaper: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]]] = None) torch.BoolTensor [source]#
Given a dense
score
, return atorch.BoolTensor
which is True at positions corresponding to values in the topk = (1-sparsity)*score.numel()
ofscore
.- Parameters
score – Values used to evaluate which positions to keep.
sparisty – rankless tensor in range [0,1] controlling fraction of the resulting mask that will be pruned.
score_shaper – If given,
score
will be interpreted as multiple independent subtensors. This can be used to ensure sparsity distribution is “balanced” or to produce blockwise sparsity. By default,score
is reinterpreted as a 1D tensor, yielding completely unstructured sparsity.
- Returns
mask
with givensparsity
, keeping only the highest values fromscore
.
Examples
>>> # Common score used for the following examples >>> score=torch.tensor([[1.0, 2.0], ... [0.0, -1.0]])
>>> # 25% sparsity, drops the one lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.25), ... ) tensor([[ True, True], [ True, False]])
>>> # 75% sparsity, drops the 3 lowest magnitude >>> make_mask_topk_sparsity( ... score=score, ... sparsity=torch.tensor(0.75), ... ) tensor([[False, True], [False, False]])
cerebras_pytorch.sparse.dynamic
#
Base class for all dynamic sparsity optimizer, plus dynamic schedule helpers.
- class cerebras_pytorch.sparse.dynamic.DynamicSparsityOptimizer[source]#
Abstract base class for a dynamic sparsity optimizer.
Subclasses must implement
update_mask
.- Parameters
params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify
init_method – Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda.
sparsity – Sparsity, either constant or step-aware hyperparameter
schedule –
Sparsity update schedule. May be one of:
int
: Single regular update frequency.list
: Irregular update on the given steps.dict
: Containing{"start": start, "freq": freq, "stop": stop}
for regular updates with start & stop.ScheduleCallable
: User function accepting a ranklesstorch.LongTensor
and returning a ranklesstorch.BoolTensor
- __init__(params, sparsity=torch.optim.optimizer.required, schedule: Union[int, List[int], Dict, Callable[[torch.LongTensor], torch.BoolTensor]] = torch.optim.optimizer.required, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', **kwargs)[source]#
- abstract update_mask(p, mask, sparsity, group)#
Compute an updated sparsity pattern.
- Parameters
p (torch.Tensor) – the parameter to sparsify
mask (torch.tensor(dtype=torch.bool)) – the current mask of param p
sparsity (torch.tensor(dtype=torch.float32)) – the desired sparsity level
group (dict) – The param group dict with any additional options
- Returns
The updated sparsity pattern on parameter p
- class cerebras_pytorch.sparse.dynamic.FreqSchedule[source]#
When schedulding sparsity update steps on a regular interval, this class allows configuring the start and stop step in addition to the update frequency.
cerebras_pytorch.sparse.base
#
- class cerebras_pytorch.sparse.base.BaseSparsityOptimizer[source]#
Abstract base class for a dynamic sparsity optimizer.
Subclasses must implement
_get_target_sparsity_level_of_group
andstep
.- Parameters
params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify
init_method (InitMethodType) – method by which sparsity is initialized
defaults (dict) – Additional defaults for param_groups
- __init__(params, init_method: Union[str, Callable[[torch.nn.Parameter, torch.FloatTensor, Optional[cerebras_pytorch.sparse.utils.ScoreShaper]], torch.BoolTensor]] = 'random', defaults=None)[source]#
- apply_sparsity()#
Apply the sparsity pattern to the parameters and optimizer states.
- hook_module(module: torch.nn.Module)[source]#
Hook the given module such that the sparsity pattern is applied to both the parameters before forward() and gradients after backward()
- manage_optimizer_state_sparsity(optimizer: torch.optim.optimizer.Optimizer, state_names: List[str])[source]#
Manage the sparsity of an optimizer’s state. For any parameters that this SparsityOptimizer manages, apply the sparsity pattern to all states named state_names
- zero_grad(set_to_none: bool = True)[source]#
Override default torch.optim.Optimizer to never zero gradients: This optimizer is slightly unique in that it isn’t responsible for the main weight update of the params it manages (and thus doesn’t consult or “maintain” their gradients), but it does manage the sparsity pattern of the params.
Can be further overriden in other SparsityOptimizers if they deal with gradients (like RigL).