# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Provide an "optimizer" implementing static sparsity.
"""
import warnings
import torch
from torch.optim.optimizer import required
from .base import BaseSparsityOptimizer, InitMethodType
[docs]class StaticSparsityOptimizer(BaseSparsityOptimizer):
r"""Implements a static sparsity optimizer.
Args:
params (iterable): iterable of parameters to sparsify or dicts defining
parameter groups to sparsify
sparsity (float): target sparsity
init_method: Method to initialize sparsity pattern. Can either be the
name of a built-in method or a lambda.
Example:
>>> optimizer = torch.optim.SGD(
model.parameters(), lr=0.1, momentum=0.9
)
>>> sparsity_opt = StaticSparsityOptimizer(
[p for n,p in model.named_parameters() if should_sparsify(n,p)],
sparsity=0.5,
)
>>> sparsity_opt.hook_module(model)
>>> sparsity_opt.initialize_sparsity()
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> sparsity_opt.step()
"""
[docs] def __init__(
self,
params,
sparsity=required,
init_method: InitMethodType = "random",
**kwargs,
):
if kwargs:
warnings.warn(f"Unused arguments: {kwargs}")
defaults = {
'sparsity': sparsity,
}
super().__init__(
params=params, init_method=init_method, defaults=defaults
)
def add_param_group(self, param_group):
# Verify all required values are specified.
param_group = super().add_param_group(param_group)
# Do static sparsity specific verification.
sparsity = param_group["sparsity"]
if not isinstance(sparsity, float):
raise ValueError(
"StaticSparsityOptimizer only supports constant sparsity"
)
if not 0.0 <= sparsity < 1.0:
raise ValueError(
f"Invalid sparsity level {sparsity}. Must be 0.0 <= s < 1.0"
)
param_group["csx_annotated_sparsity"] = (sparsity, sparsity, sparsity)
def _get_target_sparsity_level_of_group(self, group) -> torch.FloatTensor:
# Always the same static sparsity level
return torch.tensor(group["sparsity"])
@torch.no_grad()
def step(self, closure=None):
# Ensure we've called apply_sparsity before step
self._ensure_sparsity_applied()
# Merely apply the mask to maintain initial sparsity pattern.
self.apply_sparsity()