Source code for

# Copyright 2022 Cerebras Systems.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

[docs]class NoiseGenerator(torch.nn.Module):
[docs] def __init__(self, width, height, channels, num_diffusion_steps, seed=None): super(NoiseGenerator, self).__init__() self.height = height self.width = width self.channels = channels self.num_diffusion_steps = num_diffusion_steps
[docs] def forward(self, input, label): """ Args: :param input : Float tensor of size (B, C, H, W). :param label : Int tensor of size (B, ). Returns: A dict corresponding to the noisy images, ground truth noises and the timesteps corresponding to the scheduled noise variance with the following keys and shapes. "input": Tensor of shape (batch_size, C, H, W). This tensor is simply passed through. "label": Tensor of shape (batch_size, ) representing labels. This tensor is simply passed through. "diffusion_noise": Tensor of shape (batch_size, channels, height, width) represents diffusion noise to be applied "timestep": Tensor of shape (batch_size, ) that indicates the timesteps for each diffusion sample "vae_noise": Tensor of shape (batch_size, latent_channels, latent_height, latent_width) represents the noise sample to be used with reparametrization of VAE """ if input.ndim != 4: raise ValueError(f"Samples ndim should be 4. Got {input.ndim}") # reshaping to (batch_size, 1, ..., 1) for broadcasting batch_size = input.shape[0] timestep = torch.randint( self.num_diffusion_steps, size=(batch_size,), dtype=label.dtype ) noise_shape = (batch_size, self.channels, self.height, self.width) diffusion_noise = torch.randn(noise_shape, dtype=input.dtype).to( input.device ) vae_noise_shape = (batch_size, self.channels, self.height, self.width) vae_noise_sample = torch.randn(vae_noise_shape, dtype=input.dtype) return { "input": input, "label": label, "diffusion_noise": diffusion_noise, "timestep": timestep, "vae_noise": vae_noise_sample, }
def __repr__(self): return ( f"{self.__class__.__name__}(" f"schedule_name={self.schedule_name}" f", num_diffusion_steps={self.num_diffusion_steps}" f")" )
[docs]class LabelDropout(torch.nn.Module):
[docs] def __init__(self, dropout_prob, num_classes): super(LabelDropout, self).__init__() assert dropout_prob > 0 self.dropout_prob = dropout_prob self.num_classes = num_classes
def token_drop(self, label): drop_ids = torch.rand(label.shape[0]) < self.dropout_prob return drop_ids def forward(self, image, label): drop_ids = self.token_drop(label) label = torch.where( drop_ids, torch.tensor(self.num_classes, dtype=label.dtype), label ) return image, label