Source code for cerebras_pytorch.utils.data.utils

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Utility and helper functions used by the Cerebras dataloader"""
import math
from typing import Optional
from warnings import warn

import numpy as np
import torch

import cerebras_pytorch as cstorch
from cerebras_appliance.data.dtypes import bf16, is_bf16
from cerebras_pytorch.utils._num import ceildiv
from cerebras_pytorch.utils.nest import visit_torch_tensors


def compute_num_steps(
    dataloader: torch.utils.data.DataLoader,
    initial_step: int = 0,
    num_steps: Optional[int] = None,
    max_steps: Optional[int] = None,
    num_epochs: Optional[int] = None,
    steps_per_epoch: Optional[int] = None,
    grad_accum_steps: int = 1,
):
    """
    Computes the number of steps to execute on the system based on the
    provided step information

    Args:
        dataloader: The dataloader itself which is used to determine the length
            of the dataset if available
        initial_step: The step to begin on. An error is thrown if the initial
            step exceeds the maximal steps calulated below
        num_steps: The number of steps to run
        max_steps: The maximum number of steps to run
        num_epochs: The number of epochs to run
        steps_per_epoch: The number of steps to run each epoch
        grad_accum_steps: The number of steps accumulate gradients before stepping

    Note:
        At least one of num_steps, max_steps, or num_epochs must be specified

    Returns:
        The calculated total number of steps to execute
    """
    try:
        # Dataset length is known
        dataloader_size = len(dataloader)
        assert dataloader_size > 0, "Dataloader does not generate any batches."
        if steps_per_epoch is not None:
            if steps_per_epoch > dataloader_size:
                raise ValueError(
                    f"The requested steps per epoch of {steps_per_epoch} "
                    f"exceeds total steps in an epoch, which is "
                    f"{dataloader_size}."
                )
        else:
            steps_per_epoch = dataloader_size

        # With grad accumulation, the global step is incremented every Nth
        # batch, so our effective steps per epoch needs to be adjusted.
        assert grad_accum_steps <= steps_per_epoch, (
            f"Gradient accumulation steps of {grad_accum_steps} is "
            f"greater than batches per epoch of {steps_per_epoch}."
        )

        steps_per_epoch //= grad_accum_steps
    except TypeError:
        # Dataset length is not known
        assert num_epochs is None, (
            "Specifying num_epochs for datasets with unknown length is "
            "not allowed. Please control training behavior through "
            "number of steps instead."
        )
        steps_per_epoch = 1

    # Calculate total steps
    total_steps = math.inf
    if num_epochs is not None:
        total_steps = min(total_steps, num_epochs * steps_per_epoch)
    if num_steps is not None:
        total_steps = min(total_steps, num_steps)
    if max_steps is not None:
        remaining_steps = max_steps - initial_step
        assert remaining_steps > 0, (  # This was checked above
            f"Initial global step {initial_step} already exceeds "
            f"max step {max_steps}."
        )
        total_steps = min(total_steps, remaining_steps)

    # At least one of the above if blocks must have been true.
    # Adding an assert in case someone makes a mistake.
    assert not math.isinf(
        total_steps
    ), "One of num_epochs, num_steps, or max_steps must be provided"

    if num_epochs is None:
        steps_per_epoch = total_steps

    # Override steps_per_epoch depending on the num_epochs computation
    num_epochs = ceildiv(total_steps, steps_per_epoch)
    steps_per_epoch = ceildiv(total_steps, num_epochs)

    return total_steps


def infer_batch_size(data, batch_size=None) -> int:
    """Infers the batch size from a dataloader batch.

    Args:
        data: A nested structure of tensors.
        batch_size: The batch size to compare against.
            If None, the batch size is inferred from the data.
    """
    inferred_batch_sizes = set(
        1 if len(tensor.size()) == 0 else tensor.size()[0]
        for _, tensor in visit_torch_tensors(data)
    )
    if len(inferred_batch_sizes) > 1:
        if cstorch.use_cs():
            raise RuntimeError(
                f"Only uniform batch sizes are supported in CS runs, but "
                f"the dataloader returned a batch with batch sizes "
                f"{inferred_batch_sizes}. "
            )
        warn(
            f"Detected non-uniform batch sizes within the same batch: "
            f"{inferred_batch_sizes}. While this is allowed in non-CSX "
            f"runs, it may throw off metrics such as rate profiling. "
            f"The run will proceed assuming no batch size."
        )
        return None

    if len(inferred_batch_sizes) == 1:
        inferred_batch_size = inferred_batch_sizes.pop()

        if batch_size is not None and inferred_batch_size != batch_size:
            if cstorch.use_cs():
                raise RuntimeError(
                    f"Only uniform batch sizes are supported in CS runs, but "
                    f"the dataloader returned two different batches with "
                    f"batch sizes {batch_size} and {inferred_batch_size}. "
                    f"Make sure to set `drop_last=True` in the dataloader."
                )
            else:
                warn(
                    f"Detected non-uniform batch sizes between batches "
                    f"({batch_size} vs {inferred_batch_size}). "
                    f"While this is allowed in non-CSX runs, it may throw off "
                    f"metrics such as rate profiling. "
                )

        return inferred_batch_size

    raise RuntimeError(
        "We could not detect any torch tensors in the input data "
        "returned by the dataloader. We expect the dataloader to "
        "return a nested dict/list/tuple of tensors. If there are "
        "custom types that internally hold tensors, we are not "
        "currently able to detect them. Please ensure that the "
        "dataloader returns tensors in the expected format."
    )


[docs]def to_numpy(tensor: torch.Tensor) -> np.ndarray: """Converts a torch tensor to a numpy array.""" if tensor.dtype == torch.bfloat16: assert bf16.itemsize == 2 # Sanity check return tensor.view(torch.int16).numpy().view(bf16) return tensor.numpy()
[docs]def from_numpy(array: np.ndarray) -> torch.Tensor: """Converts a numpy array to a torch tensor.""" # Copy non-writeable array to make it writable for torch.from_numpy if not array.flags.writeable: array = array.copy() if is_bf16(array.dtype): return torch.from_numpy(array).view(torch.bfloat16) return torch.from_numpy(array)