# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import torch
from PIL import Image
from torchvision import transforms
from torchvision.datasets import VisionDataset
import cerebras_pytorch as cstorch
import cerebras_pytorch.distributed as dist
from modelzoo.common.pytorch.input_utils import get_streaming_batch_size
from modelzoo.vision.pytorch.input.classification.dataset_factory import (
VisionSubset,
)
from modelzoo.vision.pytorch.input.transforms import LambdaWithParam
from modelzoo.vision.pytorch.input.utils import (
FastDataLoader,
ShardedSampler,
create_worker_cache,
num_tasks,
task_id,
)
from modelzoo.vision.pytorch.unet.input.preprocessing_utils import (
adjust_brightness_transform,
normalize_tensor_transform,
rotation_90_transform,
)
[docs]class InriaAerialDataset(VisionDataset):
[docs] def __init__(
self,
root,
split="train",
transforms=None,
transform=None,
target_transform=None,
use_worker_cache=False,
):
super(InriaAerialDataset, self).__init__(
root, transforms, transform, target_transform
)
if split not in ["train", "val", "test"]:
raise ValueError(
f"Invalid value={split} passed to `split` argument. "
f"Valid are 'train' or 'val' or 'test' "
)
self.split = split
if split == "test" and target_transform is not None:
raise ValueError(
"split {split} has no mask images and hence target_transform should be None. Got {target_tranform}"
)
if use_worker_cache and dist.is_streamer():
if not cstorch.use_cs:
raise RuntimeError(
"use_worker_cache not supported for non-CS runs"
)
else:
self.root = create_worker_cache(self.root)
self.data_dir = os.path.join(self.root, self.split)
self.image_dir = os.path.join(self.data_dir, "images")
self.mask_dir = os.path.join(self.data_dir, "gt")
self.file_list = sorted(os.listdir(self.image_dir))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item.
"""
image_file_path = os.path.join(self.image_dir, self.file_list[index])
image = Image.open(image_file_path) # 3-channel PILImage
mask_file_path = os.path.join(self.mask_dir, self.file_list[index])
target = Image.open(mask_file_path) # PILImage
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self):
return len(self.file_list)
[docs]class InriaAerialDataProcessor:
[docs] def __init__(self, params):
self.use_worker_cache = params["use_worker_cache"]
self.data_dir = params["data_dir"]
self.num_classes = params["num_classes"]
self.image_shape = params["image_shape"] # of format (H, W, C)
self.duplicate_act_worker_data = params.get(
"duplicate_act_worker_data", False
)
self.loss_type = params["loss"]
self.normalize_data_method = params.get("normalize_data_method")
self.shuffle_seed = params.get("shuffle_seed", None)
if self.shuffle_seed is not None:
torch.manual_seed(self.shuffle_seed)
self.augment_data = params.get("augment_data", True)
self.batch_size = get_streaming_batch_size(params["batch_size"])
self.shuffle = params.get("shuffle", True)
# Multi-processing params.
self.num_workers = params.get("num_workers", 0)
self.drop_last = params.get("drop_last", True)
self.prefetch_factor = params.get("prefetch_factor", 10)
self.persistent_workers = params.get("persistent_workers", True)
self.mixed_precision = params.get("mixed_precision")
if self.mixed_precision:
self.mp_type = (
torch.bfloat16 if params["use_bfloat16"] else torch.float16
)
else:
self.mp_type = torch.float32
# Debug params:
self.overfit = params.get("overfit", False)
# default is that each activation worker sends `num_workers`
# batches so total batch_size * num_act_workers * num_pytorch_workers samples
self.overfit_num_batches = params.get(
"overfit_num_batches", num_tasks() * self.num_workers
)
self.random_indices = params.get("overfit_indices", None)
if self.overfit:
logging.info(f"---- Overfitting {self.overfit_num_batches}! ----")
# Using Faster Dataloader for mapstyle dataset.
self.use_fast_dataloader = params.get("use_fast_dataloader", False)
def create_dataset(self, is_training):
split = "train" if is_training else "val"
dataset = InriaAerialDataset(
root=self.data_dir,
split=split,
transforms=self.transform_image_and_mask,
use_worker_cache=self.use_worker_cache,
)
if self.overfit:
random.seed(self.shuffle_seed)
if self.random_indices is None:
indices = random.sample(
range(0, len(dataset)),
self.overfit_num_batches * self.batch_size,
)
else:
indices = self.random_indices
dataset = VisionSubset(dataset, indices)
print(f"---- Overfitting {indices}! ----")
return dataset
def create_dataloader(self, is_training=False):
dataset = self.create_dataset(is_training)
shuffle = self.shuffle and is_training
generator_fn = torch.Generator(device="cpu")
if self.shuffle_seed is not None:
generator_fn.manual_seed(self.shuffle_seed)
if shuffle:
if self.duplicate_act_worker_data:
# Multiples activation workers, each sending same data in different
# order since the dataset is extremely small
if self.shuffle_seed is None:
seed = task_id()
else:
seed = self.shuffle_seed + task_id()
generator_fn.manual_seed(seed)
data_sampler = torch.utils.data.RandomSampler(
dataset, generator=generator_fn
)
else:
data_sampler = ShardedSampler(
dataset, shuffle, self.shuffle_seed, self.drop_last
)
else:
data_sampler = torch.utils.data.SequentialSampler(dataset)
num_samples_per_task = len(data_sampler)
assert (
num_samples_per_task >= self.batch_size
), f"Number of samples available per task(={num_samples_per_task}) is less than batch_size(={self.batch_size})"
if self.use_fast_dataloader:
dataloader_fn = FastDataLoader
print("-- Using FastDataloader -- ")
else:
dataloader_fn = torch.utils.data.DataLoader
print("-- Using torch.utils.data.DataLoader -- ")
if self.num_workers:
dataloader = dataloader_fn(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
drop_last=self.drop_last,
generator=generator_fn,
sampler=data_sampler,
)
else:
dataloader = dataloader_fn(
dataset,
batch_size=self.batch_size,
drop_last=self.drop_last,
generator=generator_fn,
sampler=data_sampler,
)
return dataloader
def _apply_normalization(
self, image, normalize_data_method, *args, **kwargs
):
return normalize_tensor_transform(image, normalize_data_method)
def preprocess_image(self, image):
if self.image_shape[-1] == 1:
image = image.convert(
"L"
) # convert PILImage to grayscale (H, W, 1)
# converts to (C, H, W) format.
to_tensor_transform = transforms.PILToTensor()
# Normalize
normalize_transform = LambdaWithParam(
self._apply_normalization, self.normalize_data_method
)
transforms_list = [
to_tensor_transform,
normalize_transform,
]
image = transforms.Compose(transforms_list)(image)
return image
def preprocess_mask(self, mask):
to_tensor_transform = transforms.PILToTensor()
normalize_transform = LambdaWithParam(
self._apply_normalization, "zero_one"
)
transforms_list = [
to_tensor_transform,
normalize_transform,
]
mask = transforms.Compose(transforms_list)(
mask
) # output of shape (1, 5000, 5000)
return mask
def transform_image_and_mask(self, image, mask):
image = self.preprocess_image(image)
mask = self.preprocess_mask(mask)
if self.augment_data:
do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5
# n_rots in range [0, 3)
n_rotations = torch.randint(low=0, high=3, size=(1,)).item()
if self.image_shape[0] != self.image_shape[1]: # H != W
# For a rectangle image
n_rotations = n_rotations * 2
augment_transform_image = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=True,
)
augment_transform_mask = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=False,
)
image = augment_transform_image(image)
mask = augment_transform_mask(mask)
# Handle dtypes and mask shapes based on `loss_type`
# and `mixed_precsion`
if self.loss_type == "bce":
mask = mask.to(self.mp_type)
if self.mixed_precision:
image = image.to(self.mp_type)
return image, mask
def get_augment_transforms(
self, do_horizontal_flip, n_rotations, do_random_brightness
):
augment_transforms_list = []
if do_horizontal_flip:
horizontal_flip_transform = transforms.Lambda(
lambda x: transforms.functional.hflip(x)
)
augment_transforms_list.append(horizontal_flip_transform)
if n_rotations > 0:
rotation_transform = transforms.Lambda(
lambda x: rotation_90_transform(x, num_rotations=n_rotations)
)
augment_transforms_list.append(rotation_transform)
if do_random_brightness:
brightness_transform = transforms.Lambda(
lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2)
)
augment_transforms_list.append(brightness_transform)
return transforms.Compose(augment_transforms_list)