# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Utility and helper functions used by the Cerebras dataloader"""
import math
from typing import Optional
from warnings import warn
import numpy as np
import torch
import cerebras_pytorch as cstorch
from cerebras_appliance.data.dtypes import bf16, is_bf16
from cerebras_pytorch.utils._num import ceildiv
from cerebras_pytorch.utils.nest import visit_torch_tensors
def compute_num_steps(
dataloader: torch.utils.data.DataLoader,
initial_step: int = 0,
num_steps: Optional[int] = None,
max_steps: Optional[int] = None,
num_epochs: Optional[int] = None,
steps_per_epoch: Optional[int] = None,
grad_accum_steps: int = 1,
):
"""
Computes the number of steps to execute on the system based on the
provided step information
Args:
dataloader: The dataloader itself which is used to determine the length
of the dataset if available
initial_step: The step to begin on. An error is thrown if the initial
step exceeds the maximal steps calulated below
num_steps: The number of steps to run
max_steps: The maximum number of steps to run
num_epochs: The number of epochs to run
steps_per_epoch: The number of steps to run each epoch
grad_accum_steps: The number of steps accumulate gradients before stepping
Note:
At least one of num_steps, max_steps, or num_epochs must be specified
Returns:
The calculated total number of steps to execute
"""
def _check_steps(name, value, allow_none=False, allow_zero=False):
if value is None:
if not allow_none:
raise ValueError(f"`{name}` cannot be None.")
else:
if not isinstance(value, int):
raise ValueError(
f"`{name}` must be an integer, but got {type(value)}."
)
if value == 0 and not allow_zero:
raise ValueError(f"`{name}` must be greater than zero.")
if value < 0:
raise ValueError(
f"`{name}` cannot be negative, but got {value}."
)
if num_epochs is not None and num_steps is not None:
raise ValueError(
"Only one of `num_epochs` or `num_steps` can be specified."
)
_check_steps(
"initial_step", initial_step, allow_none=False, allow_zero=True
)
_check_steps("num_steps", num_steps, allow_none=True)
_check_steps("max_steps", max_steps, allow_none=True)
_check_steps("num_epochs", num_epochs, allow_none=True)
_check_steps("steps_per_epoch", steps_per_epoch, allow_none=True)
_check_steps("grad_accum_steps", grad_accum_steps, allow_none=False)
try:
# Dataset length is known
dataloader_size = len(dataloader)
assert dataloader_size > 0, "Dataloader does not generate any batches."
if steps_per_epoch is not None:
if steps_per_epoch > dataloader_size:
raise ValueError(
f"The requested steps per epoch of {steps_per_epoch} "
f"exceeds total steps in an epoch, which is "
f"{dataloader_size}."
)
else:
steps_per_epoch = dataloader_size
# With grad accumulation, the global step is incremented every Nth
# batch, so our effective steps per epoch needs to be adjusted.
if grad_accum_steps > steps_per_epoch:
raise ValueError(
f"Gradient accumulation steps of {grad_accum_steps} is "
f"greater than batches per epoch of {steps_per_epoch}."
)
steps_per_epoch //= grad_accum_steps
except TypeError:
# Dataset length is not known
if num_epochs is not None:
raise ValueError(
"Specifying num_epochs for datasets with unknown length is "
"not allowed. Please control training behavior through "
"number of steps instead."
)
steps_per_epoch = 1
# Calculate total steps
total_steps = math.inf
if num_epochs is not None:
total_steps = min(total_steps, num_epochs * steps_per_epoch)
if num_steps is not None:
total_steps = min(total_steps, num_steps)
if max_steps is not None:
remaining_steps = max_steps - initial_step
if remaining_steps <= 0:
raise RuntimeError(
f"Initial global step {initial_step} already exceeds "
f"max step {max_steps}."
)
total_steps = min(total_steps, remaining_steps)
# At least one of the above if blocks must have been true.
# Adding an assert in case someone makes a mistake.
if math.isinf(total_steps):
raise ValueError(
"One of num_epochs, num_steps, or max_steps must be provided"
)
if num_epochs is None:
steps_per_epoch = total_steps
# Override steps_per_epoch depending on the num_epochs computation
num_epochs = ceildiv(total_steps, steps_per_epoch)
steps_per_epoch = ceildiv(total_steps, num_epochs)
return total_steps
def infer_batch_size(data, batch_size=None) -> int:
"""Infers the batch size from a dataloader batch.
Args:
data: A nested structure of tensors.
batch_size: The batch size to compare against.
If None, the batch size is inferred from the data.
"""
inferred_batch_sizes = set(
1 if len(tensor.size()) == 0 else tensor.size()[0]
for _, tensor in visit_torch_tensors(data)
)
if len(inferred_batch_sizes) > 1:
if cstorch.use_cs():
raise RuntimeError(
f"Only uniform batch sizes are supported in CS runs, but "
f"the dataloader returned a batch with batch sizes "
f"{inferred_batch_sizes}. "
)
warn(
f"Detected non-uniform batch sizes within the same batch: "
f"{inferred_batch_sizes}. While this is allowed in non-CSX "
f"runs, it may throw off metrics such as rate profiling. "
f"The run will proceed assuming no batch size."
)
return None
if len(inferred_batch_sizes) == 1:
inferred_batch_size = inferred_batch_sizes.pop()
if batch_size is not None and inferred_batch_size != batch_size:
if cstorch.use_cs():
raise RuntimeError(
f"Only uniform batch sizes are supported in CS runs, but "
f"the dataloader returned two different batches with "
f"batch sizes {batch_size} and {inferred_batch_size}. "
f"Make sure to set `drop_last=True` in the dataloader."
)
else:
warn(
f"Detected non-uniform batch sizes between batches "
f"({batch_size} vs {inferred_batch_size}). "
f"While this is allowed in non-CSX runs, it may throw off "
f"metrics such as rate profiling. "
)
return inferred_batch_size
raise RuntimeError(
"We could not detect any torch tensors in the input data "
"returned by the dataloader. We expect the dataloader to "
"return a nested dict/list/tuple of tensors. If there are "
"custom types that internally hold tensors, we are not "
"currently able to detect them. Please ensure that the "
"dataloader returns tensors in the expected format."
)
[docs]def to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""Converts a torch tensor to a numpy array."""
if tensor.dtype == torch.bfloat16:
assert bf16.itemsize == 2 # Sanity check
return tensor.view(torch.int16).numpy().view(bf16)
return tensor.numpy()
[docs]def from_numpy(array: np.ndarray) -> torch.Tensor:
"""Converts a numpy array to a torch tensor."""
# Copy non-writeable array to make it writable for torch.from_numpy
if not array.flags.writeable:
array = array.copy()
if is_bf16(array.dtype):
return torch.from_numpy(array).view(torch.bfloat16)
return torch.from_numpy(array)