cerebras_pytorch package#

Creation Ops#

Can be used to lazily initialize tensors with known shape, dtype and value to avoid have them unnecessarily take up memory.

full#

cerebras_pytorch.full(shape, value: float, dtype=None)[source]#

Returns an lazily initialized tensor filled with the provided value

Parameters
  • shape – The shape of the tensor.

  • value – The value to fill the tensor with.

  • dtype – The dtype of the tensor.

full_like#

cerebras_pytorch.full_like(other: torch.Tensor, value: float, dtype=None)[source]#

Returns an lazily initialized full tensor with the same properties as the provided tensor

Parameters
  • other – The tensor to copy the properties from

  • value – The value to fill the tensor with

  • dtype – The dtype of the tensor. If not provided, the dtype of the other tensor is used

ones#

cerebras_pytorch.ones(shape, dtype=None)[source]#

Returns an lazily initialized tensor filled with ones

Parameters
  • shape – The shape of the tensor

  • dtype – The dtype of the tensor

ones_like#

cerebras_pytorch.ones_like(other: torch.Tensor, dtype=None)[source]#

Returns an lazily initialized tensor full of ones with the same properties as the provided tensor

Parameters
  • other – The tensor to copy the properties from

  • dtype – The dtype of the tensor. If not provided, the dtype of the other tensor is used

zeros#

cerebras_pytorch.zeros(shape, dtype=None)[source]#

Returns an lazily initialized tensor filled with zeros

Parameters
  • shape – The shape of the tensor

  • dtype – The dtype of the tensor

zeros_like#

cerebras_pytorch.zeros_like(other: torch.Tensor, dtype=None)[source]#

Returns an lazily initialized tensor full of zeros with the same properties as the provided tensor

Parameters
  • other – The tensor to copy the properties from

  • dtype – The dtype of the tensor. If not provided, the dtype of the other tensor is used

Checkpoint Saving/Loading utilities#

cerebras_pytorch.save(obj: dict, checkpoint_file: str) None[source]#

Save a PyTorch state dict to the given file.

Parameters
  • obj – The object to save.

  • checkpoint_file – The path to save the object to.

cerebras_pytorch.load(checkpoint_file: Union[cerebras_appliance.utils.file.StrPath, IO], map_location: Optional[Union[str, torch.device, Callable, dict]] = None, **kwargs) Any#

Load a PyTorch checkpoint from a file.

Parameters
  • checkpoint_file – The path to the checkpoint to load.

  • map_location – A mapping of where to load the checkpoint content to. If the map_location is None, then the tensors will be lazily loaded from the checkpoint file every single time the tensor is accessed. If the map_location is “cache”, then the tensors will be cached once they are lazily loaded from the checkpoint file. If the map location is “cpu”, then the tensors will be eagerly loaded into memory from the checkpoint file.

  • **kwargs – Additional keyword arguments to pass to the vanilla torch checkpoint loader. These are ignored if the checkpoint is a Cerebras HDF5 checkpoint.

Returns

The loaded checkpoint file.

Raises

RuntimeError – If the checkpoint file does not exist or checkpoint is not a valid HDF5 or vanilla torch checkpoint.

Data Utilities#

utils.data.DataLoader#

class cerebras_pytorch.utils.data.DataLoader#

Wrapper around torch.utils.data.DataLoader that facilitates moving data generated by the dataloader to a Cerebras system

Parameters
  • input_fn – A callable that returns a torch.utils.data.DataLoader instance or an iterable that returns a structure containing torch tensors.

  • *args – Any other positional or keyword arguments are passed into the input_fn when each worker instantiates their respective dataloaders

  • **kwargs

    Any other positional or keyword arguments are passed into the input_fn when each worker instantiates their respective dataloaders

__init__(*args: Any, **kwargs: Any) None#
disable_dataloader_checkpointing()#
load_state_dict()#
serialize_state_dict()#
state_dict()#

Each worker will call this input function to construct their own dataloader object. This means that some data sharding scheme is required if the intent is for each worker to stream in a unique set of data.

