Source code for modelzoo.fc_mnist.pytorch.data

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torchvision import datasets, transforms

import cerebras_pytorch as cstorch
import cerebras_pytorch.distributed as dist
from modelzoo.common.pytorch.input_utils import get_streaming_batch_size
from modelzoo.common.pytorch.utils import SampleGenerator


[docs]def get_train_dataloader(params): """ :param <dict> params: dict containing input parameters for creating dataset. Expects the following fields: - "data_dir" (string): path to the data files to use. - "batch_size" (int): batch size - "to_float16" (bool): whether to convert to float16 or not - "drop_last_batch" (bool): whether to drop the last batch or not """ input_params = params["train_input"] use_cs = cstorch.use_cs() batch_size = get_streaming_batch_size(input_params.get("batch_size")) to_float16 = input_params.get("to_float16", True) shuffle = input_params["shuffle"] dtype = torch.float32 if to_float16: if use_cs or torch.cuda.is_available(): dtype = cstorch.amp.get_half_dtype() else: print( f"Input dtype float16 is not supported with " f"vanilla PyTorch CPU workflow. Using float32 instead." ) if input_params.get("use_fake_data", False): num_streamers = dist.num_streamers() if dist.is_streamer() else 1 train_loader = SampleGenerator( data=( torch.zeros(batch_size, 1, 28, 28, dtype=dtype), torch.zeros( batch_size, dtype=torch.int32 if use_cs else torch.int64 ), ), sample_count=60000 // batch_size // num_streamers, ) else: train_dataset = datasets.MNIST( input_params["data_dir"], train=True, download=dist.is_master_ordinal(), transform=transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda( lambda x: torch.as_tensor(x, dtype=dtype) ), ] ), target_transform=transforms.Lambda( lambda x: torch.as_tensor(x, dtype=torch.int32) ) if use_cs else None, ) train_sampler = None if use_cs and dist.num_streamers() > 1 and dist.is_streamer(): train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=dist.num_streamers(), rank=dist.get_streaming_rank(), shuffle=shuffle, ) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, drop_last=input_params["drop_last_batch"], shuffle=False if train_sampler else shuffle, num_workers=input_params.get("num_workers", 0), ) return train_loader
[docs]def get_eval_dataloader(params): input_params = params["eval_input"] use_cs = cstorch.use_cs() batch_size = get_streaming_batch_size(input_params.get("batch_size")) to_float16 = input_params.get("to_float16", True) dtype = torch.float32 if to_float16: if use_cs or torch.cuda.is_available(): dtype = cstorch.amp.get_half_dtype() else: print( f"Input dtype float16 is not supported with " f"vanilla PyTorch CPU workflow. Using float32 instead." ) if input_params.get("use_fake_data", False): num_streamers = dist.num_streamers() if dist.is_streamer() else 1 eval_loader = SampleGenerator( data=( torch.zeros(batch_size, 1, 28, 28, dtype=dtype), torch.zeros( batch_size, dtype=torch.int32 if use_cs else torch.int64 ), ), sample_count=10000 // batch_size // num_streamers, ) else: eval_dataset = datasets.MNIST( input_params["data_dir"], train=False, download=dist.is_master_ordinal(), transform=transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda( lambda x: torch.as_tensor(x, dtype=dtype) ), ] ), target_transform=transforms.Lambda( lambda x: torch.as_tensor(x, dtype=torch.int32) ) if use_cs else None, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=batch_size, drop_last=input_params["drop_last_batch"], shuffle=False, num_workers=input_params.get("num_workers", 0), ) return eval_loader