# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
import functools
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable
from typing import List
import torch
from torch.optim.optimizer import Optimizer
from .init import InitMethodType, make_init_method
[docs]class BaseSparsityOptimizer(Optimizer, ABC):
r"""
Abstract base class for a dynamic sparsity optimizer.
Subclasses must implement :meth:`_get_target_sparsity_level_of_group` and
:meth:`step`.
Args:
params (iterable): iterable of parameters to sparsify or dicts defining
parameter groups to sparsify
init_method (InitMethodType): method by which sparsity is initialized
defaults (dict): Additional defaults for param_groups
"""
[docs] def __init__(
self, params, init_method: InitMethodType = 'random', defaults=None
):
defaults = defaults or {}
defaults['init_method'] = init_method
super().__init__(params, defaults)
self.optimizers_and_state_names = []
self._init_sparsity_called = False
self._hook_module_called = False
self._apply_sparsity_called = False
self._param_grad_hooks = {}
from cerebras_pytorch.backend import current_backend_impl
self.backend = current_backend_impl()
self.backend.register_optimizer(self)
[docs] def initialize_sparsity(self):
"""
Compute the initial sparsity pattern for each parameter.
"""
if self._init_sparsity_called:
# Don't re-initialize
return
self._init_sparsity_called = True
num_params = sum(len(group["params"]) for group in self.param_groups)
# Set up intiailization progress bar
if self.backend.progress_tracker is not None:
self.backend.progress_tracker.reset(total=num_params)
self.backend.progress_tracker.set_postfix()
self.backend.progress_tracker.set_description(
"Initializing sparsity patterns"
)
with self.backend.device:
for group in self.param_groups:
self._init_sparsity_of_group(group)
self.visit_state(lambda x: x.to(self.backend.torch_device))
# After initializing new masks, we'll need to double check that
# apply_sparsity() gets called once before step()
self._apply_sparsity_called = False
def _init_sparsity_of_group(self, group):
"""
Compute the initial sparsity pattern for each of the parameters in the
given group.
"""
# This simple scalar computation does not need to be traced
with torch.device("cpu"):
sparsity = self._get_target_sparsity_level_of_group(group)
# Use the CPU device if doing eager initialization on CSX.
# Otherwise, use the parameter's device.
# This allows us to trace the mask initialization during
# lazy initialization.
device = None
if (
self.backend.is_csx
and not self.backend.device.config.lazy_initialization
):
device = "cpu"
initializer = group['init_method']
for p in group['params']:
self.state[p]["mask"] = initializer(p, sparsity, device=device)
if self.backend.progress_tracker is not None:
self.backend.progress_tracker.update()
@abstractmethod
def _get_target_sparsity_level_of_group(self, group) -> torch.FloatTensor:
"""
Returns the target sparsity level for parameters in the group.
Returns:
sparsity_level: a rankless FloatTensor holding the sparsity level
"""
[docs] def manage_optimizer_state_sparsity(
self, optimizer: Optimizer, state_names: List[str]
):
"""
Manage the sparsity of an optimizer's state. For any parameters that
this SparsityOptimizer manages, apply the sparsity pattern to all
states named `state_names`
"""
self.optimizers_and_state_names.append((optimizer, state_names))
def _yield_optimizer_states(self, p):
"""
Yield the given parameter's optimizer states which need sparsity
applied.
"""
for opt, state_names in self.optimizers_and_state_names:
if p in opt.state:
state = opt.state[p]
for s_name in state_names:
if s_name in state:
yield state[s_name]
[docs] def annotate_sparsity(self):
"""
Annotate sparsity as performance hints for the cerebras compiler
"""
for group in self.param_groups:
sparsity = group.get("csx_annotated_sparsity")
if sparsity is None:
continue
min_v, max_v, ending_v = sparsity
for p in group['params']:
self.backend.set_attribute(p, "min_sparsity", min_v)
self.backend.set_attribute(p, "max_sparsity", max_v)
self.backend.set_attribute(p, "sparsity", ending_v)
for state in self._yield_optimizer_states(p):
self.backend.set_attribute(state, "min_sparsity", min_v)
self.backend.set_attribute(state, "max_sparsity", max_v)
self.backend.set_attribute(state, "sparsity", ending_v)
[docs] def hook_module(self, module: torch.nn.Module):
"""
Hook the given module such that the sparsity pattern is applied to both
the parameters before forward() and gradients after backward()
"""
self._hook_module_called = True
def forward_pre_hook(module, input):
self.annotate_sparsity()
self.apply_sparsity()
module.register_forward_pre_hook(forward_pre_hook)
def _ensure_sparsity_applied(self):
if not self._apply_sparsity_called:
error = (
"apply_sparsity() must be called before forward() to apply "
"sparsity to parameters and optimizer state. "
)
if self._hook_module_called:
error += (
"A module hook was installed which should have taken care "
"of calling it, but did not. Check that you have not "
"disabled module hooks."
)
else:
error += (
"For your convenience, the SparsityOptimizer method "
"``.hook_module()`` can add a torch.nn.Module forward_pre "
"hook to automatically apply sparsity."
)
raise RuntimeError(error)
[docs] def zero_grad(self, set_to_none: bool = True):
"""
Override default torch.optim.Optimizer to never zero gradients: This
optimizer is slightly unique in that it isn't responsible for the
`main` weight update of the params it manages (and thus doesn't consult
or "maintain" their gradients), but it does manage the sparsity pattern
of the params.
Can be further overriden in other SparsityOptimizers if they deal with
gradients (like RigL).
"""
def state_dict(self):
# Adapted from torch.optim.Optimizer, but we use param_names
# param_names used in place of params
param_groups = []
# map parameter -> name
name_map = {}
for group in self.param_groups:
name_map.update(dict(zip(group["params"], group["param_names"])))
group = group.copy()
del group["params"]
# Some objects may themselves be stateful, so we store their state
# instead of them
for k, v in list(group.items()):
if hasattr(v, "state_dict"):
group[k] = v.state_dict()
elif callable(v):
# Don't serialize callable objects
del group[k]
param_groups.append(group)
state = {name_map[p]: v for p, v in self.state.items()}
return {"state": state, "param_groups": param_groups}
def load_state_dict(self, state_dict):
# Adapted from torch.optim.Optimizer, but we use param_names
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of parameter groups"
)
# map name -> parameter
name_map = {}
for group in self.param_groups:
name_map.update(dict(zip(group["param_names"], group["params"])))
for group, saved_group in zip(groups, saved_groups):
if group["param_names"] != saved_group["param_names"]:
raise ValueError(
"loaded state dict contains different parameters than "
"the current optimizer"
)
def to_device(param, value):
"""
Transfer each value to the same device as param.
"""
if isinstance(value, torch.Tensor):
return value.to(param.device)
elif isinstance(value, dict):
return {k: to_device(param, v) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(to_device(param, v) for v in value)
else:
return value
# Copy state associated with params (moving tensors to param device).
state = defaultdict(dict)
for param_name, v in state_dict['state'].items():
param = name_map[param_name]
state[param] = to_device(param, v)
# Update parameter groups, resetting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
# Some Sparsity param_group entries are complex and need to be
# serialized specially.
for k, v in group.items():
if hasattr(v, "load_state_dict"):
# Use the old object, but with loaded state.
v.load_state_dict(new_group[k])
new_group[k] = v
elif k not in new_group:
# Some items were omitted from the state_dict. Keep their
# old value.
new_group[k] = v
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
self.__setstate__({'state': state, 'param_groups': param_groups})
# Loading state counts as initializing it, don't re-init
self._init_sparsity_called = True
[docs] def visit_state(self, fn):
"""
Applies a lambda to each stateful value.
"""
for state in self.state.values():
for key, val in state.items():
new_val = fn(val)
if new_val is not None:
state[key] = new_val
for group in self.param_groups:
for v in group.values():
if hasattr(v, "visit_state"):
v.visit_state(fn)
def add_param_group(self, param_group):
# SparsityOptimizer accepts named_params tuples instead
named_params = param_group["params"]
if isinstance(named_params, list):
# list of tuples
names, params = zip(*named_params)
elif isinstance(named_params, tuple):
# single tuple
names, params = named_params
params = [params]
names = [names]
param_group["params"] = params
param_group["param_names"] = names
super().add_param_group(param_group)
# Hydrate the initializer
param_group["init_method"] = make_init_method(
param_group["init_method"]
)
# Ensure every group has a name
if "name" not in param_group:
if len(names) == 1:
# Single weight group
param_group["name"] = names[0]
else:
param_group["name"] = f"group_{len(self.param_groups)}"
# Return the newly added param_group
return self.param_groups[-1]
@torch.no_grad()
def apply_sparsity(self):
"""
Apply the sparsity pattern to the parameters and optimizer states.
"""
if not self._init_sparsity_called:
if self.backend.is_csx:
raise RuntimeError(
"Sparsity must be initialized before execution"
)
# We can init lazily on CPU/GPU though.
self.initialize_sparsity()
self._apply_sparsity_called = True
self._apply_masks_to_params()
self._apply_masks_to_opt_state()
def _grad_hook(self, p, grad):
# In the case there any NaNs in the unused gradients that correspond to
# zero'd out weights, we use a selection to replace these NaNs with
# zeros. (multiplying with the mask would preserve them).
# DLS will skip a weight update if there is a NaN in the gradient, but
# we only want this to happen if there is a NaN in gradients
# corresponding to non-zero weights. This is the behavior of the CS2
# which doesn't even compute the full gradients on most steps.
zero = torch.zeros_like(grad)
mask = self.state[p]['mask']
# Return modified gradient.
return torch.where(mask, grad, zero)
@torch.no_grad()
def _apply_masks_to_params(self):
for group in self.param_groups:
for p in group['params']:
# Apply sparsity.
p.mul_(self.state[p]['mask'])
# Set up autograd to apply sparsity to gradients too.
if p not in self._param_grad_hooks:
self._param_grad_hooks[p] = p.register_hook(
functools.partial(self._grad_hook, p)
)
@torch.no_grad()
def _apply_masks_to_opt_state(self):
for group in self.param_groups:
for p in group['params']:
for state in self._yield_optimizer_states(p):
state.mul_(self.state[p]['mask'])