Source code for cerebras.modelzoo.data.vision.segmentation.transforms.crop_and_pad_augmentations

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from: https://github.com/MIC-DKFZ/batchgenerators/
# blob/master/batchgenerators/augmentations/crop_and_pad_augmentations.py (commit id: 01f225d)

# Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
# and Applied Computer Vision Lab, Helmholtz Imaging Platform
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from builtins import range

import numpy as np


[docs]def center_crop(data, crop_size, seg=None): return crop(data, seg, crop_size, 0, 'center')
[docs]def get_lbs_for_random_crop(crop_size, data_shape, margins): """ :param crop_size: :param data_shape: (b,c,x,y(,z)) must be the whole thing! :param margins: :return: """ lbs = [] for i in range(len(data_shape) - 2): if data_shape[i + 2] - crop_size[i] - margins[i] > margins[i]: lbs.append( np.random.randint( margins[i], data_shape[i + 2] - crop_size[i] - margins[i] ) ) else: lbs.append((data_shape[i + 2] - crop_size[i]) // 2) return lbs
[docs]def get_lbs_for_center_crop(crop_size, data_shape): """ :param crop_size: :param data_shape: (b,c,x,y(,z)) must be the whole thing! :return: """ lbs = [] for i in range(len(data_shape) - 2): lbs.append((data_shape[i + 2] - crop_size[i]) // 2) return lbs
[docs]def crop( data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", pad_mode='constant', pad_kwargs={'constant_values': 0}, pad_mode_seg='constant', pad_kwargs_seg={'constant_values': 0}, ): """ crops data and seg (seg may be None) to crop_size. Whether this will be achieved via center or random crop is determined by crop_type. Margin will be respected only for random_crop and will prevent the crops form being closer than margin to the respective image border. crop_size can be larger than data_shape - margin -> data/seg will be padded with zeros in that case. margins can be negative -> results in padding of data/seg followed by cropping with margin=0 for the appropriate axes :param data: b, c, x, y(, z) :param seg: :param crop_size: :param margins: distance from each border, can be int or list/tuple of ints (one element for each dimension). Can be negative (data/seg will be padded if needed) :param crop_type: random or center :return: """ if not isinstance(data, (list, tuple, np.ndarray)): raise TypeError("data has to be either a numpy array or a list") data_shape = tuple([len(data)] + list(data[0].shape)) data_dtype = data[0].dtype dim = len(data_shape) - 2 if seg is not None: seg_shape = tuple([len(seg)] + list(seg[0].shape)) seg_dtype = seg[0].dtype if not isinstance(seg, (list, tuple, np.ndarray)): raise TypeError("data has to be either a numpy array or a list") assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), ( "data and seg must have the same spatial " "dimensions. Data: %s, seg: %s" % (str(data_shape), str(seg_shape)) ) if type(crop_size) not in (tuple, list, np.ndarray): crop_size = [crop_size] * dim else: assert len(crop_size) == len(data_shape) - 2, ( "If you provide a list/tuple as center crop make sure it has the same dimension as your " "data (2d/3d)" ) if not isinstance(margins, (np.ndarray, tuple, list)): margins = [margins] * dim data_return = np.zeros( [data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype ) if seg is not None: seg_return = np.zeros( [seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype ) else: seg_return = None for b in range(data_shape[0]): data_shape_here = [data_shape[0]] + list(data[b].shape) if seg is not None: seg_shape_here = [seg_shape[0]] + list(seg[b].shape) if crop_type == "center": lbs = get_lbs_for_center_crop(crop_size, data_shape_here) elif crop_type == "random": lbs = get_lbs_for_random_crop(crop_size, data_shape_here, margins) else: raise NotImplementedError( "crop_type must be either center or random" ) need_to_pad = [[0, 0]] + [ [ abs(min(0, lbs[d])), abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d]))), ] for d in range(dim) ] # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed ubs = [ min(lbs[d] + crop_size[d], data_shape_here[d + 2]) for d in range(dim) ] lbs = [max(0, lbs[d]) for d in range(dim)] slicer_data = [slice(0, data_shape_here[1])] + [ slice(lbs[d], ubs[d]) for d in range(dim) ] data_cropped = data[b][tuple(slicer_data)] if seg_return is not None: slicer_seg = [slice(0, seg_shape_here[1])] + [ slice(lbs[d], ubs[d]) for d in range(dim) ] seg_cropped = seg[b][tuple(slicer_seg)] if any([i > 0 for j in need_to_pad for i in j]): data_return[b] = np.pad( data_cropped, need_to_pad, pad_mode, **pad_kwargs ) if seg_return is not None: seg_return[b] = np.pad( seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg ) else: data_return[b] = data_cropped if seg_return is not None: seg_return[b] = seg_cropped return data_return, seg_return
[docs]def random_crop(data, seg=None, crop_size=128, margins=[0, 0, 0]): return crop(data, seg, crop_size, margins, 'random')