Source code for cerebras_pytorch.core.compile

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""
Provides the fundamental and helper functions
required to compile a model for a Cerebras system
"""
from contextlib import nullcontext
from functools import wraps
from inspect import ismethod
from typing import Union

import torch

import cerebras_pytorch as cstorch
from cerebras_pytorch.backend import (
    Backend,
    current_backend,
    current_backend_impl,
)
from cerebras_pytorch.utils.step_closures import RepeatStepClosure


[docs]def compile( # pylint: disable=redefined-builtin model: torch.nn.Module, backend: Union[str, Backend, None] = None, ): """Prepares the PyTorch module for tracing. This method prepares the module by moving it to the device so that it can be compiled after the first trace. Note that parameter initialization must be done before calling this method since post this call, the parameters are moved to the device. Args: model: The PyTorch module to be compiled. backend: The Cerebras backend to use to compile. If None, the current backend is used. If not current backend is set, the CPU backend is initialized and used. Defaults to None. Returns: A pseudo-module that almost acts like the original module but does not have any of the property accessor or private methods of the original module. It can be called `module(*args, **kwargs)` to run the forward pass, similar to the original module. """ if backend is None: backend = current_backend(raise_exception=False) if backend is None: backend = cstorch.backend("cpu") elif isinstance(backend, str): backend = cstorch.backend(backend) elif isinstance(backend, Backend): curr_backend = current_backend(raise_exception=False) if backend is not curr_backend: raise RuntimeError( f"Compile got a different backend than the currently " f"initialized backend. " ) else: raise TypeError( f"Expected backend to be one of str, Backend or None. " f"Got: {type(backend)}" ) if ( hasattr(model, "cerebras_device") and model.cerebras_device != backend.device ): raise RuntimeError( f"Attempting to compile a model using a different backend " f"than what was used to initialize its parameters. " f"Please make sure that you are using the same backend " f"in initialization and compilation. " ) # pylint: disable=protected-access cs_backend_impl = backend._impl cs_backend_impl.setup_model(model) @wraps(model.__call__) def compiled_forward(*args, **kwargs): return cs_backend_impl.forward(model, *args, **kwargs) # Add aliases to the compiled forward for name in dir(model): method = getattr(model, name) if not name.startswith("_") and ismethod(method): setattr(compiled_forward, name, method) compiled_forward.device = cs_backend_impl.torch_device return compiled_forward
[docs]def trace(step_fn: callable) -> callable: """A decorator that wraps the training/evaluation step function for tracing. Any operation that is meant to be executed on the Cerebras Wafer-Scale Cluster must be wrapped with this decorator. This includes the forward pass, backward pass, optimizer steps, and more. For example, the following code snippet shows how to wrap a training step that does the forward and backward pass and optimizer step: :: @cstorch.trace def training_step(batch, model, optimizer, loss_fn): features, labels = batch outputs = model(features) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() return loss Args: step_fn: The training/evaluation step function to be wrapped. Returns: The wrapped training/evaluation step function. """ outputs = None @wraps(step_fn) def generated_trace_fn(*args, **kwargs): nonlocal outputs backend = current_backend_impl() if ( not backend.in_run_context or not backend.run_context.traced.is_set() ): if backend.retrace_every_iteration: ctx = nullcontext() else: ctx = RepeatStepClosure() with ctx: outputs = step_fn(*args, **kwargs) # Set force=True to mark the outputs as if they were added to a # step closure. This ensures that if caller passes these outputs # to a step closure, we don't get duplicates. backend.mark_output(outputs, force=True) if backend.in_run_context: backend.run_context.traced.set() return outputs return generated_trace_fn