Writing Custom Optimizers#

To define a Cerebras-compliant optimizer, create a subclass of the following:


For example:

class CustomOptimizer(cstorch.optim.Optimizer):

    def __init__(self, params, ...):
        defaults = ...
        super().__init__(params, defaults, enable_global_step=...)


    def preinitialize(self):

    def step(self, closure=None):

    def state_names_to_sparsify(self):

As seen in the above example, similar to torch.optim.Optimizer, the base Optimizer class expects three arguments. Namely, the model parameters, the param group defaults, as well as optional enable_global_step which will define a global step state variable for each parameter.

In addition, there are three abstract methods that must be overridden:

  • preinitialize

    This method is used to initialize any state variables that will be used by the optimizer. For example, SGD defines its momentum buffers in its preinitialize method.

    Note to remain Cerebras-compliant, no optimizer state variables may be initialized outside of the preinitialize method

    For optimal performance, when initializing the state tensors that are filled with some constant value, use the creation ops that are available in the cstorch package to lazily initialize them. These ops will lazily initialize and fill the tensor, meaning that they take up very little memory and can be initialized much quicker than their torch counterparts when running on the Cerebras Wafer Scale cluster. Please see the source code for the optimizers in cerebras_pytorch for examples.

  • step

    This method is where the optimizer step is implemented. Note that due to the nature of lazy tensor tracing and execution, there may not be any Python level conditions or loops used to dynamically define the control flow. This means that only torch ops (such as torch.where) may be used.

    However, static structures are allowed. For example, a loop with a fixed number of iterations, or a Python conditional that doesn’t involve any torch tensors whose conditional involves only constant variables.

  • state_names_to_sparsify

    This method returns the names of the state variables that need to be sparsified. Refer to the existing optimizer implementations for examples.