# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# Based on HuggingFace diffusers/schedulers/scheduling_ddpm.py and
# https://github.com/facebookresearch/DiT/blob/main/diffusion/gaussian_diffusion.py
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from modelzoo.vision.pytorch.dit.layers.schedulers import (
get_named_beta_schedule,
)
from modelzoo.vision.pytorch.dit.samplers.sampler_utils import (
set_sampling_timesteps,
threshold_sample,
)
from modelzoo.vision.pytorch.dit.samplers.SamplerBase import SamplerBase
[docs]@dataclass
class DDPMSamplerOutput:
"""
Output class for the sampler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
[docs]class DDPMSampler(SamplerBase):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
and https://arxiv.org/pdf/2102.09672.pdf
Args:
num_diffusion_steps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
schedule_name (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
Choose from `linear`
clip_sample (`bool`, default `False`):
option to clip predicted sample for numerical stability.
set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method
(introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for
latent-space diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from
`learned_range`, `fixed_small`, `fixed_large`.
num_inference_steps (`str`): string containing comma-separated numbers,
indicating the step count per section.
For example, if there's 300 `num_diffusion_steps` and num_inference_steps=`10,15,20`
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
Can either pass `custom_timesteps` (or) `num_inference_steps`, but not both.
custom_timesteps (`List[int]`): List of timesteps to be used during sampling.
Should be in decreasing order.
Can either pass `custom_timesteps` (or) `num_inference_steps`, but not both.
"""
[docs] def __init__(
self,
num_diffusion_steps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
schedule_name: str = "linear",
clip_sample: bool = False,
set_alpha_to_one: bool = True,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
clip_sample_range: float = 1.0,
variance_type: str = "learned_range",
num_inference_steps: int = None,
custom_timesteps: List[int] = None,
):
# `num_train_timesteps` -> `num_diffusion_steps`
# beta_schedule -> schedule_name
self.num_diffusion_steps = num_diffusion_steps
self.clip_sample = clip_sample
self.thresholding = thresholding
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
self.sample_max_value = sample_max_value
self.clip_sample_range = clip_sample_range
self.betas = get_named_beta_schedule(
schedule_name,
self.num_diffusion_steps,
beta_start=beta_start,
beta_end=beta_end,
)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = (
torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
)
self.num_inference_steps = num_inference_steps
self.custom_timesteps = custom_timesteps
if self.num_inference_steps is None and self.custom_timesteps is None:
logging.warning(f"Setting `num_inference_steps` to `250")
self.num_inference_steps = 250
if self.num_inference_steps and self.custom_timesteps:
raise ValueError(
"Can only pass one of str `num_inference_steps` or `custom_timesteps` list."
)
if variance_type not in ["learned_range", "fixed_small", "fixed_large"]:
raise ValueError(
f"variance_type={variance_type} unsupported."
f"Supported values are `learned_range`, `fixed_small`, `fixed_large`"
)
self.variance_type = variance_type
self.set_timesteps(
num_diffusion_steps, num_inference_steps, custom_timesteps
)
[docs] def set_timesteps(
self, num_diffusion_steps, num_inference_steps, custom_timesteps
):
"""
Computes timesteps to be used during sampling
Args:
num_diffusion_steps (`int`): Total number of steps the model was trained on
num_inference_steps (`str`): string containing comma-separated numbers,
indicating the step count per section.
For example, if there's 300 `num_diffusion_steps` and num_inference_steps=`10,15,20`
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
Can either pass `custom_timesteps` (or) `num_inference_steps`, but not both.
custom_timesteps (`List[int]`): User specified list of timesteps to be used during sampling.
"""
self.timesteps = set_sampling_timesteps(
num_diffusion_steps=num_diffusion_steps,
num_inference_steps=str(num_inference_steps),
custom_timesteps=custom_timesteps,
)
def _get_variance(self, t, predicted_model_var_values=None):
"""
Variance calculation https://arxiv.org/pdf/2102.09672.pdf
Eqn (15)
Args:
t (`int`): Current timestep
predicted_model_var_values (`torch.Tensor`):
Model predicted values used in variance computation.
`υ` in Eqn 15
"""
prev_timestep = self.previous_timestep(t)
# Section 4 (Improved Sampling Speed) of https://arxiv.org/pdf/2102.09672.pdf
# For supporting sampling using subsequence of (1, 2, ....T)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.final_alpha_cumprod
)
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
# we always take the log of variance, so clamp it to ensure it's not 0 at t=0
variance = torch.clamp(variance, min=1e-20)
log_variance_clipped = torch.log(variance)
if self.variance_type == "fixed_small":
model_log_variance = log_variance_clipped
model_variance = variance
elif self.variance_type == "fixed_large":
# This differs from DiT repo where at t=0,
# `posterior_variance at t=1` is used to prevent log(0)
model_log_variance = torch.log(current_beta_t)
model_variance = current_beta_t
elif self.variance_type == "learned_range":
min_log = torch.log(variance)
max_log = torch.log(current_beta_t)
frac = (predicted_model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
return model_log_variance, model_variance
[docs] def step(
self,
pred_noise: torch.FloatTensor,
pred_var: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
) -> Union[DDPMSamplerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE.
Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
pred_noise (`torch.FloatTensor`): predicted eps output from learned diffusion model.
pred_var (`torch.FloatTensor`): Model predicted values
used in variance computation.`υ` in Eqn 15.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSamplerOutput class
Returns:
[`DDPMSamplerOutput` (with keys: `prev_sample`, `pred_original_sample`)]
if `return_dict` is True
(or) `tuple`.
When returning a tuple,
the first element is the `prev_sample` tensor and
second element is `pred_original_sample`
"""
t = timestep
prev_timestep = self.previous_timestep(t)
# Section 4 (Improved Sampling Speed) of https://arxiv.org/pdf/2102.09672.pdf
# For supporting sampling using subsequence of (1, 2, ....T)
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.final_alpha_cumprod
)
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alpha_prod_t)
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alpha_prod_t - 1)
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# pred_xstart = (sample - sqrt_beta_prod_t * predicted_noise) / sqrt_alpha_prod_t
pred_xstart = (
sqrt_recip_alphas_cumprod * sample
- sqrt_recipm1_alphas_cumprod * pred_noise
)
# 3. Clip or threshold "predicted x_0"
if self.thresholding:
pred_xstart = threshold_sample(
pred_xstart,
self.dynamic_thresholding_ratio,
self.sample_max_value,
)
elif self.clip_sample:
pred_xstart = pred_xstart.clamp(
-self.clip_sample_range, self.clip_sample_range
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_xstart_sample_coeff = (
alpha_prod_t_prev ** (0.5) * current_beta_t
) / beta_prod_t
current_sample_coeff = (
current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
)
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample_mean = (
pred_xstart_sample_coeff * pred_xstart
+ current_sample_coeff * sample
)
# 6. Compute predicted previous sample variance Σ
# See formula (11) from https://arxiv.org/pdf/2006.11239.pdf
# This part differs from HF implementation and uses DiT logic
pred_prev_std = 0
if t > 0:
pred_prev_log_variance, _ = self._get_variance(t, pred_var)
pred_prev_std = torch.exp(0.5 * pred_prev_log_variance)
# 7. Add noise
noise = torch.randn(
sample.size(),
dtype=sample.dtype,
layout=sample.layout,
device=sample.device,
generator=generator,
)
pred_prev_sample = (
pred_prev_sample_mean + pred_prev_std * noise
) # reparametrization trick
if not return_dict:
return (pred_prev_sample, pred_xstart)
return DDPMSamplerOutput(
prev_sample=pred_prev_sample, pred_original_sample=pred_xstart
)
def __len__(self):
return self.num_diffusion_steps
[docs] def previous_timestep(self, timestep):
"""
Returns the previous timestep based on current timestep.
Depends on the timesteps computed in `self.set_timesteps`
"""
index = (self.timesteps == timestep).nonzero()[0][0]
if index == self.timesteps.shape[0] - 1:
prev_timestep = torch.tensor(-1)
else:
prev_timestep = self.timesteps[index + 1]
return prev_timestep