cerebras.modelzoo.common.pytorch_utils#

General purpose Pytorch Utilities.

Functions

check_checkpoint_compatibility

Checks that the checkpoint is compatible with the current version of modelzoo.

get_checkpoints

Gather checkpoints in a model directory.

load_from_checkpoint_file

Loads state dict from checkpoint path and checks for version compatibilty.

setup_artifact_dir

Create a unique subdirectory for this run by generating a time stamp so that parallel runs using the same model_dir don't overwrite common files.

setup_logging

Configure default logging format.

setup_logging_excepthook

Setup a logging hook that runs whenever an exception is raised that catches and logs the exception to ensure that the full traceback is printed in the log file.

to_cpu

Move tensor from device to cpu.

to_tensor

If the provided value is a Python int or float, it converts them into PyTorch Tensors of type int32 and float32 respectively.

visit_structure

Recursively traverse nested structure and return the items accepted by the selector.

Classes

BufferedShuffleDataset

Dataset shuffled from the original dataset.

IterableDatasetSampler

This sampler can be used with a multi-worker distributed dataloader.

RunConfigParamsValidator

Validate Run Configs.

SampleGenerator

Iterator which returns multiple samples of a given input data.