vision.pytorch.input.classification package#
Submodules#
vision.pytorch.input.classification.dataset_factory module#
- class vision.pytorch.input.classification.dataset_factory.Processor#
Bases:
object
- __init__(params)#
- check_split_valid(split)#
- create_dataloader(dataset, is_training=False)#
- create_dataset(use_training_transforms=True, split='train')#
- create_shuffled_idx(num_sample, rng)#
- process_transform(use_training_transforms=True)#
- split_dataset(dataset, split_percent, seed)#
- class vision.pytorch.input.classification.dataset_factory.VisionSubset#
Bases:
torch.utils.data.Subset
- __init__(dataset, indices)#
- set_transforms(transforms=None, transform=None, target_transform=None)#
- transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
- transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
- truncate_to_idx(new_length)#
vision.pytorch.input.classification.mixup module#
Mixup and CutMix
This is borrowed from the PyTorch repo: https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
- class vision.pytorch.input.classification.mixup.RandomCutmix#
Bases:
torch.nn.Module
Randomly apply Cutmix to the provided batch and targets. The class implements the data augmentations as described in the paper “CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features”. :param num_classes: number of classes used for one-hot encoding. :type num_classes: int :param p: probability of the batch being transformed. Default value is 0.5. :type p: float :param alpha: hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
- Parameters
inplace (bool) – boolean to make this transform inplace. Default set to False.
- __init__(num_classes, p=0.5, alpha=1.0, inplace=False)#
- forward(batch, target)#
- Parameters
batch (Tensor) – Float tensor of size (B, C, H, W)
target (Tensor) – Integer tensor of size (B, )
- Returns
Randomly transformed batch.
- Return type
Tensor
- class vision.pytorch.input.classification.mixup.RandomMixup#
Bases:
torch.nn.Module
Randomly apply Mixup to the provided batch and targets. The class implements the data augmentations as described in the paper “mixup: Beyond Empirical Risk Minimization”. :param num_classes: number of classes used for one-hot encoding. :type num_classes: int :param p: probability of the batch being transformed. Default value is 0.5. :type p: float :param alpha: hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
- Parameters
inplace (bool) – boolean to make this transform inplace. Default set to False.
- __init__(num_classes, p=0.5, alpha=1.0, inplace=False)#
- forward(batch, target)#
- Parameters
batch (Tensor) – Float tensor of size (B, C, H, W)
target (Tensor) – Integer tensor of size (B, )
- Returns
Randomly transformed batch.
- Return type
Tensor
vision.pytorch.input.classification.preprocessing module#
- vision.pytorch.input.classification.preprocessing.get_preprocess_transform(image_size, params, use_training_transforms=True)#
vision.pytorch.input.classification.sampler module#
- class vision.pytorch.input.classification.sampler.RepeatedAugSampler#
Bases:
torch.utils.data.Sampler
Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU). Heavily based on ‘torch.utils.data.DistributedSampler’.
This is borrowed from the DeiT Repo: https://github.com/facebookresearch/deit/blob/main/samplers.py
- __init__(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, num_repeats=3, batch_size=256)#
- set_epoch(epoch)#
vision.pytorch.input.classification.utils module#
- vision.pytorch.input.classification.utils.create_preprocessing_params_with_defaults(params)#
Preprocessing params for augmentations