utils.data.SyntheticDataset#

class cerebras_pytorch.utils.data.SyntheticDataset[source]#

A synthetic dataset that generates samples from a SampleSpec.

Constructs a SyntheticDataset instance.

A synthetic dataset can be used to generate samples on the fly with an expected dtype/shape but without needing to create a full-blown dataset. This is especially useful for compile validation.

Parameters
  • sample_spec

    Specification of the samples to generate. This can be a nested structure of one of the following types:

    • torch.Tensor: A tensor to be cloned.

    • Callable: A callable that takes the sample index and

      returns a tensor.

    Supported data structures for holding the above leaf nodes are list, tuple, dict, OrderedDict, and NamedTuple.

  • num_samples – Total size of the dataset. If None, the dataset will generate samples indefinitely.

__init__(sample_spec: Union[torch.Tensor, Callable[[int], torch.Tensor], List[Union[torch.Tensor, Callable[[int], torch.Tensor], List[SampleSpecT], Tuple[SampleSpecT, ...], Dict[str, SampleSpecT], OrderedDict[str, SampleSpecT], NamedTuple]], Tuple[Union[torch.Tensor, Callable[[int], torch.Tensor], List[SampleSpecT], Tuple[SampleSpecT, ...], Dict[str, SampleSpecT], OrderedDict[str, SampleSpecT], NamedTuple], ...], Dict[str, Union[torch.Tensor, Callable[[int], torch.Tensor], List[SampleSpecT], Tuple[SampleSpecT, ...], Dict[str, SampleSpecT], OrderedDict[str, SampleSpecT], NamedTuple]], OrderedDict[str, Union[torch.Tensor, Callable[[int], torch.Tensor], List[SampleSpecT], Tuple[SampleSpecT, ...], Dict[str, SampleSpecT], OrderedDict[str, SampleSpecT], NamedTuple]], NamedTuple], num_samples: Optional[int] = None)[source]#

Constructs a SyntheticDataset instance.

A synthetic dataset can be used to generate samples on the fly with an expected dtype/shape but without needing to create a full-blown dataset. This is especially useful for compile validation.

Parameters
  • sample_spec

    Specification of the samples to generate. This can be a nested structure of one of the following types:

    • torch.Tensor: A tensor to be cloned.

    • Callable: A callable that takes the sample index and

      returns a tensor.

    Supported data structures for holding the above leaf nodes are list, tuple, dict, OrderedDict, and NamedTuple.

  • num_samples – Total size of the dataset. If None, the dataset will generate samples indefinitely.

utils.data.DataExecutor#

class cerebras_pytorch.utils.data.DataExecutor#

Defines a single execution run on a Cerebras wafer scale cluster

Parameters
  • dataloader – the dataloader to use for the run

  • num_steps – the number of steps to run. Defaults to 1 if the backend was configured for compile or validate only

  • checkpoint_steps – the interval at which to schedule fetching checkpoints from the cluster

  • activation_steps – the interval at which to schedule fetching activations from the cluster

  • cs_config – optionally, a CSConfig object can be passed in to configure the cerebras wafer-scale cluster. if none provided the default configuration values will be used.

  • writer – The summary writer to be used to write any summarized scalars or tensors to tensorboard

  • profiler_activities – The list of activities to profile. By default the total samples, the client side rate and global rate are tracked and accessible via the profiler attribute

__init__(*args: Any, **kwargs: Any) None#

Note

As of Cerebras Release 2.0, we don’t officially support multiple CS runs in a single process. This means that the above executor can only be run/iterated once. Any runs with different configurations must be run in separate processes.

utils.data.RestartableDataLoader#

class cerebras_pytorch.utils.data.RestartableDataLoader[source]#

Defines interface for the restartable dataloader protocol.

__init__(*args, **kwargs)#
aggregate_state_dict(worker_states: List[Dict[str, Any]]) Dict[str, Any][source]#

Use this method to specify how to combine the list of CSX Worker state dicts. Each CSX Worker state in the worker_states list is to be specified in state_dict

