# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""
Sparsity mask initialization methods and helpers, invoked by
``BaseSparsityOptimizer``.
"""
import inspect
from typing import Callable, Optional, Union
import numpy as np
import torch
from cerebras_pytorch.utils.typing import signature_matches_type_hint
from .utils import ScoreShaper, make_mask_topk_sparsity
InitMethodCallable = Callable[
[torch.nn.Parameter, torch.FloatTensor, Optional[ScoreShaper]],
torch.BoolTensor,
]
InitMethodType = Union[str, InitMethodCallable]
[docs]def random(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
) -> torch.BoolTensor:
"""
Uniformly random sparsity pattern.
"""
score = torch.rand_like(p, device=sparsity.device)
return make_mask_topk_sparsity(score, sparsity, score_shaper)
[docs]def topk(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
) -> torch.BoolTensor:
"""
Prune lowest magnitude weights.
"""
# We transfer the param to the sparsity device because for CSX this
# involves reading the data back on CPU.
score = p.to(sparsity.device).abs()
return make_mask_topk_sparsity(score, sparsity, score_shaper)
[docs]def from_zeros(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
) -> torch.BoolTensor:
"""
Any zeros currently in the weights represent pruned connections.
NOTE: Doesn't actualy honor the configured sparsity.
"""
return p.to(sparsity.device) != 0
[docs]def checkerboard(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
) -> torch.BoolTensor:
"""
Mostly for stress and performance testing, creates a sparsity mask that is
maximally distributed in a checkerboard across the weight.
"""
density = 1 - sparsity.item()
# Create a row with a uniformly distributed sparsity pattern
col = p.shape[-1]
# Alocate padding for potential rolling to still result in balance.
padding = int(np.ceil(col / density + 1e-5))
# [ 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 ]
steps = np.floor(np.arange(col + padding) * density + 1e-5)
# [ F, F, T, F, F, T, F, F]
mask = steps[1:] != steps[:-1]
if len(p.shape) == 2:
row = p.shape[0]
# Now evenly distribute this over the rows as well by rolling each
offset = -np.nonzero(mask)[0][0]
mask = np.stack([np.roll(mask, x * offset) for x in range(row)])
# Trim off padding columns
mask = mask[..., :col]
return torch.tensor(mask, device=sparsity.device)
def _noop_compile_only(
p: torch.nn.Parameter,
sparsity: torch.FloatTensor,
score_shaper: Optional[ScoreShaper] = None,
) -> torch.BoolTensor:
"""
"init" method that doesn't init to be used only with compile_only. This
avoids computing masks on the CPU that aren't ultimately used.
"""
return torch.empty_like(p, dtype=torch.bool)
def make_init_method(init_method: InitMethodType) -> InitMethodCallable:
from cerebras_pytorch.backend import current_backend_impl
if current_backend_impl().compile_only:
return _noop_compile_only
init_methods = {
"random": random,
"topk": topk,
"from_zeros": from_zeros,
"checkerboard": checkerboard,
}
init_method_error = (
f'Unknown `init_method`: "{init_method}". Valid options are one '
f'of the built-in {list(init_methods.keys())} or a function with '
f'signature {InitMethodCallable}.'
)
if isinstance(init_method, str):
if init_method not in init_methods:
raise ValueError(init_method_error)
init_method = init_methods[init_method]
elif callable(init_method):
signature = inspect.signature(init_method)
if not signature_matches_type_hint(signature, InitMethodCallable):
raise ValueError(init_method_error)
else:
raise ValueError(init_method_error)
return init_method