Writing Custom Optimizers#
To define a Cerebras-compliant optimizer, create a subclass of the following:
class CustomOptimizer(cstorch.optim.Optimizer): def __init__(self, params, ...): ... defaults = ... super().__init__(params, defaults, enable_global_step=...) ... def preinitialize(self): ... def step(self, closure=None): ... def state_names_to_sparsify(self): ...
As seen in the above example, similar to
Optimizer class expects three arguments. Namely, the model
parameters, the param group defaults, as well as optional
which will define a global step state variable for each parameter.
In addition, there are three abstract methods that must be overridden:
This method is used to initialize any state variables that will be used by the optimizer. For example,
SGDdefines its momentum buffers in its
Note to remain Cerebras-compliant, no optimizer state variables may be initialized outside of the
For optimal performance, when initializing the state tensors that are filled with some constant value, use the creation ops that are available in the
cstorchpackage to lazily initialize them. These ops will lazily initialize and fill the tensor, meaning that they take up very little memory and can be initialized much quicker than their
torchcounterparts when running on the Cerebras Wafer Scale cluster. Please see the source code for the optimizers in cerebras_pytorch for examples.
This method is where the optimizer step is implemented. Note that due to the nature of lazy tensor tracing and execution, there may not be any Python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as
torch.where) may be used.
However, static structures are allowed. For example, a loop with a fixed number of iterations, or a Python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.
This method returns the names of the state variables that need to be sparsified. Refer to the existing optimizer implementations for examples.