Returns

The consolidated state dict that will be saved in a checkpoint.

Usage:

def aggregate_state_dict(self, worker_states):
    return {
        "worker_0": worker_states[0],
        "worker_1": worker_states[1]
    }
deaggregate_state_dict(aggregated_state_dict: Dict[str, Any]) Dict[str, Any][source]#

Use this method to specify how to load an individual CSX Worker state given a consolidated list of state dicts, as specified in aggregate_state_dict.

Returns

The state dict will be passed to the above-defined load_state_dict method.

Usage:

def deaggregate_state_dict(self, aggregated_state_dict):
    return {
        "worker_0": aggregated_state_dict.get("worker_0", {})
    }
load_state_dict(state_dict: Dict[str, Any]) None[source]#

Use this method to load CSX Worker state for the dataloader instance, as captured from a previous run.

Parameters

state_dict – dict holding worker state info, specified in deaggregate_state_dict

Usage:

def load_state_dict(self, state_dict):
    wrk_state_dict = state_dict.get("worker_0", {})

    worker_step = wrk_state_dict.get("worker_step", 0)
    worker_id = wrk_state_dict.get("worker_id")

    print(f"WRK {worker_id} loaded step: {worker_step}")
state_dict() Dict[str, Any][source]#

Use this method to specify what state information should be saved by each CSX Worker.

Returns

dict holding state information for the CSX Worker

In order to access Cerebras internal data checkpoint info per CSX Worker at some checkpoint step, follow the steps in the example below. Cerebras internal data checkpoint format is recorded in the DataLoaderCheckpoint dataclass.

Usage:

import cerebras_pytorch as cstorch
...
def state_dict(self) -> Dict[str, Any]:
    worker_state = cstorch.distributed.get_worker_state()
    state_dict = {}
    if worker_state:
        state_dict["worker_step"] = worker_state.worker_step
        state_dict["worker_id"] = worker_state.global_worker_id
    return state_dict

Note

The call to get_worker_state is well-defined only inside of the state_dict method; using this anywhere else will result in a RuntimeError exception. See linked docs for more details.

utils.data.DataLoaderCheckpoint#

class cerebras_pytorch.utils.data.DataLoaderCheckpoint#

Dataclass representing the Cerebras internal dataloader checkpoint format. Each CSX Worker captures its state information via this class at a checkpoint step.

Parameters
  • global_worker_id (int) – ID of this worker amongst all other workers across all boxes

  • local_worker_id (int) – ID of this worker amongst all other workers across the same box

  • total_num_workers (int) – The total number of workers for the run across all boxes

  • num_workers_per_csx (int) – The total number of workers per box for the run

  • num_csx (int) – The total number of CSXs (boxes) for the run

  • wse_id (int) – ID of the Wafer-Scale Engine (CSX) to which this worker streams data

  • appliance_step (int) – The appliance step at which this checkpoint state info is captured

  • worker_step (int) – The worker step at which this state info is captured. Note that this is simply equal to appliance_step if num_workers_per_csx = 1; for the multi-worker scenario, the appliance step is distributed across workers on a single box in a round-robin fashion based on the local worker id

  • samples_streamed (int) – The total number of samples streamed by this worker at checkpoint step. This is simply worker_step * per_box_batch_size

Note

appliance_step, worker_step and samples_streamed are the attributes that vary across different steps; whereas the other attributes provide constant state information for the current run.

get_worker_state#

cerebras_pytorch.distributed.get_worker_state()[source]#

API exposing internal state info captured by each CSX Worker for the current run at a checkpoint step. This state info is represented in the DataLoaderCheckpoint dataclass format:

Returns

DataLoaderCheckpoint instance holding worker state information at the checkpoint step

Note

  • This method may only be called inside of a custom implementation of state_dict for

dataloaders conforming to the RestartableDataLoader protocol, since state_dict is well-defined only at a checkpoint step. - Use this method to save any of the aforementioned state info recorded by each worker when defining state_dict for custom implementations of restartable dataloaders. - This state info captured by each worker is for the current run only, i.e. if you pause and restart a run, the counters gathering information returned by this function will be reset.

