Source code for cerebras_pytorch.sparse.dynamic

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Base class for all dynamic sparsity optimizer, plus dynamic schedule helpers.
"""
import inspect
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Union

import torch
from torch.optim.optimizer import required

import cerebras_pytorch as cstorch
from cerebras_pytorch.utils.typing import signature_matches_type_hint

from .base import BaseSparsityOptimizer, InitMethodType
from .utils import set_param_group_hyperparam


[docs]class BaseSchedule(ABC): TYPE_REGISTRY = {} def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if hasattr(cls, "TYPE"): cls.TYPE_REGISTRY[cls.TYPE] = cls cls.TYPE_REGISTRY[cls.__name__] = cls
[docs] @staticmethod def get_cls(typename: str): """ Looks up the class by its typename in the registry. Raises a ValueError if none exist with that name. """ tr = BaseSchedule.TYPE_REGISTRY if typename in tr: return tr[typename] raise ValueError( f"Uknown scheduler `type`:\"{typename}\". Valid options are " f"{list(tr.keys())}" )
@abstractmethod def __call__(self, step: torch.LongTensor) -> torch.BoolTensor: """ Given a training step rankless tensor, return a rankless bool tensor if this is a sparsity update step. """
[docs]class FreqSchedule(BaseSchedule): """ When schedulding sparsity update steps on a regular interval, this class allows configuring the start and stop step in addition to the update frequency. """
[docs] def __init__(self, start=None, freq=1000, stop=None): self.start = start self.freq = freq self.stop = stop
def __call__(self, step: torch.LongTensor) -> torch.BoolTensor: """ Returns a boolean rankless tensor if this step is an update step. """ # First, check if this is (after offsetting from start) an update step # based on the frequency check_step = step if self.start is not None: check_step = step - self.start is_update_step = check_step % self.freq == 0 # Next add the bounds checking if applicable if self.start is not None: is_update_step &= step >= self.start if self.stop is not None: is_update_step &= step < self.stop return is_update_step
[docs]class ListSchedule(BaseSchedule): """ When schedulding requires an irregular update cadence, explicit steps can be provided as a list. """
[docs] def __init__(self, steps: Union[List[int], torch.Tensor]): steps = tuple(steps) self.steps = steps self.start = min(steps) self.stop = max(steps)
def __call__(self, step: torch.LongTensor) -> torch.BoolTensor: """ Returns a boolean rankless tensor if this step is an update step. """ is_update_step = torch.tensor(False, device=step.device) for s in self.steps: is_update_step |= step == s return is_update_step
ScheduleCallable = Callable[ # torch.tensor(shape=[], dtype=int64) -> torch.tensor(shape=[], dtype=bool) [torch.LongTensor], torch.BoolTensor, ] ScheduleType = Union[int, List[int], Dict, ScheduleCallable]
[docs]def make_schedule(schedule: ScheduleType) -> ScheduleCallable: """ Instantiate a supported schedule type. """ if isinstance(schedule, int): # Single update frequency return FreqSchedule(freq=schedule) elif isinstance(schedule, dict): schedule = schedule.copy() typename = schedule.pop("type", None) if typename: return BaseSchedule.get_cls(typename)(**schedule) if "freq" in schedule: return FreqSchedule(**schedule) elif isinstance(schedule, (list, tuple)): return ListSchedule(schedule) elif callable(schedule): signature = inspect.signature(schedule) if signature_matches_type_hint(signature, ScheduleCallable): return schedule valid_types = list(BaseSchedule.TYPE_REGISTRY.keys()) raise ValueError( f"Invalid `schedule`: {schedule}. Valid options are:\n" f"* int: Regularly updating sparsity at fixed interval\n" f"* list[int]: List of specific update steps\n" f'* {{"start": start, "freq": freq, "stop": stop}}\n' f"* Callable: Used as-is\n" f"* {{\"type\": ...}} as one of {valid_types}" )
[docs]class DynamicSparsityOptimizer(BaseSparsityOptimizer, ABC): r"""Abstract base class for a dynamic sparsity optimizer. Subclasses must implement :meth:`update_mask`. Args: params (iterable): iterable of parameters to sparsify or dicts defining parameter groups to sparsify init_method: Method to initialize sparsity pattern. Can either be the name of a built-in method or a lambda. sparsity: Sparsity, either constant or step-aware hyperparameter schedule: Sparsity update schedule. May be one of: * ``int``: Single regular update frequency. * ``list``: Irregular update on the given steps. * ``dict``: Containing ``{"start": start, "freq": freq, "stop": stop}`` for regular updates with start & stop. * ``ScheduleCallable`` : User function accepting a rankless ``torch.LongTensor`` and returning a rankless ``torch.BoolTensor`` """
[docs] def __init__( self, params, sparsity=required, schedule: ScheduleType = required, init_method: InitMethodType = "random", **kwargs, ): defaults = {"sparsity": sparsity, "schedule": schedule, **kwargs} # When using CS, we execute the initial step 0 schedule and initialize # the masks on CPU, though during training it all happens on device: # | Training Device | GPU | CS | # | Operation | | # | ---------------------------------| # | step 0 schedule | CPU | CPU | # | initial mask | GPU | CPU | # | training schedule | GPU | CS | # | training mask update | GPU | CS | self._step = torch.tensor(0, dtype=torch.int64) super().__init__( params=params, init_method=init_method, defaults=defaults, )
def add_param_group(self, param_group): param_group = super().add_param_group(param_group) param_group["schedule"] = make_schedule(param_group["schedule"]) set_param_group_hyperparam(param_group, "sparsity") return param_group def _init_sparsity_of_group(self, group): # Called from __init__ via BaseSparsityOptimizer.init_sparsity starts_sparse = group["schedule"](self._step) if not starts_sparse: # Then use the all 1's "mask". for p in group['params']: self.state[p]['mask'] = cstorch.ones_like(p, dtype=torch.bool) else: # Base implementation calls _get_target_sparsity_level_of_group, # which needs group["is_update_step"] set. group["is_update_step"] = starts_sparse super()._init_sparsity_of_group(group) group.pop("is_update_step") if self.backend.is_csx: # To provide a hint to the CSX compiler for performance # optimization, annotate the (min, max, ending) sparsity. begin_step = getattr(group["schedule"], "start", None) or 0 # If the schedule has a `stop` step use that, otherwise pick # 100,000 arbitrarily. end_step = getattr(group["schedule"], "stop", None) or 100000 min_max_end = group["sparsity"].get_min_max_end( begin_step, end_step ) if min_max_end and not starts_sparse: # If we we don't start sparse, there is a period of dense # training, or 0% sparsity. _, max_v, end_v = min_max_end min_max_end = (0.0, max_v, end_v) group["csx_annotated_sparsity"] = min_max_end def _get_target_sparsity_level_of_group(self, group) -> torch.FloatTensor: """ Returns the target sparsity level at the current step, including during _init_sparsity_of_group """ is_update_step = group["is_update_step"] sparsity = group["sparsity"](self._step, is_update_step) # Ensure dynamic sparsity stays between [0, 1) sparsity = torch.clamp(sparsity, min=0.0, max=1.0) return sparsity def state_dict(self): state_dict = super(DynamicSparsityOptimizer, self).state_dict() state_dict["step"] = self._step return state_dict
[docs] def visit_state(self, fn): """ Applies a lambda to each stateful value. """ super().visit_state(fn) new_val = fn(self._step) if new_val is not None: self._step = new_val
def load_state_dict(self, state_dict): super(DynamicSparsityOptimizer, self).load_state_dict(state_dict) with self.backend.device: self._step = state_dict['step'].to(self.backend.torch_device) @abstractmethod @torch.no_grad() def update_mask(self, p, mask, sparsity, group): """ Compute an updated sparsity pattern. Args: p (torch.Tensor): the parameter to sparsify mask (torch.tensor(dtype=torch.bool)): the current mask of param p sparsity (torch.tensor(dtype=torch.float32)): the desired sparsity level group (dict): The param group dict with any additional options Returns: The updated sparsity pattern on parameter p """ @torch.no_grad() def step(self, closure=None): # Ensure we've called apply_sparsity before step self._ensure_sparsity_applied() # The weights and optimizer state were just updated. In case we # _decrease_ sparsity here instead of increasing it, apply the current # sparsity pattern. self.apply_sparsity() # By convention, `step` counts number of fwd/bwd/gradient evaluations of # the model (`step==0` is model initialization time). If # `sparsity_optimizer.step()` is called after weights have been updated # (which is recommended), we are effectively setting up the sparsity # pattern for the next step. Thus, increment step here so # self.process_schedule can indicate if this is a step to update. self._step.add_(1) for group in self.param_groups: is_update_step = group["schedule"](self._step) # cache this group's is_update_step for use by update_mask group["is_update_step"] = is_update_step sparsity = self._get_target_sparsity_level_of_group(group) add_summaries = group.get("add_summaries", False) if add_summaries: if len(self.param_groups) > 1: name = "/" + group["name"] else: name = "" cstorch.summarize_scalar(f"sparsity/target{name}", sparsity) for name, p in zip(group["param_names"], group["params"]): # In case there are multiple devices, ensure the sparsity is # on the parameter's device; it comes from the device we # evaluated the schedule on, usually the device of step. sparsity = sparsity.to(p.device) mask = self.state[p]['mask'] updated_mask = self.update_mask(p, mask, sparsity, group) # Rewrite into the existing mask tensor for state tracking new_mask = torch.where(is_update_step, updated_mask, mask) self.state[p]['mask'] = new_mask if add_summaries: cstorch.summarize_scalar( f"sparsity/{name}", 1 - new_mask.sum() / new_mask.numel(), ) # Remove is_update_step, this shouldn't be stateful. group.pop("is_update_step") self.apply_sparsity()