Source code for cerebras.modelzoo.data.vision.classification.data.retinopathy

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import os

import cv2
import numpy as np
from PIL import Image
from torchvision.datasets.utils import verify_str_arg
from torchvision.datasets.vision import VisionDataset

from cerebras.modelzoo.data.vision.classification.dataset_factory import (
    Processor,
    VisionSubset,
)


[docs]class DiabeticRetinopathy(VisionDataset): _TARGET_PIXELS = { "original": None, "1M": 1000000, "250K": 250000, "btgraham-300": 300, } _LABEL_FILE = { "train": "trainLabels.csv", "val": "retinopathy_solution.csv", "test": "retinopathy_solution.csv", }
[docs] def __init__( self, root, split="train", config="btgraham-300", use_heavy_train_aug=False, transform=None, target_transform=None, ): super().__init__( os.path.join(root, "retinopathy"), transform=transform, target_transform=target_transform, ) if not os.path.exists(self.root): raise RuntimeError( "Dataset not found. Download from " "https://www.kaggle.com/c/diabetic-retinopathy-detection/data" ) self._split = verify_str_arg(split, "split", ("train", "val", "test")) verify_str_arg(config, "config", self._TARGET_PIXELS.keys()) self._use_heavy_aug = use_heavy_train_aug if config == "btgraham-300": self.img_processing = _btgraham_processing else: self.img_processing = _resize_image_if_necessary self.target_pixels = self._TARGET_PIXELS[config] self._image_files = [] self.labels = [] label_file = os.path.join(self.root, self._LABEL_FILE[split]) if os.path.exists(label_file): with open(label_file, "r") as f: content = csv.reader(f, delimiter=",") next(content, None) # skip the headers for line in content: if split == "train": self._image_files.append(line[0]) self.labels.append(float(line[1])) else: assert len(line) == 3 if (split == "val" and line[2] == "Public") or ( split == "test" and line[2] == "Private" ): self._image_files.append(line[0]) self.labels.append(float(line[1])) else: if split == "train" or split == "val": raise RuntimeError( f"Missing label file {label_file} for {split} split" ) for img_file in os.listdir(os.path.join(self.root), split): if img_file.endswith(".jpeg"): self._image_files.append(img_file[:-5]) self.labels = [-1] * len(self._image_files)
def __getitem__(self, idx): image_file = self._image_files[idx] split_dir = "train" if self._split == "train" else "test" image = cv2.imread( os.path.join(self.root, split_dir, f"{image_file}.jpeg"), flags=3 ) image = self.img_processing(image, self.target_pixels).astype("uint8") # convert the color from BGR (cv2 format) to RGB (PIL format) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self._use_heavy_aug: image = _heavy_data_augmentation(np.array(image)) image = Image.fromarray(image.astype("uint8"), 'RGB') label = self.labels[idx] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label def __len__(self): return len(self._image_files)
def _btgraham_processing(image, target_pixels, crop_to_radius=False): """Process an image as the winner of the 2015 Kaggle competition.""" image = _scale_radius_size(image, target_radius_size=target_pixels) image = _subtract_local_average(image, target_radius_size=target_pixels) image = _mask_and_crop_to_radius( image, target_radius_size=target_pixels, radius_mask_ratio=0.9, crop_to_radius=crop_to_radius, ) return image def _resize_image_if_necessary(image, target_pixels=None): """Resize an image to have (roughly) the given number of target pixels.""" if target_pixels is None: return image # Get image height and width. height, width, _ = image.shape actual_pixels = height * width if actual_pixels > target_pixels: factor = np.sqrt(target_pixels / actual_pixels) image = cv2.resize(image, dsize=None, fx=factor, fy=factor) return image def _scale_radius_size(image, target_radius_size): x = image[image.shape[0] // 2, :, :].sum(axis=1) r = (x > x.mean() / 10.0).sum() / 2.0 if r < 1.0: # Some images in the dataset are corrupted, causing the radius heuristic # to fail. In these cases, just assume that the radius is the height of # the original image. r = image.shape[0] / 2.0 s = target_radius_size / r return cv2.resize(image, dsize=None, fx=s, fy=s) def _subtract_local_average(image, target_radius_size): image_blurred = cv2.GaussianBlur(image, (0, 0), target_radius_size / 30) image = cv2.addWeighted(image, 4, image_blurred, -4, 128) return image def _mask_and_crop_to_radius( image, target_radius_size, radius_mask_ratio=0.9, crop_to_radius=False ): """Mask and crop image to the given radius ratio.""" mask = np.zeros(image.shape) center = (image.shape[1] // 2, image.shape[0] // 2) radius = int(target_radius_size * radius_mask_ratio) cv2.circle( mask, center=center, radius=radius, color=(1, 1, 1), thickness=-1 ) image = image * mask + (1 - mask) * 128 if crop_to_radius: x_max = min(image.shape[1] // 2 + radius, image.shape[1]) x_min = max(image.shape[1] // 2 - radius, 0) y_max = min(image.shape[0] // 2 + radius, image.shape[0]) y_min = max(image.shape[0] // 2 - radius, 0) image = image[y_min:y_max, x_min:x_max, :] return image def _sample_heavy_data_augmentation_parameters(): # Scale image +/- 10%. s = np.random.uniform(-0.1, 0.1) # Rotate image [0, 2pi). a = np.random.uniform(0.0, 2.0 * np.pi) # Vertically shear image +/- 20%. b = np.random.uniform(-0.2, 0.2) + a # Horizontal and vertial flipping. flip = [-1.0, 1.0] np.random.shuffle(flip) hf = flip[0] np.random.shuffle(flip) vf = flip[0] # Relative x,y translation. dx = np.random.uniform(-0.1, 0.1) dy = np.random.uniform(-0.1, 0.1) return s, a, b, hf, vf, dx, dy def _heavy_data_augmentation(image): height = float(image.shape[0]) width = float(image.shape[1]) # sample data augmentation parameters s, a, b, hf, vf, dx, dy = _sample_heavy_data_augmentation_parameters() # Rotation + scale. c00 = (1 + s) * np.cos(a) c01 = (1 + s) * np.sin(a) c10 = (s - 1) * np.sin(b) c11 = (1 - s) * np.cos(b) # Horizontal and vertial flipping. c00 = c00 * hf c01 = c01 * hf c10 = c10 * vf c11 = c11 * vf # Convert x,y translation to absolute values. dx = width * dx dy = height * dy # Convert affine matrix to TF's transform. Matrix is applied w.r.t. the # center of the image. cy = height / 2.0 cx = width / 2.0 affine_matrix = np.array( [ [c00, c01, (1.0 - c00) * cx - c01 * cy + dx], [c10, c11, (1.0 - c11) * cy - c10 * cx + dy], [0.0, 0.0, 1.0], ], dtype=float, ) affine_inv_matrix = np.linalg.inv(affine_matrix) # Since background is grey in these configs, put in pixels in [-1, 1] # range to avoid artifacts from the affine transformation. image = image.astype(float) image = (image / 127.5) - 1.0 # Apply the affine transformation. image = np.matmul(image, affine_inv_matrix) # Put pixels back to [0, 255] range and cast to uint8, since this is what # our preprocessing pipeline usually expects. image = (1.0 + image) * 127.5 return image
[docs]class DiabeticRetinopathyProcessor(Processor):
[docs] def __init__(self, params): super().__init__(params) self.allowable_split = ["train", "val", "test"] self.num_classes = 5
def create_dataset( self, use_training_transforms=True, config="btgraham-300", split="train" ): self.check_split_valid(split) transform, target_transform = self.process_transform( use_training_transforms ) dataset = DiabeticRetinopathy( root=self.data_dir, split=split, config=config, transform=transform, target_transform=target_transform, ) return dataset def create_vtab_dataset( self, use_heavy_train_aug=False, use_1k_sample=True, seed=42 ): train_transform, train_target_transform = self.process_transform( use_training_transforms=True ) eval_transform, eval_target_transform = self.process_transform( use_training_transforms=False ) config = "btgraham-300" train_set = DiabeticRetinopathy( root=self.data_dir, split="train", use_heavy_train_aug=use_heavy_train_aug, config=config, transform=train_transform, target_transform=train_target_transform, ) val_set = DiabeticRetinopathy( root=self.data_dir, split="val", config=config, transform=eval_transform, target_transform=eval_target_transform, ) test_set = DiabeticRetinopathy( root=self.data_dir, split="test", config=config, transform=eval_transform, target_transform=eval_target_transform, ) if use_1k_sample: rng = np.random.default_rng(seed) sample_idx = self.create_shuffled_idx(len(train_set), rng) train_set = VisionSubset(train_set, sample_idx[:800]) sample_idx = self.create_shuffled_idx(len(val_set), rng) val_set = VisionSubset(val_set, sample_idx[:200]) return train_set, val_set, test_set