utils.CSConfig#

class cerebras_pytorch.utils.CSConfig#

Contains config details for the Cerebras Wafer Scale Cluster

Parameters
  • mgmt_address (Optional[str]) – Address to connect to appliance. If not provided, query the cluster management node for it. Default: None.

  • credentials_path (Optional[str]) – Credentials for connecting to appliance. If not provided, query the cluster management node for it. Default: None.

  • num_csx (int) – Number of Cerebras Systems to run on. Default: 1.

  • max_wgt_servers (int) – Number of weight servers to support run. Default: 24.

  • max_act_per_csx (int) – Number of activation servers per system. Default: 1.

  • num_workers_per_csx (int) – Number of streaming workers per system. Default: 1.

  • transfer_processes (int) – Number of processes to transfer data to/from appliance. Default: 5.

  • job_time_sec (int) – Time limit for the appliance jobs, not including the queue time. Default: None.

  • mount_dirs (List[str]) – Local storage to mount to appliance (ex. training data). Default: None.

  • python_paths (List[str]) – A list of path that worker pods respect as PYTHONPATH in addition to the PYTHONPATH set in the container image. Default: None.

  • job_labels (List[str]) – A list of equal-sign-separated key-value pairs that get applied as part of job metadata. Default: None.

  • debug_args (DebugArgs) – Optional debugging arguments object. Default: None.

  • precision_opt_level (int) – The precision optimization level. Default: 1.

numpy utilities#

from_numpy#

cerebras_pytorch.from_numpy(array: numpy.ndarray) torch.Tensor[source]#

Converts a numpy array to a torch tensor.

to_numpy#

cerebras_pytorch.to_numpy(tensor: torch.Tensor) numpy.ndarray[source]#

Converts a torch tensor to a numpy array.

Tensorboard utilities#

cerebras_pytorch.summarize_scalar(name: str, scalar: Union[int, float, torch.Tensor])[source]#

Save the scalar to the event file of the writer specified in the data executor

Parameters
  • name – the key to save the scalar in the event file

  • scalar – the scalar value to summarize. Note, if a torch.Tensor is provided, it must be a scalar tensor for which scalar.item() can be called

Note

Scalars summarized using this API are only visible in Tensorboard if a SummaryWriter was passed to the DataExecutor object.

cerebras_pytorch.summarize_tensor(name: str, tensor: torch.Tensor)[source]#

Save the tensor to the event file of the writer specified in the data executor

Parameters
  • name – the key to save the tensor in the event file

  • tensor – the torch.Tensor to summarize

Note

Tensors summarized using this API are only visible if a SummaryWriter was passed to the DataExecutor object.

class cerebras_pytorch.utils.tensorboard.SummaryWriter#

Thin wrapper around torch.utils.tensorboard.SummaryWriter

Additional features include the ability to add a tensor summary

Parameters
  • base_step – The base step to use in summarize_{scalar,tensor} functions

  • *args – Any other positional and keyword arguments are forwarded directly to the base class

  • **kwargs

    Any other positional and keyword arguments are forwarded directly to the base class

__init__(*args: Any, **kwargs: Any) None#
add_tensor()#
class cerebras_pytorch.utils.tensorboard.SummaryReader#

Class for reading summaries saved using the SummaryWriter

Parameters
  • log_dir – The directory at which the event files can be found

  • kwargs – The remaining keyword arguments are forwarded to the internal EventAccumulator object

__init__(*args: Any, **kwargs: Any) None#
read_scalar()#
read_tensor()#
reload()#
scalar_names()#
tensor_names()#

Dataloader benchmark utilities#

cerebras_pytorch.utils.benchmark.benchmark_dataloader(input_fn: Callable[[...], Iterable], num_epochs: Optional[int] = None, steps_per_epoch: Optional[int] = None, sampling_frequency: Optional[int] = None, profile_activities: Optional[List[str]] = None, print_metrics: bool = True) cerebras_pytorch.utils.benchmark.utils.dataloader.Metrics[source]#

