Source code for modelzoo.common.pytorch.dump_context

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Provides DumpContext, a debug utility for dumping activations and gradients on
a CPU/GPU run, and setting up debug names for dumped WSE activations to be
automatically correlated.
"""
import functools
import os
import warnings
from collections import defaultdict
from contextlib import ContextDecorator

import numpy as np
import torch

import cerebras_pytorch as cstorch
from cerebras_pytorch.utils.nest import visit_torch_tensors


[docs]class DumpContext(ContextDecorator): """ A debug utility context manager. When provided with a torch.nn.Module, the resulting context manager can be entered to enable dumping of all module forward and backward outputs to a npz, for comparing numerics between implementations. """
[docs] def __init__( self, outdir: str, model: torch.nn.Module, buffer_steps: int = None ): """ Sets up global module hoooks to either dump intermediate activations on CPU/GPU or name the traced tensors for correlating with debug dumps on CS2. The recursive name of the torch.nn.Module is memoized, and the output of FWD and BWD of each module is saved as keys in a .npz file. Args: outdir: Where to output dumps_{i}.npz model: root module to name its children buffer_steps: If given, flush to a new .npz file after this many steps """ self._outdir = outdir os.makedirs(self._outdir, exist_ok=True) # The actual hook functions to install self._forward_pre_hook = None self._forward_hook = None self._backward_hook = None self._full_backward_hook = None self.setup_hooks(model) # Any installed hooks, set during enable_collection() self._module_hooks = [] self._call_counter = {} self._buffer_steps = buffer_steps self._flush_count = 0 self._buffer = defaultdict(list)
def __enter__(self): self.enable_collection() return self def __del__(self): self.flush() def __exit__(self, *exc): self.disable_collection() # Check if we need to flush by using the first buffer's size as a # proxy for how many steps we've captured. if self._buffer_steps and self._buffer: first_buffer = next(iter(self._buffer)) if len(first_buffer) >= self._buffer_steps: self.flush()
[docs] def setup_hooks(self, model): """ Define hooking functions on the given torch.nn.Module, but don't install them. Args: model: torch.nn.Module that serves as the root for recursive names """ cstorch.add_debug_name(model) # Helpers for hooks def get_name(module, counter_increment=0): name = cstorch.get_debug_name(module) def_counter = 0 if counter_increment >= 0 else 1 counter = self._call_counter.setdefault(name, def_counter) self._call_counter[name] += counter_increment if counter != def_counter: name = f"{name}.call{counter}" return name def recurse(top_scope, output): for scope, tensor in visit_torch_tensors(output, scope=top_scope): yield ".".join(scope), tensor # pylint: disable=redefined-builtin if cstorch.use_cs(): def fwd_pre_name_scope(module, input): name = get_name(module) + ".fwd" cstorch.set_debug_scope(name) def fwd_post_name_scope(module, input, output): # This will actually be the name for the bwd pass entered from # the module's output's grad hook. # Also, increment the counter for the next fwd_pre to hit. name = get_name(module, 1) + ".bwd" for _, tensor in recurse([name], output): # In case a module returns a tensor unmodifed, don't change # its scope. existing_name = getattr(tensor, "_debug_name_scope", None) if tensor.requires_grad and not existing_name: # pylint: disable=protected-access tensor._debug_name_scope = name # Set scope before beginning bwd pass from any output tensor.register_hook( lambda x: cstorch.set_debug_scope(name) ) # Clear any scope in case this is the last module. cstorch.set_debug_scope(None) def bwd_post_name_scope(module, grad_output, grad_input): # Clear scope after bwd pass is complete. cstorch.set_debug_scope(None) self._forward_pre_hook = fwd_pre_name_scope self._forward_hook = fwd_post_name_scope self._backward_hook = bwd_post_name_scope else: # CPU/CPU def save_output(key, module, input, output): """ Saves to numpy arrays in the output directory. """ counter_increment = 1 if key == "bwd": counter_increment = -1 # hook args are `grad_input, grad_output`, where grad_input # is the _gradient_ of the module's input i.e. the output # of the backward pass and the more interesting value to # dump. This way, the dump named `module.fwd` is the output # of the forward pass (i.e. txact), and `module.bwd` is the # output of the backward pass (i.e. txdelta) for the # corresponding kernel output = input name = get_name(module, counter_increment) for scope, tensor in recurse([name, key], output): tensor = tensor.detach().to("cpu").clone() if tensor.dtype == torch.bfloat16: warnings.warn( f"Encountered bfloat16 tensor in summary " f"collection. Numpy does not natively support " f"bfloat16, so any torch.bfloat16 tensors will be " f"saved as np.float32." ) tensor = tensor.float() numpy = tensor.numpy() self._buffer[scope].append(numpy) self._forward_hook = functools.partial(save_output, "fwd") self._backward_hook = functools.partial(save_output, "bwd")
[docs] def enable_collection(self): """ Install the hooks defined during `setup_hooks`, enabling the collection of the dumps. """ def install_if_set(hook): hook_fn = getattr(self, f"_{hook}_hook") if hook_fn: register_fn = f"register_module_{hook}_hook" return getattr(torch.nn.modules.module, register_fn)(hook_fn) return None hooks = ("forward_pre", "forward", "backward", "full_backward") self._module_hooks = [install_if_set(hook) for hook in hooks] # Clear call counters self._call_counter = {}
[docs] def disable_collection(self): """ Uninstall the hooks installed during `enable_collection`, disabling further dump collection. """ for hook in self._module_hooks: if hook: hook.remove() self._module_hooks = []
[docs] def flush(self): """ Write all dump buffers out to disk. """ if self._flush_count: outfile = f"act_dumps_{self._flush_count}.npz" else: outfile = "act_dumps.npz" np.savez( os.path.join(self._outdir, outfile), **{key: np.stack(values) for key, values in self._buffer.items()}, ) self._buffer.clear() self._flush_count += 1