Source code for cerebras_pytorch.sparse.wrapper

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Wrap a regular optimizer with a sparsity optimizer for auto-stepping.
"""

from torch.optim.optimizer import Optimizer


[docs]class SparsityWrapperOptimizer(Optimizer): """ Helper Optimizer that can be used as a drop-in replacement for the main optimizer that also takes care of updating and applying sparsity. """
[docs] def __init__(self, optimizer, sparsity_optimizer): if hasattr(optimizer, "state_names_to_sparsify"): # Determine which optimizer states need sparsification opt_states_to_sparsify = optimizer.state_names_to_sparsify() sparsity_optimizer.manage_optimizer_state_sparsity( optimizer, opt_states_to_sparsify ) self.optimizer = optimizer self.sparsity_optimizer = sparsity_optimizer backend = sparsity_optimizer.backend backend.unregister_optimizer(optimizer) backend.unregister_optimizer(sparsity_optimizer) backend.register_optimizer(self)
@property def state(self): # Merge the states, including nested merging for shared params o = self.optimizer.state s = self.sparsity_optimizer.state def merge(k): ov = o.get(k) os = s.get(k) if ov and os: return {**ov, **os} elif ov: return ov return os return {k: merge(k) for k in o.keys() | s.keys()} @property def param_groups(self): # Only expose the param groups of the main optimizer, otherwise there # would appear to be duplicates in the param_groups[i]["params"] return self.optimizer.param_groups def zero_grad(self, set_to_none: bool = True): self.optimizer.zero_grad(set_to_none) self.sparsity_optimizer.zero_grad(set_to_none) def state_dict(self): state_dict = self.optimizer.state_dict() state_dict["sparsity"] = self.sparsity_optimizer.state_dict() return state_dict def load_state_dict(self, state_dict): sparsity = state_dict.pop("sparsity", None) if sparsity is not None: self.sparsity_optimizer.load_state_dict(sparsity) self.optimizer.load_state_dict(state_dict)
[docs] def visit_state(self, fn): """ Applies a lambda to each stateful value. """ self.optimizer.visit_state(fn) self.sparsity_optimizer.visit_state(fn)
def step(self, closure=None): self.optimizer.step(closure) self.sparsity_optimizer.step()