common.pytorch.optim.sparse package#

Submodules#

common.pytorch.optim.sparse.base module#

class common.pytorch.optim.sparse.base.BaseSparsityOptimizer#

Bases: torch.optim.optimizer.Optimizer, abc.ABC

Abstract base class for a dynamic sparsity optimizer.

Subclasses must implement the init_sparsity and step function.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • optimizers_and_state_names (tuple, list(tuple)) – a tuple or list of tuple where the where the first element of the tuple is the optimizer and the second is a list of a optimizer state names to sparsify.

  • defaults (dict) – Defaults for param_groups

__init__(params, optimizers_and_state_names=[], defaults={}, **kwargs)#
add_param_group(param_group)#
apply_grad_sparsity()#

Apply the sparsity pattern to the gradients of the parameters.

apply_param_sparsity()#

Apply the sparsity pattern to the parameters.

apply_sparsity()#

Apply the sparsity pattern to the parameters and optimizer states.

abstract init_sparsity()#

Compute the initial sparsity pattern for each of the parameters.

load_state_dict(state_dict)#
sparse_params()#

Context manager applying sparsity to params upon entry and sparsity to gradients upon exit.

state_dict()#

common.pytorch.optim.sparse.dynamic module#

class common.pytorch.optim.sparse.dynamic.DynamicSparsityOptimizer#

Bases: modelzoo.common.pytorch.optim.sparse.base.BaseSparsityOptimizer, abc.ABC

Abstract base class for a dynamic sparsity optimizer.

Subclasses must implement the update_sparsity function.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • optimizers_and_state_names (tuple, list(tuple)) – a tuple or list of tuple where the where the first element of the tuple is the optimizer and the second is a list of a optimizer state names to sparsify.

  • sparsity_schedule (list) – Ordered list of (step, sparsity) tuples

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = DynamicSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity_schedule=[(0, 0.0) (5, 0.5)],
    )
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, optimizers_and_state_names=[], sparsity_schedule=torch.optim.optimizer.required, **kwargs)#
add_param_group(param_group)#
init_sparsity()#
load_state_dict(state_dict)#
process_schedule(group)#

Given a parameter group, determine whether we want to update sparsity on this step, as well as the sparsity level at this step

Parameters

group (dict) – a parameter group to sparsify. Contains schedule information at key “sparsity_schedule”

Returns

A pair (torch.tensor(dtype=torch.bool), torch.tensor(dtype=torch.float32))

corresponding to whether to sparsify on the current step, and the current sparsity level

state_dict()#
step(closure=None)#
abstract update_mask(p, mask, sparsity, group)#

Compute an updated sparsity pattern.

Parameters
  • p (torch.Tensor) – the parameter to sparsify

  • mask (torch.tensor(dtype=torch.bool)) – the current mask of param p

  • sparsity (torch.tensor(dtype=torch.float32)) – the desired sparsity level

  • group (dict) – The param group dict with any additional options

Returns

The updated sparsity pattern on parameter p

class common.pytorch.optim.sparse.dynamic.GMPSparsityOptimizer#

Bases: common.pytorch.optim.sparse.dynamic.DynamicSparsityOptimizer

Implements Gradual Magnitude Pruning https://arxiv.org/abs/1506.02626.

Sparsity increases monotonically based on weight magnitude.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • optimizers (tuple, list(tuple)) – a tuple or list of tuple where the where the first element of the tuple is the optimzer and the second is a list of a optimizer state names to sparsify.

  • sparsity_schedule (list) – List of (step, sparsity) tuples

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = GMPSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity_schedule=[(0, 0.0) (5, 0.5)],
    )
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, optimizers_and_state_names=[], sparsity_schedule=torch.optim.optimizer.required, **kwargs)#
update_mask(p, mask, sparsity, group)#

common.pytorch.optim.sparse.static module#

class common.pytorch.optim.sparse.static.StaticSparsityOptimizer#

Bases: modelzoo.common.pytorch.optim.sparse.base.BaseSparsityOptimizer

Implements a static sparsity optimizer.

Parameters
  • params (iterable) – iterable of parameters to sparsify or dicts defining parameter groups to sparsify

  • optimizers_and_state_names (tuple, list(tuple)) – a tuple or list of tuple where the where the first element of the tuple is the optimizer and the second is a list of a optimizer state names to sparsify.

  • sparsity (float) – target sparsity

  • init_method (str) – method by which masks are initialized

Example

>>> optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9
    )
>>> sparsity_opt = StaticSparsityOptimizer(
        [p for n,p in model.named_parameters() if should_sparsify(n,p)],
        sparsity=0.5,
    )
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
__init__(params, optimizers_and_state_names=[], sparsity=torch.optim.optimizer.required, init_method='topk', **kwargs)#
init_sparsity()#
step(closure=None)#

common.pytorch.optim.sparse.utils module#

common.pytorch.optim.sparse.utils.make_mask_top_atleast_k(score, num_dense_elem)#
common.pytorch.optim.sparse.utils.make_mask_topk_k(score, num_dense_elem)#
common.pytorch.optim.sparse.utils.make_mask_topk_sparsity(score, sparsity)#
common.pytorch.optim.sparse.utils.tiebreak_for_topk(score, method, eps)#

Module contents#