# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Defines the Cerebras DataLoader class and RestartableDataLoader protocol class."""
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Union
import torch
from typing_extensions import Protocol, runtime_checkable
from cerebras_appliance.log import ClassLogger, named_class_logger
from cerebras_pytorch.backend import Backend, current_backend_impl
from cerebras_pytorch.utils.data.utils import infer_batch_size
@named_class_logger
class DataLoader(ClassLogger):
"""
Wrapper around torch.utils.data.DataLoader that facilitates
moving data generated by the dataloader to a Cerebras system
Args:
input_fn: A callable that returns a torch.utils.data.DataLoader
instance or an iterable that returns a structure containing torch
tensors.
*args, **kwargs: Any other positional or keyword arguments
are passed into the input_fn when each worker instantiates
their respective dataloaders
"""
def __init__(
self,
input_fn: Callable[..., Union[torch.utils.data.DataLoader, Iterable]],
*args,
**kwargs,
):
if not callable(input_fn):
raise TypeError(
"Expected a callable that constructs and returns a "
"`torch.utils.data.DataLoader` or an iterable that "
"returns a structure containing torch tensors."
)
self.input_fn = input_fn
self.input_fn_params = deepcopy((args, kwargs))
self.dataloader = input_fn(*args, **kwargs)
self.loaded_state_dict = None
self.worker_states = None
self.batch_size = None
self.enable_dataloader_checkpointing = True
@property
def is_restartable(self) -> bool:
"""Returns True if dataloader is restartable."""
return isinstance(self.dataloader, RestartableDataLoader)
@property
def _backend(self) -> Backend:
"""Returns the current backend implementation."""
return current_backend_impl()
def disable_dataloader_checkpointing(self):
"""Disable DataLoader checkpointing."""
if self._backend.backend_type.is_csx:
self.enable_dataloader_checkpointing = False
def state_dict(self) -> Dict[str, Any]:
"""Returns dataloader state to save in a checkpoint
by invoking the saving mechanism of the
:py:class:`~cerebras_pytorch.utils.data.RestartableDataLoader` API.
Returns:
`dict` capturing dataloader state as specified in the
implementation of the dataloader's `aggregate_state_dict`
method
"""
if (
not self._backend.backend_type.is_csx
or not self.enable_dataloader_checkpointing
):
return {}
if not self._backend.run_context.is_checkpoint_step:
raise RuntimeError(
"DataLoader state can only be requested at a checkpoint step. Please "
"ensure that `state_dict` is called on the `cstorch.utils.DataLoader` "
"at a checkpoint step. If you're calling it inside of a method, please "
"decorate it with the `cstorch.checkpoint_closure` method decorator."
)
state_dict = {}
if self.is_restartable:
# NOTE: This condition checks if `state_dict` is called
# for an initial checkpoint request before the run begins.
if self._backend.run_context.is_pre_initial_step:
# Fetch state if already loading from checkpoint.
if self.loaded_state_dict is not None:
state_dict = deepcopy(self.loaded_state_dict)
else:
# Fetch state from the appliance workers
step = self._backend.run_context.iteration + 1
worker_states = self._fetch_worker_state_dicts(step)
# NOTE: For aggregation, we only pass the per WRK
# state dict users explicitly chose to save in their
# `state_dict` implementation.
aggregated_state_dict = self.dataloader.aggregate_state_dict(
[
worker_state["user_state_dict"]
for worker_state in worker_states
]
)
state_dict = aggregated_state_dict
else:
# TODO: Add link to `RestartableDataLoader` docs when
# ready in the log message below.
warnings.warn(
"DataLoader is not configured for restart. "
"Please implement `state_dict`, `aggregate_state_dict`, "
"`load_state_dict` and `deaggregate_state_dict` methods "
"to enable deterministic restart."
)
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads dataloader state from the provided `state_dict`
by invoking the loading mechanism of the
:py:class:`~cerebras_pytorch.utils.data.RestartableDataLoader` API.
Args:
state_dict: dict capturing dataloader state loaded from a
checkpoint
"""
if not self._backend.backend_type.is_csx:
return
self.loaded_state_dict = state_dict
self.logger.debug(f"Loaded DataLoader state: {state_dict}")
def _fetch_worker_state_dicts(self, step: int) -> List[Dict[str, Any]]:
"""Fetches and returns list of individual Worker state dicts
from appliance at the given step.
"""
self.logger.verbose(
f"Fetching dataloader state at checkpoint step: {step}"
)
wrk_state_dict_list = self._backend.appliance.fetch_dataloader_state(
step
)
return wrk_state_dict_list
def serialize_state_dict(self):
"""Communicates list of worker state dicts to appliance."""
if (
self.loaded_state_dict is not None
and self._backend.backend_type.is_csx
):
self._backend.appliance.serialized_dataloader_state = (
self.loaded_state_dict
)
def __len__(self):
return len(self.dataloader)
def __iter__(self):
for batch in self.dataloader:
self.batch_size = infer_batch_size(batch, self.batch_size)
yield batch
[docs]@runtime_checkable
class RestartableDataLoader(Protocol):
"""Defines interface for the restartable dataloader protocol."""
[docs] def state_dict(self) -> Dict[str, Any]:
"""Use this method to specify what state information should be saved
by each CSX Worker.
Returns:
dict holding state information for the CSX Worker
In order to access Cerebras internal data checkpoint info per
CSX Worker at some checkpoint step, follow the steps in the example
below. Cerebras internal data checkpoint format is recorded in the
:py:class:`~cerebras_pytorch.utils.data.DataLoaderCheckpoint` dataclass.
Usage:
::
import cerebras_pytorch as cstorch
...
def state_dict(self) -> Dict[str, Any]:
worker_state = cstorch.distributed.get_worker_state()
state_dict = {}
if worker_state:
state_dict["worker_step"] = worker_state.worker_step
state_dict["worker_id"] = worker_state.global_worker_id
return state_dict
.. note::
The call to :py:func:`~cerebras_pytorch.distributed.get_worker_state`
is well-defined only inside of the `state_dict` method; using this
anywhere else will result in a RuntimeError exception. See linked
docs for more details.
"""
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Use this method to load CSX Worker state for the dataloader instance,
as captured from a previous run.
Args:
state_dict: dict holding worker state info, specified in
:py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.deaggregate_state_dict`
Usage:
::
def load_state_dict(self, state_dict):
wrk_state_dict = state_dict.get("worker_0", {})
worker_step = wrk_state_dict.get("worker_step", 0)
worker_id = wrk_state_dict.get("worker_id")
print(f"WRK {worker_id} loaded step: {worker_step}")
"""
[docs] def aggregate_state_dict(
self, worker_states: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Use this method to specify how to combine the list of CSX Worker state dicts.
Each CSX Worker state in the `worker_states` list is to be specified in
:py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.state_dict`
Returns:
The consolidated state dict that will be saved in a checkpoint.
Usage:
::
def aggregate_state_dict(self, worker_states):
return {
"worker_0": worker_states[0],
"worker_1": worker_states[1]
}
"""
[docs] def deaggregate_state_dict(
self, aggregated_state_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""Use this method to specify how to load an individual CSX Worker state given
a consolidated list of state dicts, as specified in
:py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.aggregate_state_dict`.
Returns:
The state dict will be passed to the above-defined
:py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.load_state_dict` method.
Usage:
::
def deaggregate_state_dict(self, aggregated_state_dict):
return {
"worker_0": aggregated_state_dict.get("worker_0", {})
}
"""