vision.pytorch.input.classification package#

Submodules#

vision.pytorch.input.classification.dataset_factory module#

class vision.pytorch.input.classification.dataset_factory.Processor#

Bases: object

__init__(params)#
check_split_valid(split)#
create_dataloader(dataset, is_training=False)#
create_dataset(use_training_transforms=True, split='train')#
create_shuffled_idx(num_sample, rng)#
process_transform(use_training_transforms=True)#
split_dataset(dataset, split_percent, seed)#
class vision.pytorch.input.classification.dataset_factory.VisionSubset#

Bases: torch.utils.data.Subset

__init__(dataset, indices)#
set_transforms(transforms=None, transform=None, target_transform=None)#
transforms (callable, optional): A function/transforms that takes in

an image and a label and returns the transformed versions of both.

transform (callable, optional): A function/transform that takes in an PIL image

and returns a transformed version. E.g, transforms.RandomCrop

target_transform (callable, optional): A function/transform that takes in the

target and transforms it.

truncate_to_idx(new_length)#

vision.pytorch.input.classification.mixup module#

Mixup and CutMix

This is borrowed from the PyTorch repo: https://github.com/pytorch/vision/blob/main/references/classification/transforms.py

class vision.pytorch.input.classification.mixup.RandomCutmix#

Bases: torch.nn.Module

Randomly apply Cutmix to the provided batch and targets. The class implements the data augmentations as described in the paper “CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features”. :param num_classes: number of classes used for one-hot encoding. :type num_classes: int :param p: probability of the batch being transformed. Default value is 0.5. :type p: float :param alpha: hyperparameter of the Beta distribution used for cutmix.

Default value is 1.0.

Parameters

inplace (bool) – boolean to make this transform inplace. Default set to False.

__init__(num_classes, p=0.5, alpha=1.0, inplace=False)#
forward(batch, target)#
Parameters
  • batch (Tensor) – Float tensor of size (B, C, H, W)

  • target (Tensor) – Integer tensor of size (B, )

Returns

Randomly transformed batch.

Return type

Tensor

class vision.pytorch.input.classification.mixup.RandomMixup#

Bases: torch.nn.Module

Randomly apply Mixup to the provided batch and targets. The class implements the data augmentations as described in the paper “mixup: Beyond Empirical Risk Minimization”. :param num_classes: number of classes used for one-hot encoding. :type num_classes: int :param p: probability of the batch being transformed. Default value is 0.5. :type p: float :param alpha: hyperparameter of the Beta distribution used for mixup.

Default value is 1.0.

Parameters

inplace (bool) – boolean to make this transform inplace. Default set to False.

__init__(num_classes, p=0.5, alpha=1.0, inplace=False)#
forward(batch, target)#
Parameters
  • batch (Tensor) – Float tensor of size (B, C, H, W)

  • target (Tensor) – Integer tensor of size (B, )

Returns

Randomly transformed batch.

Return type

Tensor

vision.pytorch.input.classification.preprocessing module#

vision.pytorch.input.classification.preprocessing.get_preprocess_transform(image_size, params, use_training_transforms=True)#

vision.pytorch.input.classification.sampler module#

class vision.pytorch.input.classification.sampler.RepeatedAugSampler#

Bases: torch.utils.data.Sampler

Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU). Heavily based on ‘torch.utils.data.DistributedSampler’.

This is borrowed from the DeiT Repo: https://github.com/facebookresearch/deit/blob/main/samplers.py

__init__(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, num_repeats=3, batch_size=256)#
set_epoch(epoch)#

vision.pytorch.input.classification.utils module#

vision.pytorch.input.classification.utils.create_preprocessing_params_with_defaults(params)#

Preprocessing params for augmentations

Module contents#