# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from modelzoo.common.pytorch.run_utils import half_dtype_instance
from modelzoo.vision.pytorch.dit.layers.schedulers import (
get_named_beta_schedule,
)
[docs]def index(arr, timestep):
return torch.index_select(arr, 0, timestep.long())
[docs]class GaussianDiffusion(torch.nn.Module):
"""Generate noisy images via Gaussian diffusion.
The class implements the noising process as described in Step 5 of Algorithm 1
in the paper
`"Denoising Diffusion Probabilistic Models` <https://arxiv.org/abs/2006.11239>`.
"""
[docs] def __init__(
self,
num_diffusion_steps,
schedule_name,
seed=None,
beta_start=0.0001,
beta_end=0.02,
):
"""
:param (int) num_diffusion_steps: Number of diffusion steps.
:param (float) beta_start: Minimum variance for generated Gaussian noise.
:param (float) beta_end: Maximum variance for generated Gaussian noise.
:param (int) seed: Random seed for reproducibility.
:param (float) beta_start: Initial value of variance schedule i.e beta_1
(default value according to Ho et al https://arxiv.org/pdf/2006.11239.pdf: Section 4)
:param (float) beta_end: Final value of variance schedule i.e beta_T
(default value according to Ho et al https://arxiv.org/pdf/2006.11239.pdf: Section 4)
"""
super().__init__()
if num_diffusion_steps <= 0:
raise ValueError("Number of diffusion steps must be positive.")
if seed is not None:
torch.manual_seed(seed)
self.num_diffusion_steps = num_diffusion_steps
self.schedule_name = schedule_name
self.betas = get_named_beta_schedule(
schedule_name,
self.num_diffusion_steps,
beta_start=beta_start,
beta_end=beta_end,
)
assert self.betas.dim() == 1, "betas must be 1-D"
assert torch.all(torch.logical_and(self.betas > 0, self.betas <= 1))
alphas = 1.0 - self.betas
alphas_cum_prod = torch.cumprod(alphas, dim=0)
self.sqrt_alphas_cum_prod = torch.nn.Parameter(
torch.sqrt(alphas_cum_prod).to(torch.float32), requires_grad=False,
)
self.sqrt_one_minus_alphas_cum_prod = torch.nn.Parameter(
torch.sqrt(1 - alphas_cum_prod).to(torch.float32),
requires_grad=False,
)
[docs] def forward(self, latent, noise, timestep):
"""Lookup alpha-related constants and create noised sample
Args:
:param latent (Tensor): Float tensor of size (B, C, H, W).
Returns:
A tuple corresponding to the noisy images, ground truth noises and
the timesteps corresponding to the scheduled noise variance.
"""
if latent.ndim != 4:
raise ValueError(f"Samples ndim should be 4. Got {latent.ndim}")
sqrt_alpha_prod = extract(
self.sqrt_alphas_cum_prod, timestep, noise.shape
)
sqrt_one_minus_alpha_prod = extract(
self.sqrt_one_minus_alphas_cum_prod, timestep, noise.shape
)
noisy_samples = (
sqrt_alpha_prod * latent + sqrt_one_minus_alpha_prod * noise
)
return noisy_samples.to(half_dtype_instance.half_dtype)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"schedule_name={self.schedule_name}"
f", num_diffusion_steps={self.num_diffusion_steps}"
f")"
)