Source code for cerebras_pytorch.utils.data.dataloader

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Defines the Cerebras DataLoader class and RestartableDataLoader protocol class."""
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Union

import torch
from typing_extensions import Protocol, runtime_checkable

from cerebras_appliance.log import ClassLogger, named_class_logger
from cerebras_pytorch.backend import Backend, current_backend_impl
from cerebras_pytorch.utils.data.utils import infer_batch_size


@named_class_logger
class DataLoader(ClassLogger):
    """
    Wrapper around torch.utils.data.DataLoader that facilitates
    moving data generated by the dataloader to a Cerebras system

    Args:
        input_fn: A callable that returns a torch.utils.data.DataLoader
            instance or an iterable that returns a structure containing torch
            tensors.
        *args, **kwargs: Any other positional or keyword arguments
            are passed into the input_fn when each worker instantiates
            their respective dataloaders
    """

    def __init__(
        self,
        input_fn: Callable[..., Union[torch.utils.data.DataLoader, Iterable]],
        *args,
        **kwargs,
    ):
        if not callable(input_fn):
            raise TypeError(
                "Expected a callable that constructs and returns a "
                "`torch.utils.data.DataLoader` or an iterable that "
                "returns a structure containing torch tensors."
            )

        self.input_fn = input_fn
        self.input_fn_params = deepcopy((args, kwargs))

        self.dataloader = input_fn(*args, **kwargs)
        self.loaded_state_dict = None
        self.worker_states = None
        self.batch_size = None
        self.enable_dataloader_checkpointing = True

    @property
    def is_restartable(self) -> bool:
        """Returns True if dataloader is restartable."""
        return isinstance(self.dataloader, RestartableDataLoader)

    @property
    def _backend(self) -> Backend:
        """Returns the current backend implementation."""
        return current_backend_impl()

    def disable_dataloader_checkpointing(self):
        """Disable DataLoader checkpointing."""
        if self._backend.backend_type.is_csx:
            self.enable_dataloader_checkpointing = False

    def state_dict(self) -> Dict[str, Any]:
        """Returns dataloader state to save in a checkpoint
        by invoking the saving mechanism of the
        :py:class:`~cerebras_pytorch.utils.data.RestartableDataLoader` API.

        Returns:
            `dict` capturing dataloader state as specified in the
            implementation of the dataloader's `aggregate_state_dict`
            method
        """
        if (
            not self._backend.backend_type.is_csx
            or not self.enable_dataloader_checkpointing
        ):
            return {}
        if not self._backend.run_context.is_checkpoint_step:
            raise RuntimeError(
                "DataLoader state can only be requested at a checkpoint step. Please "
                "ensure that `state_dict` is called on the `cstorch.utils.DataLoader` "
                "at a checkpoint step. If you're calling it inside of a method, please "
                "decorate it with the `cstorch.checkpoint_closure` method decorator."
            )

        state_dict = {}

        if self.is_restartable:
            # NOTE: This condition is necessary to avoid error
            # during the call to write initial checkpoint, when
            # the `appliance` instance has not yet been created.
            if not self._backend.run_context.is_pre_initial_step:
                step = self._backend.run_context.iteration + 1
                worker_states = self._fetch_worker_state_dicts(step)
                # NOTE: For aggregation, we only pass the per WRK
                # state dict users explicitly chose to save in their
                # `state_dict` implementation.
                aggregated_state_dict = self.dataloader.aggregate_state_dict(
                    [
                        worker_state["user_state_dict"]
                        for worker_state in worker_states
                    ]
                )
                state_dict = aggregated_state_dict
        else:
            # TODO: Add link to `RestartableDataLoader` docs when
            # ready in the log message below.
            warnings.warn(
                "DataLoader is not configured for restart. "
                "Please implement `state_dict`, `aggregate_state_dict`, "
                "`load_state_dict` and `deaggregate_state_dict` methods "
                "to enable deterministic restart."
            )

        return state_dict

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Loads dataloader state from the provided `state_dict`
        by invoking the loading mechanism of the
        :py:class:`~cerebras_pytorch.utils.data.RestartableDataLoader` API.

        Args:
            state_dict: dict capturing dataloader state loaded from a
                checkpoint
        """
        if not self._backend.backend_type.is_csx:
            return

        self.loaded_state_dict = state_dict
        self.logger.debug(f"Loaded dataloader state dict: {state_dict}")

    def _fetch_worker_state_dicts(self, step: int) -> List[Dict[str, Any]]:
        """Fetches and returns list of individual Worker state dicts
        from appliance at the given step.
        """
        self.logger.verbose(
            f"Fetching dataloader state at checkpoint step: {step}"
        )
        wrk_state_dict_list = self._backend.appliance.fetch_dataloader_state(
            step
        )
        return wrk_state_dict_list

    def serialize_state_dict(self):
        """Communicates list of worker state dicts to appliance."""
        if (
            self.loaded_state_dict is not None
            and self._backend.backend_type.is_csx
        ):
            self._backend.appliance.serialized_dataloader_state = (
                self.loaded_state_dict
            )

    def __len__(self):
        return len(self.dataloader)

    def __iter__(self):
        for batch in self.dataloader:
            self.batch_size = infer_batch_size(batch, self.batch_size)
            yield batch


