Source code for modelzoo.vision.pytorch.input.transforms
# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from inspect import getfullargspec, signature
import numpy as np
from PIL import Image
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode
__all__ = [
"create_transform",
]
SUPPORTED_TRANSFORMS = [
# Transforms on PIL Image and torch.*Tensor
"center_crop",
"color_jitter_with_prob",
"random_crop",
"random_grayscale",
"random_horizontal_flip",
"random_resized_crop",
"random_vertical_flip",
"resize",
# Transforms on torch.*Tensor only
"normalize",
"random_erase",
# Conversion transforms
"to_dtype",
"to_tensor",
# Automatic augmentation transforms
"autoaug",
"randaug",
"trivialaug",
# "augmix", # available from torchvision 0.13
# transforms on PIL image only
"resize_center_crop_pil_image",
]
def get_or_use_default(transform_spec, key, default_val):
name = transform_spec.get("name")
val = transform_spec.get(key)
if val is None:
logging.debug(
f"Transform {name}'s {key} parameter is not specified. "
f"Using default value {default_val}."
)
val = default_val
return val
[docs]def create_transform(transform_spec):
"""
Create the specified transform. For each transform, the parameter list (name
and default value) follows those in torchvision 0.12
(https://pytorch.org/vision/0.12/transforms.html)
Args:
name (str): name of the transform
args (dict): a dictionary of parameters used to initialize the transform.
Default is None.
"""
name = transform_spec["name"].lower()
if "interpolation" in transform_spec.keys():
transform_spec["interpolation"] = InterpolationMode(
transform_spec["interpolation"]
)
# Transforms on PIL Image and torch.*Tensor
if name == "center_crop":
return transforms.CenterCrop(size=transform_spec.get("size"))
elif name == "color_jitter_with_prob":
transform = transforms.ColorJitter(
brightness=get_or_use_default(transform_spec, "brightness", 0),
contrast=get_or_use_default(transform_spec, "contrast", 0),
saturation=get_or_use_default(transform_spec, "saturation", 0),
hue=get_or_use_default(transform_spec, "hue", 0),
)
color_jitter_prob = get_or_use_default(transform_spec, "p", 0)
if color_jitter_prob > 0:
transform = transforms.RandomApply([transform], p=color_jitter_prob)
return transform
elif name == "random_crop":
pad_if_needed = get_or_use_default(
transform_spec, "pad_if_needed", False
)
if pad_if_needed:
logging.info(
f"For RandomCrop, pad_if_needed is set to {pad_if_needed}, which "
f"is different from torchvision's default (False)."
)
fill = get_or_use_default(transform_spec, "fill", 0)
if fill != 0:
logging.info(
f"For RandomCrop, fill is set to {fill}, which is different "
f"from torchvision's default (0)."
)
padding_mode = get_or_use_default(
transform_spec, "padding_mode", "constant"
)
if padding_mode != "constant":
logging.info(
f"For RandomCrop, padding_mode is set to {padding_mode}, which "
f"is different from torchvision's default (`constant`)."
)
return transforms.RandomCrop(
size=transform_spec.get("size"),
padding=get_or_use_default(transform_spec, "padding", None),
pad_if_needed=pad_if_needed,
fill=fill,
padding_mode=padding_mode,
)
elif name == "random_grayscale":
return transforms.RandomGrayscale(
p=get_or_use_default(transform_spec, "p", 0.1),
)
elif name == "random_horizontal_flip":
return transforms.RandomHorizontalFlip(
p=get_or_use_default(transform_spec, "p", 0.5),
)
elif name == "random_resized_crop":
return transforms.RandomResizedCrop(
size=transform_spec.get("size"),
scale=get_or_use_default(transform_spec, "scale", (0.08, 1.0)),
ratio=get_or_use_default(
transform_spec, "ratio", (0.75, 4.0 / 3.0)
),
interpolation=get_or_use_default(
transform_spec, "interpolation", InterpolationMode.BILINEAR
),
)
elif name == "random_vertical_flip":
return transforms.RandomVerticalFlip(
p=get_or_use_default(transform_spec, "p", 0.5),
)
elif name == "resize":
return transforms.Resize(
size=transform_spec.get("size"),
interpolation=get_or_use_default(
transform_spec, "interpolation", InterpolationMode.BILINEAR
),
max_size=get_or_use_default(transform_spec, "max_size", None),
antialias=get_or_use_default(transform_spec, "antialias", None),
)
# Transforms on torch.*Tensor only
elif name == "normalize":
inplace = get_or_use_default(transform_spec, "inplace", False)
if inplace:
logging.info(
f"For Normalize, inplace is set to {inplace}, which "
f"is different from torchvision's default (False)."
)
return transforms.Normalize(
mean=transform_spec.get("mean"),
std=transform_spec.get("std"),
inplace=inplace,
)
elif name == "random_erase":
return transforms.RandomErasing(
p=get_or_use_default(transform_spec, "p", 0.5),
scale=get_or_use_default(transform_spec, "scale", (0.02, 0.33)),
ratio=get_or_use_default(transform_spec, "ratio", (0.3, 3.3)),
value=get_or_use_default(transform_spec, "value", 0),
inplace=get_or_use_default(transform_spec, "inplace", False),
)
# Conversion transforms
elif name == "to_dtype":
return LambdaWithParam(dtype_transform, transform_spec.get("mp_type"))
elif name == "to_tensor":
return transforms.ToTensor()
elif name == "resize_center_crop_pil_image":
# Used to DiT model image transform
return LambdaWithParam(
resize_center_crop_pil_image, transform_spec.get("size")
)
# Automatic augmentation transforms
elif name == "autoaug":
policy = get_or_use_default(transform_spec, "policy", "imagenet")
interpolation = get_or_use_default(
transform_spec, "interpolation", InterpolationMode.NEAREST
)
if interpolation != InterpolationMode.NEAREST:
logging.info(
f"For AutoAugment, interpolation is set to {interpolation}, which "
f"is different from torchvision's default (InterpolationMode.NEAREST)."
)
return autoaugment.AutoAugment(
policy=autoaugment.AutoAugmentPolicy(policy),
interpolation=interpolation,
fill=get_or_use_default(transform_spec, "fill", None),
)
elif name == "randaug":
magnitude = get_or_use_default(transform_spec, "magnitude", 9)
# magnitude_std and magnitude_max are extra parameters to keep
# consistent with timm's implementation: https://timm.fast.ai/RandAugment
magnitude_std = get_or_use_default(transform_spec, "magnitude_std", 0)
# If magnitude_std is > 0, we introduce some randomness in the fixed
# policy and sample magnitude from a normal distribution with mean
# `magnitude` and std-dev of `magnitude_std`.
if magnitude_std > 0:
# If magnitude_std is inf, we sample from a uniform distribution
logging.info(
f"RandAugment's magnitude_std={magnitude_std}. We will "
f"introduce some randomness in the usually fixed policy and "
f"sample from a distribution according to timm."
)
if magnitude_std == float("inf"):
magnitude = random.uniform(0, magnitude)
else:
magnitude = random.gauss(magnitude, magnitude_std)
upper_bound = get_or_use_default(transform_spec, "magnitude_max", 10)
if magnitude > upper_bound:
logging.info(
f"Capping magnitude for RandAugment from {magnitude} to "
f"magnitude_max={upper_bound} following timm."
)
magnitude = max(0.0, min(magnitude, upper_bound))
return autoaugment.RandAugment(
num_ops=get_or_use_default(transform_spec, "num_ops", 2),
magnitude=magnitude,
num_magnitude_bins=get_or_use_default(
transform_spec, "num_magnitude_bins", 31
),
interpolation=get_or_use_default(
transform_spec, "interpolation", InterpolationMode.NEAREST
),
fill=get_or_use_default(transform_spec, "fill", None),
)
elif name == "trivialaug":
return autoaugment.TrivialAugmentWide(
num_magnitude_bins=get_or_use_default(
transform_spec, "num_magnitude_bins", 31
),
interpolation=get_or_use_default(
transform_spec, "interpolation", InterpolationMode.NEAREST
),
fill=get_or_use_default(transform_spec, "fill", None),
)
# Only available starting torchvision 0.13
# elif name == "augmix":
# return autoaugment.AugMix(
# severity=get_or_use_default(transform_spec, "severity", 3),
# mixture_width=get_or_use_default(transform_spec, "mixture_width", 3),
# chain_depth=get_or_use_default(transform_spec, "chain_depth", -1),
# alpha=get_or_use_default(transform_spec, "alpha", 1.0),
# all_ops=get_or_use_default(transform_spec, "all_ops", True),
# interpolation=get_or_use_default(transform_spec, "interpolation", InterpolationMode.BILINEAR),
# fill=get_or_use_default(transform_spec, "fill", None),
# )
else:
raise ValueError(f"Unsupported or invalid transform name: {name}.")
def dtype_transform(x, mp_type, *args, **kwargs):
return x.to(mp_type)
def resize_center_crop_pil_image(pil_image, image_size, *args, **kwargs):
"""
Using same cropping mechanism as source DiT repo
https://github.com/facebookresearch/DiT/blob/main/train.py#L85
Based on Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
assert (
image_size[0] == image_size[1]
), f"This transform is supported only for resizing to square shapes"
image_size = image_size[0]
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(
arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
)
class LambdaWithParam(object):
def __init__(self, lambd, *args, **kwargs):
assert callable(lambd), (
repr(type(lambd).__name__) + " object is not callable"
)
self.lambd = lambd
self.args = args
self.kwargs = kwargs
ll_sig = getfullargspec(lambd)
if not ll_sig.varargs or not ll_sig.varkw:
raise TypeError(
"User-defined lambda transform function must have signature: "
"function(img, positional args, *args, **kwargs). Instead, "
f"got function{str(signature(lambd))}."
)
def __call__(self, img):
return self.lambd(img, *self.args, **self.kwargs)
def __repr__(self):
return self.__class__.__name__ + '(args={0}, kwargs={1})'.format(
self.args, self.kwargs
)