Writing Custom Optimizers#
To define a Cerebras-compliant optimizer, create a subclass of the following:
cerebras_pytorch.optim.Optimizer
.
For example:
class CustomOptimizer(cstorch.optim.Optimizer):
def __init__(self, params, ...):
...
defaults = ...
super().__init__(params, defaults, enable_global_step=...)
...
def preinitialize(self):
...
def step(self, closure=None):
...
def state_names_to_sparsify(self):
...
As seen in the above example, similar to torch.optim.Optimizer
, the
base Optimizer
class expects three arguments. Namely, the model
parameters, the param group defaults, as well as optional enable_global_step
which will define a global step state variable for each parameter.
In addition, there are three abstract methods that must be overridden:
preinitialize
This method is used to initialize any state variables that will be used by the optimizer. For example,
SGD
defines its momentum buffers in itspreinitialize
method.Note to remain Cerebras-compliant, no optimizer state variables may be initialized outside of the
preinitialize
methodFor optimal performance, when initializing the state tensors that are filled with some constant value, use the creation ops that are available in the
cstorch
package to lazily initialize them. These ops will lazily initialize and fill the tensor, meaning that they take up very little memory and can be initialized much quicker than theirtorch
counterparts when running on the Cerebras Wafer Scale cluster. Please see the source code for the optimizers in cerebras_pytorch for examples.step
This method is where the optimizer step is implemented. Note that due to the nature of lazy tensor tracing and execution, there may not be any Python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as
torch.where
) may be used.However, static structures are allowed. For example, a loop with a fixed number of iterations, or a Python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.
state_names_to_sparsify
This method returns the names of the state variables that need to be sparsified. Refer to the existing optimizer implementations for examples.