Source code for cerebras.modelzoo.data.vision.classification.data.dsprites

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from functools import partial

import h5py
import numpy as np
from PIL import Image
from torchvision import transforms
from torchvision.datasets.vision import VisionDataset

from cerebras.modelzoo.data.vision.classification.dataset_factory import (
    Processor,
)


[docs]class DSprites(VisionDataset): """ dSprites is a dataset of 2D shapes procedurally generated from 6 ground truth independent latent factors. These factors are color, shape, scale, rotation, x and y positions of a sprite. All possible combinations of these latents are present exactly once, generating N = 737280 total images. ### Latent factor values * Color: white * Shape: square, ellipse, heart * Scale: 6 values linearly spaced in [0.5, 1] * Orientation: 40 values in [0, 2 pi] * Position X: 32 values in [0, 1] * Position Y: 32 values in [0, 1] We varied one latent at a time (starting from Position Y, then Position X, etc), and sequentially stored the images in fixed order. Hence the order along the first dimension is fixed and allows you to map back to the value of the latents corresponding to that image. We chose the latents values deliberately to have the smallest step changes while ensuring that all pixel outputs were different. No noise was added. """ _file = "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5"
[docs] def __init__( self, root, transform=None, target_transform=None, ): super().__init__( os.path.join(root, "dsprites"), transform=transform, target_transform=target_transform, ) if not os.path.exists(self.root): raise RuntimeError( f"Dataset not found. Download from " f"https://github.com/deepmind/dsprites-dataset/blob/master/{self._file}" ) with h5py.File(os.path.join(self.root, self._file), "r") as fx: self.length = len(fx["imgs"]) self.h5dataset = None self.images = None self.labels = None
def __getitem__(self, index): # Workaround so that dataset is pickleable and allow for multiprocessing # See discussion: https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16 if self.h5dataset is None: self.h5dataset = h5py.File(os.path.join(self.root, self._file), "r") self.images = self.h5dataset["imgs"] self.labels = self.h5dataset["latents"]["classes"] # image has shape (64, 64), expand and tile to (64, 64, 3) img = np.tile(np.expand_dims(self.images[index] * 255, -1), (1, 1, 3)) img = Image.fromarray(img.astype('uint8'), 'RGB') target = self.labels[index] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return self.length
def _process_predicted_attribute_by_index( label_index, class_division_factor, label ): target = label[label_index] return np.floor(target / class_division_factor)
[docs]class DSpritesProcessor(Processor): _TASK_DICT = { "label_x_position": { "preprocess_fn": partial(_process_predicted_attribute_by_index, 4), "num_classes": 32, }, "label_orientation": { "preprocess_fn": partial(_process_predicted_attribute_by_index, 3), "num_classes": 40, }, }
[docs] def __init__(self, params): super().__init__(params) self.allowable_split = ["train"] self.allowable_task = self._TASK_DICT.keys() self.num_classes = 16
def create_dataset(self, use_training_transforms=True, split="train"): self.check_split_valid(split) transform, target_transform = self.process_transform( use_training_transforms ) dataset = DSprites( root=self.data_dir, transform=transform, target_transform=target_transform, ) return dataset def create_vtab_dataset( self, task="label_x_position", num_classes=16, use_1k_sample=True, seed=42, ): if task not in self.allowable_task: raise ValueError( f"Task {task} is not supported, choose from " f"{self.allowable_task} instead" ) num_original_classes = self._TASK_DICT[task]["num_classes"] if num_classes is None: num_classes = num_original_classes if ( not isinstance(num_classes, int) or num_classes <= 1 or (num_classes > num_original_classes) ): raise ValueError( f"The number of classes should be None or in " f"[2, {num_original_classes}" ) class_division_factor = float(num_original_classes) / num_classes target_transform = partial( self._TASK_DICT[task]["preprocess_fn"], class_division_factor ) train_transform, train_tgt_transform = self.process_transform( use_training_transforms=True ) eval_transform, eval_tgt_transform = self.process_transform( use_training_transforms=False ) train_target_transform = transforms.Compose( [target_transform, train_tgt_transform] ) eval_target_transform = transforms.Compose( [target_transform, eval_tgt_transform] ) dataset = DSprites( root=self.data_dir, transform=None, ) # DSprites only comes with a training set. Therefore, the training, # validation, and test sets are split out of the original training set. # By default, 80% is used as a new training split, 10% is used for # validation, and 10% is used for testing. split_percent = [80, 10, 10] train_set, val_set, test_set = self.split_dataset( dataset, split_percent, seed ) if use_1k_sample: train_set.truncate_to_idx(800) val_set.truncate_to_idx(200) train_set.set_transforms( transform=train_transform, target_transform=train_target_transform ) val_set.set_transforms( eval_transform, target_transform=eval_target_transform ) test_set.set_transforms( eval_transform, target_transform=eval_target_transform ) return train_set, val_set, test_set