[docs]@runtime_checkable class RestartableDataLoader(Protocol): """Defines interface for the restartable dataloader protocol."""
[docs] def state_dict(self) -> Dict[str, Any]: """Use this method to specify what state information should be saved by each CSX Worker. Returns: dict holding state information for the CSX Worker In order to access Cerebras internal data checkpoint info per CSX Worker at some checkpoint step, follow the steps in the example below. Cerebras internal data checkpoint format is recorded in the :py:class:`~cerebras_pytorch.utils.data.DataLoaderCheckpoint` dataclass. Usage: :: import cerebras_pytorch as cstorch ... def state_dict(self) -> Dict[str, Any]: worker_state = cstorch.distributed.get_worker_state() state_dict = {} if worker_state: state_dict["worker_step"] = worker_state.worker_step state_dict["worker_id"] = worker_state.global_worker_id return state_dict .. note:: The call to :py:func:`~cerebras_pytorch.distributed.get_worker_state` is well-defined only inside of the `state_dict` method; using this anywhere else will result in a RuntimeError exception. See linked docs for more details. """
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Use this method to load CSX Worker state for the dataloader instance, as captured from a previous run. Args: state_dict: dict holding worker state info, specified in :py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.deaggregate_state_dict` Usage: :: def load_state_dict(self, state_dict): wrk_state_dict = state_dict.get("worker_0", {}) worker_step = wrk_state_dict.get("worker_step", 0) worker_id = wrk_state_dict.get("worker_id") print(f"WRK {worker_id} loaded step: {worker_step}") """
[docs] def aggregate_state_dict( self, worker_states: List[Dict[str, Any]] ) -> Dict[str, Any]: """Use this method to specify how to combine the list of CSX Worker state dicts. Each CSX Worker state in the `worker_states` list is to be specified in :py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.state_dict` Returns: The consolidated state dict that will be saved in a checkpoint. Usage: :: def aggregate_state_dict(self, worker_states): return { "worker_0": worker_states[0], "worker_1": worker_states[1] } """
[docs] def deaggregate_state_dict( self, aggregated_state_dict: Dict[str, Any] ) -> Dict[str, Any]: """Use this method to specify how to load an individual CSX Worker state given a consolidated list of state dicts, as specified in :py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.aggregate_state_dict`. Returns: The state dict will be passed to the above-defined :py:meth:`~cerebras_pytorch.utils.data.RestartableDataLoader.load_state_dict` method. Usage: :: def deaggregate_state_dict(self, aggregated_state_dict): return { "worker_0": aggregated_state_dict.get("worker_0", {}) } """