Source code for cerebras.modelzoo.data.vision.segmentation.preprocessing_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torchvision import transforms


[docs]def normalize_tensor_transform(img, normalize_data_method): """ Function to normalize img :params img: Input torch.Tensor of any shape :params normalize_data_method: One of "zero_centered" "zero_one" "standard_score" """ if normalize_data_method is None: pass elif normalize_data_method == "zero_centered": img = torch.div(img, 127.5) - 1 elif normalize_data_method == "zero_one": img = torch.div(img, 255.0) elif normalize_data_method == "standard_score": img = (img - img.mean()) / img.std() else: raise ValueError( f"Invalid arg={normalize_data_method} passed to `normalize_data_method`" ) return img
[docs]def adjust_brightness_transform(img, p, delta): """ Function equivalent to `tf.image.adjust_brightness`, but executed probabilistically. :params img: Input torch.Tensor of any shape :params p: Integer representing probability :params delta: Float value representing the value by which img Tensor is increased or decreased. """ if (torch.rand(1) > p).item(): img = torch.add(img, delta) return img
[docs]def rotation_90_transform(img, num_rotations): """ Function equivalent to `tf.image.rot90` Rotates img in counter clockwise direction :params img: torch.Tensor of shape (C, H, W) or (H, W) :params num_rotations: int value representing number of counter clock-wise rotations of img """ if len(img.shape) == 3: # If image of type (C, H, W), rotate along H, W # Rotate in counter-clockwise direction dims = [1, 2] else: dims = [0, 1] img = torch.rot90(img, k=num_rotations, dims=dims) return img
[docs]def resize_image_with_crop_or_pad_transform(img, target_height, target_width): """ Function equivalent to `tf.image.resize_with_crop_or_pad` :params img: torch.Tensor of shape (C, H, W) or (H, W) :params target_height: int value representing output image height :params target_width: int value representing output image width :returns torch.Tensor of shape (C, target_height, target_width) """ def _pad_image(img): """ Pad image till it reaches target_height and target_width """ img_shape = img.shape img_width = img_shape[-1] img_height = img_shape[-2] lft_rgt_pad = max((target_width - img_width) // 2, 0) top_bot_pad = max((target_height - img_height) // 2, 0) excess_right_pad = target_width - img_width - 2 * lft_rgt_pad excess_bot_pad = target_height - img_height - 2 * top_bot_pad pad = [ lft_rgt_pad, lft_rgt_pad + excess_right_pad, top_bot_pad, top_bot_pad + excess_bot_pad, ] img = torch.nn.functional.pad(img, pad) return img def _crop_image(img): img_shape = img.shape # Crop only when necessary. CenterCrop pads if # crop dimensions are greater, hence taking min. crop_height = min(img_shape[-2], target_height) crop_width = min(img_shape[-1], target_width) img = transforms.CenterCrop((crop_height, crop_width))(img) return img cropped_img = _crop_image(img) padded_img = _pad_image(cropped_img) assert padded_img.shape[-1] == target_width assert padded_img.shape[-2] == target_height return padded_img
[docs]def tile_image_transform(img, target_height, target_width): """ Function to tile image to tgt_height and target_width If target_height < image_height: image is not tiled in this dimension. If target_width < image_width: image is not tiled in this dimension. :params img: input torch.Tensor of shape (C, H, W) :params target_height: int value representing output tiled image height :params target_width: int value representing output tiled image width :returns torch.Tensor of shape (C, target_height, target_width) """ assert len(img.shape) == 3 img_channels, img_height, img_width = img.shape tgt_img_shape = [img_channels, target_height, target_width] def _get_tiled_image(img, tgt_img_shape, axis): if tgt_img_shape[axis] <= img.shape[axis]: # No tiling since image already satisfies requirement return img else: diff = tgt_img_shape[axis] - img.shape[axis] q, r = divmod(diff, img.shape[axis]) temp_img = img for _ in range(q): temp_img = torch.concat((img, temp_img), axis=axis) if r > 0: if axis == 1: sliced_img = temp_img[:, :r, :] elif axis == 2: sliced_img = temp_img[:, :, :r] else: raise ValueError( f"Incorrect value of {axis} passed. Valid integers are 1, 2" ) temp_img = torch.concat((temp_img, sliced_img), axis=axis) return temp_img v_tiled_img = _get_tiled_image(img, tgt_img_shape=tgt_img_shape, axis=1) tiled_img = _get_tiled_image( v_tiled_img, tgt_img_shape=tgt_img_shape, axis=2 ) return tiled_img