# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Base class for all dynamic sparsity optimizer, plus dynamic schedule helpers.
"""
import inspect
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Union
import torch
from torch.optim.optimizer import required
import cerebras_pytorch as cstorch
from cerebras_pytorch.utils.typing import signature_matches_type_hint
from .base import BaseSparsityOptimizer, InitMethodType
from .utils import set_param_group_hyperparam
[docs]class BaseSchedule(ABC):
TYPE_REGISTRY = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, "TYPE"):
cls.TYPE_REGISTRY[cls.TYPE] = cls
cls.TYPE_REGISTRY[cls.__name__] = cls
[docs] @staticmethod
def get_cls(typename: str):
"""
Looks up the class by its typename in the registry.
Raises a ValueError if none exist with that name.
"""
tr = BaseSchedule.TYPE_REGISTRY
if typename in tr:
return tr[typename]
raise ValueError(
f"Uknown scheduler `type`:\"{typename}\". Valid options are "
f"{list(tr.keys())}"
)
@abstractmethod
def __call__(self, step: torch.LongTensor) -> torch.BoolTensor:
"""
Given a training step rankless tensor, return a rankless bool tensor if
this is a sparsity update step.
"""
[docs]class FreqSchedule(BaseSchedule):
"""
When schedulding sparsity update steps on a regular interval, this class
allows configuring the start and stop step in addition to the update
frequency.
"""
[docs] def __init__(self, start=None, freq=1000, stop=None):
self.start = start
self.freq = freq
self.stop = stop
def __call__(self, step: torch.LongTensor) -> torch.BoolTensor:
"""
Returns a boolean rankless tensor if this step is an update step.
"""
# First, check if this is (after offsetting from start) an update step
# based on the frequency
check_step = step
if self.start is not None:
check_step = step - self.start
is_update_step = check_step % self.freq == 0
# Next add the bounds checking if applicable
if self.start is not None:
is_update_step &= step >= self.start
if self.stop is not None:
is_update_step &= step < self.stop
return is_update_step
[docs]class ListSchedule(BaseSchedule):
"""
When schedulding requires an irregular update cadence, explicit steps can
be provided as a list.
"""
[docs] def __init__(self, steps: Union[List[int], torch.Tensor]):
steps = tuple(steps)
self.steps = steps
self.start = min(steps)
self.stop = max(steps)
def __call__(self, step: torch.LongTensor) -> torch.BoolTensor:
"""
Returns a boolean rankless tensor if this step is an update step.
"""
is_update_step = torch.tensor(False, device=step.device)
for s in self.steps:
is_update_step |= step == s
return is_update_step
ScheduleCallable = Callable[
# torch.tensor(shape=[], dtype=int64) -> torch.tensor(shape=[], dtype=bool)
[torch.LongTensor],
torch.BoolTensor,
]
ScheduleType = Union[int, List[int], Dict, ScheduleCallable]
[docs]def make_schedule(schedule: ScheduleType) -> ScheduleCallable:
"""
Instantiate a supported schedule type.
"""
if isinstance(schedule, int):
# Single update frequency
return FreqSchedule(freq=schedule)
elif isinstance(schedule, dict):
schedule = schedule.copy()
typename = schedule.pop("type", None)
if typename:
return BaseSchedule.get_cls(typename)(**schedule)
if "freq" in schedule:
return FreqSchedule(**schedule)
elif isinstance(schedule, (list, tuple)):
return ListSchedule(schedule)
elif callable(schedule):
signature = inspect.signature(schedule)
if signature_matches_type_hint(signature, ScheduleCallable):
return schedule
valid_types = list(BaseSchedule.TYPE_REGISTRY.keys())
raise ValueError(
f"Invalid `schedule`: {schedule}. Valid options are:\n"
f"* int: Regularly updating sparsity at fixed interval\n"
f"* list[int]: List of specific update steps\n"
f'* {{"start": start, "freq": freq, "stop": stop}}\n'
f"* Callable: Used as-is\n"
f"* {{\"type\": ...}} as one of {valid_types}"
)
[docs]class DynamicSparsityOptimizer(BaseSparsityOptimizer, ABC):
r"""Abstract base class for a dynamic sparsity optimizer.
Subclasses must implement :meth:`update_mask`.
Args:
params (iterable): iterable of parameters to sparsify or dicts defining
parameter groups to sparsify
init_method: Method to initialize sparsity pattern. Can either be the
name of a built-in method or a lambda.
sparsity: Sparsity, either constant or step-aware hyperparameter
schedule: Sparsity update schedule. May be one of:
* ``int``: Single regular update frequency.
* ``list``: Irregular update on the given steps.
* ``dict``: Containing ``{"start": start, "freq": freq, "stop":
stop}`` for regular updates with start & stop.
* ``ScheduleCallable`` : User function accepting a rankless
``torch.LongTensor`` and returning a rankless
``torch.BoolTensor``
"""
[docs] def __init__(
self,
params,
sparsity=required,
schedule: ScheduleType = required,
init_method: InitMethodType = "random",
**kwargs,
):
defaults = {"sparsity": sparsity, "schedule": schedule, **kwargs}
# When using CS, we execute the initial step 0 schedule and initialize
# the masks on CPU, though during training it all happens on device:
# | Training Device | GPU | CS |
# | Operation | |
# | ---------------------------------|
# | step 0 schedule | CPU | CPU |
# | initial mask | GPU | CPU |
# | training schedule | GPU | CS |
# | training mask update | GPU | CS |
self._step = torch.tensor(0, dtype=torch.int64)
super().__init__(
params=params, init_method=init_method, defaults=defaults,
)
def add_param_group(self, param_group):
param_group = super().add_param_group(param_group)
param_group["schedule"] = make_schedule(param_group["schedule"])
set_param_group_hyperparam(param_group, "sparsity")
return param_group
def _init_sparsity_of_group(self, group):
# Called from __init__ via BaseSparsityOptimizer.init_sparsity
starts_sparse = group["schedule"](self._step)
if not starts_sparse:
# Then use the all 1's "mask".
for p in group['params']:
self.state[p]['mask'] = cstorch.ones_like(p, dtype=torch.bool)
else:
# Base implementation calls _get_target_sparsity_level_of_group,
# which needs group["is_update_step"] set.
group["is_update_step"] = starts_sparse
super()._init_sparsity_of_group(group)
group.pop("is_update_step")
if self.backend.is_csx:
# To provide a hint to the CSX compiler for performance
# optimization, annotate the (min, max, ending) sparsity.
begin_step = getattr(group["schedule"], "start", None) or 0
# If the schedule has a `stop` step use that, otherwise pick
# 100,000 arbitrarily.
end_step = getattr(group["schedule"], "stop", None) or 100000
# This simple scalar computation does not need to be traced
with torch.device("cpu"):
min_max_end = group["sparsity"].get_min_max_end(
begin_step, end_step
)
if min_max_end and not starts_sparse:
# If we we don't start sparse, there is a period of dense
# training, or 0% sparsity.
_, max_v, end_v = min_max_end
min_max_end = (0.0, max_v, end_v)
group["csx_annotated_sparsity"] = min_max_end
def _get_target_sparsity_level_of_group(self, group) -> torch.FloatTensor:
"""
Returns the target sparsity level at the current step, including during
_init_sparsity_of_group
"""
is_update_step = group["is_update_step"]
sparsity = group["sparsity"](self._step, is_update_step)
# Ensure dynamic sparsity stays between [0, 1)
sparsity = torch.clamp(sparsity, min=0.0, max=1.0)
return sparsity
def state_dict(self):
state_dict = super(DynamicSparsityOptimizer, self).state_dict()
state_dict["step"] = self._step
return state_dict
[docs] def visit_state(self, fn):
"""
Applies a lambda to each stateful value.
"""
super().visit_state(fn)
new_val = fn(self._step)
if new_val is not None:
self._step = new_val
def load_state_dict(self, state_dict):
super(DynamicSparsityOptimizer, self).load_state_dict(state_dict)
with self.backend.device:
self._step = state_dict['step'].to(self.backend.torch_device)
@abstractmethod
@torch.no_grad()
def update_mask(self, p, mask, sparsity, group):
"""
Compute an updated sparsity pattern.
Args:
p (torch.Tensor): the parameter to sparsify
mask (torch.tensor(dtype=torch.bool)): the current mask
of param p
sparsity (torch.tensor(dtype=torch.float32)): the desired
sparsity level
group (dict): The param group dict with any additional options
Returns:
The updated sparsity pattern on parameter p
"""
@torch.no_grad()
def step(self, closure=None):
# Ensure we've called apply_sparsity before step
self._ensure_sparsity_applied()
# The weights and optimizer state were just updated. In case we
# _decrease_ sparsity here instead of increasing it, apply the current
# sparsity pattern.
self.apply_sparsity()
# By convention, `step` counts number of fwd/bwd/gradient evaluations of
# the model (`step==0` is model initialization time). If
# `sparsity_optimizer.step()` is called after weights have been updated
# (which is recommended), we are effectively setting up the sparsity
# pattern for the next step. Thus, increment step here so
# self.process_schedule can indicate if this is a step to update.
self._step.add_(1)
for group in self.param_groups:
is_update_step = group["schedule"](self._step)
# cache this group's is_update_step for use by update_mask
group["is_update_step"] = is_update_step
sparsity = self._get_target_sparsity_level_of_group(group)
add_summaries = group.get("add_summaries", False)
if add_summaries:
if len(self.param_groups) > 1:
name = "/" + group["name"]
else:
name = ""
cstorch.summarize_scalar(f"sparsity/target{name}", sparsity)
for name, p in zip(group["param_names"], group["params"]):
if p.grad is None:
# If the gradient is None, then the parameter is unused
# and there is no need to update the mask
continue
# In case there are multiple devices, ensure the sparsity is
# on the parameter's device; it comes from the device we
# evaluated the schedule on, usually the device of step.
sparsity = sparsity.to(p.device)
mask = self.state[p]['mask']
updated_mask = self.update_mask(p, mask, sparsity, group)
# Rewrite into the existing mask tensor for state tracking
new_mask = torch.where(is_update_step, updated_mask, mask)
self.state[p]['mask'] = new_mask
if add_summaries:
cstorch.summarize_scalar(
f"sparsity/{name}",
1 - new_mask.sum() / new_mask.numel(),
)
# Remove is_update_step, this shouldn't be stateful.
group.pop("is_update_step")
self.apply_sparsity()