Source code for cerebras_pytorch.sparse.configure
# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Configuration helpers for constructing SparsityOptimizer objects.
"""
import inspect
import logging
import re
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from .base import BaseSparsityOptimizer
from .wrapper import SparsityWrapperOptimizer
SparseParamFilterType = Callable[[str, torch.nn.Parameter], bool]
# Can be a single regex, a list of regex, or a dict of regex -> config
ParamNamePatternsType = Union[str, List[str], Dict[str, dict]]
LOGGER = logging.getLogger("cerebras.sparsity")
def default_sparse_param_filter(name: str, param: torch.nn.Parameter) -> bool:
"""
Return True if the given parameter should be sparse.
Args:
name: Name of the parameter
param: The parameter itself
"""
# By default, sparsify params that are > 1D and not embedding or norm.
name = name.lower()
if (
len(param.shape) <= 1
or "embedding" in name
or "norm" in name
or "lm_head" in name
):
return False
return True
def validate_options(options: Dict):
"""
Validate and handle options given to a sparsity optimizer.
"""
if "sparsity_schedule" in options:
options["schedule"], options["sparsity"] = zip(
*options.pop("sparsity_schedule")
)
# TODO: handle more validation
def sparsity_param_groups(
named_parameters: List[Tuple[str, torch.nn.Parameter]],
param_name_patterns: Optional[ParamNamePatternsType] = None,
sparse_param_filter: Optional[SparseParamFilterType] = None,
):
"""
Returns a list of parameters or a list of tuple of (param, dict) for
passing to the sparsity optimizer's param_groups, if configured.
Three yaml examples:
sparsity:
type: gmp
sparsity_schedule:
- [0, 0.1]
- [6, 0.3]
- [8, 0.5]
param_name_patterns: "fc_layers.*weight"
sparsity:
type: gmp
sparsity_schedule:
- [0, 0.1]
- [6, 0.3]
- [8, 0.5]
param_name_patterns:
- "fc_layers.*weight"
- "final_layer.*weight"
sparsity:
type: gmp
param_name_patterns:
fc_layers.*weight:
sparsity_schedule:
- [0, 0.1]
- [6, 0.3]
- [8, 0.5]
final_layer.*weight:
sparsity_schedule:
- [0, 0.2]
- [6, 0.5]
- [8, 0.7]
Args:
named_parameters: List of (name, param) from model.named_parameters()
param_name_patterns: Filter to select which parameters are sparse and
optionally if any more specific config should be
applied to certain parameters.
sparse_param_filter: Callable to provide fallback selection of which
parameters are sparse if no param_name_patterns
are provided.
"""
if not sparse_param_filter:
sparse_param_filter = default_sparse_param_filter
# Check if there is a yaml specified param name pattern
if isinstance(param_name_patterns, str):
# Just a single config changing which params the defaults apply to.
pattern = re.compile(param_name_patterns)
def sparse_param_filter(
name, param
): # pylint: disable=function-redefined
return pattern.search(name)
elif isinstance(param_name_patterns, list):
# A list of several patterns, all of which get the default setting.
patterns = list(map(re.compile, param_name_patterns))
def sparse_param_filter(
name, param
): # pylint: disable=function-redefined
return any(map(lambda patt: patt.search(name), patterns))
elif isinstance(param_name_patterns, dict):
# An entire param_group per pattern.
param_groups = []
for pattern, param_group in param_name_patterns.items():
# To allow yaml syntax of adding extra name patterns without
# customizing their group options.
if param_group is None:
param_name_patterns[pattern] = param_group = {}
if not isinstance(param_group, dict):
raise ValueError(
f"To specify param groups, each `param_name_patterns` "
f"should be a dict containing the group's options. "
f"Instead, got `{param_group}` for `{pattern}`. "
f"To specify multiple patterns whose matching params "
f"all get default options, define "
f"`param_name_patterns` as a list instead."
)
param_group["params"] = []
validate_options(param_group)
param_groups.append(param_group)
patterns = [
(re.compile(pattern), param_group["params"])
for pattern, param_group in param_name_patterns.items()
]
# Go add each parameter to at most one group.
for name, param in named_parameters:
for pattern, param_list in patterns:
if pattern.search(name):
param_list.append((name, param))
break
for pattern, param_list in patterns:
if len(param_list) == 0:
raise ValueError(
f"{pattern} did not match any parameter names!"
)
return param_groups
# Not returning param_groups, just list of params all getting defaults.
return [(n, p) for n, p in named_parameters if sparse_param_filter(n, p)]
[docs]def configure_sparsity_optimizer(
sparsity_type: str,
named_parameters: List[Tuple[str, torch.nn.Parameter]],
param_name_patterns: Optional[ParamNamePatternsType] = None,
sparse_param_filter: Optional[SparseParamFilterType] = None,
**kwargs,
) -> BaseSparsityOptimizer:
"""
Construct a SparsityOptimizer of the appropriate sparsity_type according to
``param_name_patterns`` or ``sparse_param_filter`` of the given
``named_parameters``.
``**kwargs`` are passed along to the SparsityOptimizer ``__init__``
Args:
sparsity_type: Type of sparsity optimizer to construct.
named_parameters: List of (name, param) from model.named_parameters()
param_name_patterns: Filter to select which parameters are sparse and
optionally if any more specific config should be
applied to certain parameters.
sparse_param_filter: Callable to provide fallback selection of which
parameters are sparse if no param_name_patterns
are provided.
kwargs: Passed along to the chosen sparsity optimizer ``__init__``.
"""
# Allow user dervied sparsity optimizer to be configured using helper.
def _retrieve_all_subclasses(cls):
for subcls in cls.__subclasses__():
yield subcls
yield from _retrieve_all_subclasses(subcls)
supported_sparsity_types = {}
for cls in _retrieve_all_subclasses(BaseSparsityOptimizer):
if inspect.isabstract(cls):
continue
key = cls.__name__.lower().replace("sparsityoptimizer", "")
supported_sparsity_types[key] = cls
# Ensure we have a known sparsity optimizer.
sparsity_opt_cls = supported_sparsity_types.get(sparsity_type)
if not sparsity_opt_cls:
raise ValueError(
f"Unsupported sparsity optimizer type: {sparsity_type}. "
f"Supported types: {list(supported_sparsity_types.keys())}"
)
# Determine which parameters need sparsification.
params_to_sparsify = sparsity_param_groups(
named_parameters, param_name_patterns, sparse_param_filter,
)
if len(params_to_sparsify) == 0:
LOGGER.warning("Sparsity configured, but no parameters were sparse")
else:
def log(opts, names):
base_msg = f"Will apply \"{sparsity_type}\" sparsity {opts} to"
if not LOGGER.isEnabledFor(logging.INFO - 5):
LOGGER.info(
f"{base_msg} {len(names)} tensor{'s'[:len(names)^1]}"
)
else:
LOGGER.verbose(f"{base_msg} {names}")
defnames = []
for param in params_to_sparsify:
if isinstance(param, tuple):
name, param = param
defnames.append(name)
else:
opts = {**kwargs, **param}
names, params = zip(*opts.pop("params"))
# Log the param group with extra options
log(opts, names)
if defnames:
# For all params which only have default options, log them.
log(kwargs, defnames)
# Adapt sparsity_schedule -> (sparsity, schedule)
validate_options(kwargs)
return sparsity_opt_cls(
params=params_to_sparsify,
# Pass yaml config along to sparsity optimizer as its defaults.
**kwargs,
)
[docs]def configure_sparsity_wrapper(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
sparsity_type: str,
param_name_patterns: Optional[ParamNamePatternsType] = None,
sparse_param_filter: Optional[SparseParamFilterType] = None,
**kwargs,
) -> SparsityWrapperOptimizer:
"""
Returns a SparsityWrapperOptimizer ready to be a drop-in replacement for
the given ``optimizer``, while also constructing a SparsityOptimizer of the
appropriate ``sparsity_type`` according to ``param_name_patterns`` or
``sparse_param_filter`` of the given ``model.named_parameters()``.
``**kwargs`` are passed along to the SparsityOptimizer ``__init__``
Args:
model: Root module to extract parameters and hook the FWD pass
optimizer: Optimizer to wrap to sparsify the optimizer state.
sparsity_type: Type of sparsity optimizer to construct.
param_name_patterns: Filter to select which parameters are sparse and
optionally if any more specific config should be
applied to certain parameters.
sparse_param_filter: Callable to provide fallback selection of which
parameters are sparse in case no
param_name_patterns are provided.
kwargs: Passed along to the sparsity optimizer ``__init__``.
"""
sparsity_optimizer = configure_sparsity_optimizer(
sparsity_type,
list(model.named_parameters()),
param_name_patterns,
sparse_param_filter,
**kwargs,
)
sparsity_optimizer.hook_module(model)
return SparsityWrapperOptimizer(optimizer, sparsity_optimizer)