Utility to benchmark a dataloader.

Parameters
  • input_fn – Function that creates and returns a dataloader.

  • num_epochs – Number of epochs to iterate over the dataloader. If None, the dataloader is only iterated for one epoch.

  • steps_per_epoch – Number of steps to iterate over the dataloader in each epoch. If None, the dataloader is iterated in its entirety.

  • sampling_frequency – Frequency at which to sample metrics. If None, a default value of 100 (i.e. every 100 steps) is used. First step of each epoch is always sampled.

  • profile_activities – List of optional activities to profile. If None, no extra activities are profiled. Note that these may incur additional overhead and could affect overall performance of the dataloader, especially if the sampling frequency is high.

  • print_metrics – Whether to pretty print the final metrics to console.

Returns

Metrics for the dataloader experiment.

class cerebras_pytorch.utils.benchmark.utils.dataloader.Metrics[source]#

Metrics for a single dataloader experiment.

Parameters
  • dataloader_build_time – Time to build the dataloader.

  • epoch_metrics – List of metrics for each epoch.

  • batch_specs – Mapping between unique batch specs found and their occurences.

  • total_time – Total time to iterate through all epochs.

  • global_rate – Overall global rate in steps/second.

  • is_partial – Whether the metrics are partial. This can happen if the benchmark is interrupted in the middle of execution.

  • start_time_ns – Time at which the experiment started.

  • end_time_ns – Time at which the experiment ended.

batch_specs: Dict[cerebras_pytorch.utils.benchmark.utils.dataloader.BatchSpec, cerebras_pytorch.utils.benchmark.utils.dataloader.BatchSpecOccurence]#
dataloader_build_time: numpy.timedelta64#
end_time_ns: int = 0#
epoch_metrics: List[cerebras_pytorch.utils.benchmark.utils.dataloader.EpochMetrics]#
global_rate: float = 0.0#
property global_sample_rate: Optional[float]#

Returns the overall global rate in samples/second.

Note that this value only exists if all batches have the exact same structure, dtypes, and shapes. Otherwise, this value is None.

is_partial: bool = True#
start_time_ns: int#
property total_steps: int#

Returns the total number of steps across all epochs.

total_time: numpy.timedelta64#
class cerebras_pytorch.utils.benchmark.utils.dataloader.EpochMetrics[source]#

Metrics for a single epoch of a dataloader experiment.

Parameters
  • iterator_creation – Time to create the dataloader iterator.

  • iteration_time – Time to iterate the entire epoch excluding the creation of the iterator.

  • total_steps – Total number of steps in the epoch.

  • batch_metrics – List of metrics for batches generated in the epoch.

  • start_time_ns – Time at which the epoch started.

  • end_time_ns – Time at which the epoch ended.

batch_metrics: List[cerebras_pytorch.utils.benchmark.utils.dataloader.BatchMetrics]#
end_time_ns: int = 0#
iteration_time: numpy.timedelta64#
iterator_creation: numpy.timedelta64#
start_time_ns: int#
total_steps: int = 0#
property total_time: numpy.timedelta64#

Returns the total time to create and iterate the epoch.

class cerebras_pytorch.utils.benchmark.utils.dataloader.BatchMetrics[source]#

Metrics for a single batch of a dataloader experiment.

Parameters
  • epoch_step – Epoch step at which the batch was generated.

  • global_step – Global step at which the batch was generated.

  • local_rate – Local rate (in steps/second) at the sampling frequency. This is the instantaneous rate (relative to previous batch) at which the batch was generated.

  • global_rate – global rate (in steps/second) at the sampling frequency. This is the global rate since the start of the iterating epochs.

  • profile_activities – Dictionary of profile activities and their values.

  • sampling_time_ns – Time at which the batch was sampled.

epoch_step: int#
global_rate: float#
global_step: int#
local_rate: float#
profile_activities: Dict[str, Any]#
sampling_time_ns: int#