Source code for cerebras_pytorch.sparse.init

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Sparsity mask initialization methods and helpers, invoked by
``BaseSparsityOptimizer``.
"""
import inspect
from typing import Callable, Optional, Union

import numpy as np
import torch

from cerebras_pytorch.utils.typing import signature_matches_type_hint

from .utils import ScoreShaper, make_mask_topk_sparsity

InitMethodCallable = Callable[
    [torch.nn.Parameter, torch.FloatTensor, Optional[ScoreShaper]],
    torch.BoolTensor,
]
InitMethodType = Union[str, InitMethodCallable]


[docs]def random( p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[ScoreShaper] = None, device: Optional[torch.device] = None, ) -> torch.BoolTensor: """ Uniformly random sparsity pattern. """ if device is None: device = p.device # Move sparsity to device so we can use it to trace random initialization sparsity = sparsity.to(device) score = torch.rand_like(p, device=device) return make_mask_topk_sparsity(score, sparsity, score_shaper)
[docs]def topk( p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[ScoreShaper] = None, device: Optional[torch.device] = None, ) -> torch.BoolTensor: """ Prune lowest magnitude weights. """ if device is None: device = p.device # Move sparsity to the device so we can use it to trace topk sparsity = sparsity.to(device) score = p.to(device).abs() return make_mask_topk_sparsity(score, sparsity, score_shaper)
[docs]def from_zeros( p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[ScoreShaper] = None, device: Optional[torch.device] = None, ) -> torch.BoolTensor: """ Any zeros currently in the weights represent pruned connections. NOTE: Doesn't actualy honor the configured sparsity. """ if device is None: device = p.device return p.to(device) != 0
[docs]def checkerboard( p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[ScoreShaper] = None, device: Optional[torch.device] = None, ) -> torch.BoolTensor: """ Mostly for stress and performance testing, creates a sparsity mask that is maximally distributed in a checkerboard across the weight. """ density = 1 - sparsity.item() # Create a row with a uniformly distributed sparsity pattern col = p.shape[-1] # Alocate padding for potential rolling to still result in balance. padding = int(np.ceil(col / density + 1e-5)) # [ 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 ] steps = torch.floor(torch.arange(col + padding) * density + 1e-5) # [ F, F, T, F, F, T, F, F] mask = steps[1:] != steps[:-1] if len(p.shape) == 2: row = p.shape[0] # Now evenly distribute this over the rows as well by rolling each # This offset computation is equivalent to `-np.nonzero(mask)[0][0]` # but is more efficient, and more importantly allows torch.roll # to be traceable. offset = -int(np.floor(1 / density - 1e-5)) mask = torch.stack([torch.roll(mask, x * offset) for x in range(row)]) # Trim off padding columns and return return mask[..., :col]
def _noop_compile_only( p: torch.nn.Parameter, sparsity: torch.FloatTensor, score_shaper: Optional[ScoreShaper] = None, device: Optional[torch.device] = None, ) -> torch.BoolTensor: """ "init" method that doesn't init to be used only with compile_only. This avoids computing masks on the CPU that aren't ultimately used. """ return torch.empty_like(p, dtype=torch.bool) def make_init_method(init_method: InitMethodType) -> InitMethodCallable: from cerebras_pytorch.backend import current_backend_impl if current_backend_impl().compile_only: return _noop_compile_only init_methods = { "random": random, "topk": topk, "from_zeros": from_zeros, "checkerboard": checkerboard, } init_method_error = ( f'Unknown `init_method`: "{init_method}". Valid options are one ' f'of the built-in {list(init_methods.keys())} or a function with ' f'signature {InitMethodCallable}.' ) if isinstance(init_method, str): if init_method not in init_methods: raise ValueError(init_method_error) init_method = init_methods[init_method] elif callable(init_method): signature = inspect.signature(init_method) if not signature_matches_type_hint(signature, InitMethodCallable): raise ValueError(init_method_error) else: raise ValueError(init_method_error) return init_method