Source code for cerebras_pytorch.utils.data.utils

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Utility and helper functions used by the Cerebras dataloader"""
import math
from typing import Optional
from warnings import warn

import numpy as np
import torch

import cerebras_pytorch as cstorch
from cerebras_appliance.data.dtypes import bf16, is_bf16
from cerebras_pytorch.utils._num import ceildiv
from cerebras_pytorch.utils.nest import visit_torch_tensors


def compute_num_steps(
    dataloader: torch.utils.data.DataLoader,
    initial_step: int = 0,
    num_steps: Optional[int] = None,
    max_steps: Optional[int] = None,
    num_epochs: Optional[int] = None,
    steps_per_epoch: Optional[int] = None,
    grad_accum_steps: int = 1,
):
    """
    Computes the number of steps to execute on the system based on the
    provided step information

    Args:
        dataloader: The dataloader itself which is used to determine the length
            of the dataset if available
        initial_step: The step to begin on. An error is thrown if the initial
            step exceeds the maximal steps calulated below
        num_steps: The number of steps to run
        max_steps: The maximum number of steps to run
        num_epochs: The number of epochs to run
        steps_per_epoch: The number of steps to run each epoch
        grad_accum_steps: The number of steps accumulate gradients before stepping

    Note:
        At least one of num_steps, max_steps, or num_epochs must be specified

    Returns:
        The calculated total number of steps to execute
    """

    def _check_steps(name, value, allow_none=False, allow_zero=False):
        if value is None:
            if not allow_none:
                raise ValueError(f"`{name}` cannot be None.")
        else:
            if not isinstance(value, int):
                raise ValueError(
                    f"`{name}` must be an integer, but got {type(value)}."
                )
            if value == 0 and not allow_zero:
                raise ValueError(f"`{name}` must be greater than zero.")
            if value < 0:
                raise ValueError(
                    f"`{name}` cannot be negative, but got {value}."
                )

    if num_epochs is not None and num_steps is not None:
        raise ValueError(
            "Only one of `num_epochs` or `num_steps` can be specified."
        )

    _check_steps(
        "initial_step", initial_step, allow_none=False, allow_zero=True
    )
    _check_steps("num_steps", num_steps, allow_none=True)
    _check_steps("max_steps", max_steps, allow_none=True)
    _check_steps("num_epochs", num_epochs, allow_none=True)
    _check_steps("steps_per_epoch", steps_per_epoch, allow_none=True)
    _check_steps("grad_accum_steps", grad_accum_steps, allow_none=False)

    try:
        # Dataset length is known
        dataloader_size = len(dataloader)
        assert dataloader_size > 0, "Dataloader does not generate any batches."
        if steps_per_epoch is not None:
            if steps_per_epoch > dataloader_size:
                raise ValueError(
                    f"The requested steps per epoch of {steps_per_epoch} "
                    f"exceeds total steps in an epoch, which is "
                    f"{dataloader_size}."
                )
        else:
            steps_per_epoch = dataloader_size

        # With grad accumulation, the global step is incremented every Nth
        # batch, so our effective steps per epoch needs to be adjusted.
        if grad_accum_steps > steps_per_epoch:
            raise ValueError(
                f"Gradient accumulation steps of {grad_accum_steps} is "
                f"greater than batches per epoch of {steps_per_epoch}."
            )

        steps_per_epoch //= grad_accum_steps
    except TypeError:
        # Dataset length is not known
        if num_epochs is not None:
            raise ValueError(
                "Specifying num_epochs for datasets with unknown length is "
                "not allowed. Please control training behavior through "
                "number of steps instead."
            )
        steps_per_epoch = 1

    # Calculate total steps
    total_steps = math.inf
    if num_epochs is not None:
        total_steps = min(total_steps, num_epochs * steps_per_epoch)
    if num_steps is not None:
        total_steps = min(total_steps, num_steps)
    if max_steps is not None:
        remaining_steps = max_steps - initial_step
        if remaining_steps <= 0:
            raise RuntimeError(
                f"Initial global step {initial_step} already exceeds "
                f"max step {max_steps}."
            )
        total_steps = min(total_steps, remaining_steps)

    # At least one of the above if blocks must have been true.
    # Adding an assert in case someone makes a mistake.
    if math.isinf(total_steps):
        raise ValueError(
            "One of num_epochs, num_steps, or max_steps must be provided"
        )

    if num_epochs is None:
        steps_per_epoch = total_steps

    # Override steps_per_epoch depending on the num_epochs computation
    num_epochs = ceildiv(total_steps, steps_per_epoch)
    steps_per_epoch = ceildiv(total_steps, num_epochs)

    return total_steps


def infer_batch_size(data, batch_size=None) -> int:
    """Infers the batch size from a dataloader batch.

    Args:
        data: A nested structure of tensors.
        batch_size: The batch size to compare against.
            If None, the batch size is inferred from the data.
    """
    inferred_batch_sizes = set(
        1 if len(tensor.size()) == 0 else tensor.size()[0]
        for _, tensor in visit_torch_tensors(data)
    )
    if len(inferred_batch_sizes) > 1:
        if cstorch.use_cs():
            raise RuntimeError(
                f"Only uniform batch sizes are supported in CS runs, but "
                f"the dataloader returned a batch with batch sizes "
                f"{inferred_batch_sizes}. "
            )
        warn(
            f"Detected non-uniform batch sizes within the same batch: "
            f"{inferred_batch_sizes}. While this is allowed in non-CSX "
            f"runs, it may throw off metrics such as rate profiling. "
            f"The run will proceed assuming no batch size."
        )
        return None

    if len(inferred_batch_sizes) == 1:
        inferred_batch_size = inferred_batch_sizes.pop()

        if batch_size is not None and inferred_batch_size != batch_size:
            if cstorch.use_cs():
                raise RuntimeError(
                    f"Only uniform batch sizes are supported in CS runs, but "
                    f"the dataloader returned two different batches with "
                    f"batch sizes {batch_size} and {inferred_batch_size}. "
                    f"Make sure to set `drop_last=True` in the dataloader."
                )
            else:
                warn(
                    f"Detected non-uniform batch sizes between batches "
                    f"({batch_size} vs {inferred_batch_size}). "
                    f"While this is allowed in non-CSX runs, it may throw off "
                    f"metrics such as rate profiling. "
                )

        return inferred_batch_size

    raise RuntimeError(
        "We could not detect any torch tensors in the input data "
        "returned by the dataloader. We expect the dataloader to "
        "return a nested dict/list/tuple of tensors. If there are "
        "custom types that internally hold tensors, we are not "
        "currently able to detect them. Please ensure that the "
        "dataloader returns tensors in the expected format."
    )


[docs]def to_numpy(tensor: torch.Tensor) -> np.ndarray: """Converts a torch tensor to a numpy array.""" if tensor.dtype == torch.bfloat16: assert bf16.itemsize == 2 # Sanity check return tensor.view(torch.int16).numpy().view(bf16) return tensor.numpy()
[docs]def from_numpy(array: np.ndarray) -> torch.Tensor: """Converts a numpy array to a torch tensor.""" # Copy non-writeable array to make it writable for torch.from_numpy if not array.flags.writeable: array = array.copy() if is_bf16(array.dtype): return torch.from_numpy(array).view(torch.bfloat16) return torch.from_numpy(array)