# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from modelzoo.common.pytorch.layers import ViTEmbeddingLayer
from modelzoo.common.pytorch.layers.AdaLayerNorm import AdaLayerNorm
from modelzoo.common.pytorch.model_utils.create_initializer import (
create_initializer,
)
from modelzoo.vision.pytorch.dit.layers.DiTDecoder import DiTDecoder
from modelzoo.vision.pytorch.dit.layers.DiTDecoderLayer import DiTDecoderLayer
from modelzoo.vision.pytorch.dit.layers.GaussianDiffusion import (
GaussianDiffusion,
)
from modelzoo.vision.pytorch.dit.layers.RegressionHead import RegressionHead
from modelzoo.vision.pytorch.dit.layers.TimestepEmbeddingLayer import (
TimestepEmbeddingLayer,
)
from modelzoo.vision.pytorch.dit.utils import BlockType
[docs]class DiT(nn.Module):
[docs] def __init__(
self,
# Scheduler params
num_diffusion_steps,
schedule_name,
beta_start,
beta_end,
# Embedding
embedding_dropout_rate=0.0,
embedding_nonlinearity="silu",
position_embedding_type="learned",
hidden_size=768,
# Encoder
num_hidden_layers=12,
layer_norm_epsilon=1.0e-5,
# Encoder Attn
num_heads=12,
attention_module_str="aiayn_attention",
extra_attention_params={},
attention_type="scaled_dot_product",
attention_softmax_fp32=True,
dropout_rate=0.0,
nonlinearity="gelu",
attention_dropout_rate=0.0,
use_projection_bias_in_attention=True,
use_ffn_bias_in_attention=True,
# Encoder ffn
filter_size=3072,
use_ffn_bias=True,
# Task-specific
initializer_range=0.02,
default_initializer=None,
projection_initializer=None,
position_embedding_initializer=None,
init_conv_like_linear=False,
attention_initializer=None,
ffn_initializer=None,
timestep_embeddding_initializer=None,
label_embedding_initializer=None,
head_initializer=None,
norm_first=True,
# vision related params
latent_size=[32, 32],
latent_channels=4,
patch_size=[16, 16],
use_conv_patchified_embedding=False,
# added DiT params
frequency_embedding_size=256,
num_classes=1000,
label_dropout_rate=0.1,
block_type=BlockType.ADALN_ZERO,
use_conv_transpose_unpatchify=False,
):
super(DiT, self).__init__()
# Flags for lowering tests
self.block_type = BlockType.get(block_type)
self.initializer_range = initializer_range
self.latent_channels = latent_channels
self.patch_size = patch_size
if default_initializer is None:
default_initializer = {
"name": "truncated_normal",
"std": self.initializer_range,
"mean": 0.0,
"a": self.initializer_range * -2.0,
"b": self.initializer_range * 2.0,
}
if attention_initializer is None:
attention_initializer = default_initializer
if ffn_initializer is None:
ffn_initializer = default_initializer
if timestep_embeddding_initializer is None:
timestep_embeddding_initializer = default_initializer
if label_embedding_initializer is None:
label_embedding_initializer = default_initializer
if head_initializer is None:
head_initializer = default_initializer
# embeddings
self.patch_embedding_layer = ViTEmbeddingLayer(
image_size=latent_size,
num_channels=latent_channels,
patch_size=patch_size,
hidden_size=hidden_size,
initializer_range=self.initializer_range,
embedding_dropout_rate=embedding_dropout_rate,
projection_initializer=projection_initializer,
position_embedding_initializer=position_embedding_initializer,
position_embedding_type=position_embedding_type,
use_conv_patchified_embedding=use_conv_patchified_embedding,
init_conv_like_linear=init_conv_like_linear,
)
self.projection_initializer = create_initializer(projection_initializer)
self.use_conv_patchified_embedding = use_conv_patchified_embedding
self.timestep_embedding_layer = TimestepEmbeddingLayer(
num_diffusion_steps=num_diffusion_steps,
frequency_embedding_size=frequency_embedding_size,
hidden_size=hidden_size,
nonlinearity=embedding_nonlinearity,
kernel_initializer=timestep_embeddding_initializer,
)
use_cfg_embedding = label_dropout_rate > 0
self.label_embedding_layer = nn.Embedding(
num_classes + use_cfg_embedding, hidden_size
)
self.label_embedding_initializer = create_initializer(
label_embedding_initializer
)
norm_layer = (
AdaLayerNorm
if self.block_type == BlockType.ADALN_ZERO
else nn.LayerNorm
)
decoder_layer = DiTDecoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=filter_size,
dropout=dropout_rate,
activation=nonlinearity,
layer_norm_eps=layer_norm_epsilon,
norm_first=norm_first,
norm_layer=norm_layer,
attention_module=attention_module_str,
extra_attention_params=extra_attention_params,
attention_dropout_rate=attention_dropout_rate,
attention_type=attention_type,
attention_softmax_fp32=attention_softmax_fp32,
use_projection_bias_in_attention=use_projection_bias_in_attention,
use_ffn_bias_in_attention=use_ffn_bias_in_attention,
use_ffn_bias=use_ffn_bias,
attention_initializer=attention_initializer,
ffn_initializer=ffn_initializer,
use_ff_layer1_dropout=False,
use_ff_layer2_dropout=True,
gate_res=True if self.block_type == BlockType.ADALN_ZERO else False,
add_cross_attention=False,
)
self.transformer_decoder = DiTDecoder(
decoder_layer=decoder_layer, num_layers=num_hidden_layers, norm=None
)
# regression heads
self.noise_head = RegressionHead(
image_size=latent_size,
hidden_size=hidden_size,
out_channels=latent_channels,
patch_size=patch_size,
use_conv_transpose_unpatchify=use_conv_transpose_unpatchify,
kernel_initializer=head_initializer,
)
self.final_norm = norm_layer(hidden_size, eps=layer_norm_epsilon)
self.gaussian_diffusion = GaussianDiffusion(
num_diffusion_steps,
schedule_name,
beta_start=beta_start,
beta_end=beta_end,
)
self.reset_parameters()
def reset_parameters(self):
# Embedding layers
self.patch_embedding_layer.reset_parameters()
self.timestep_embedding_layer.reset_parameters()
self.label_embedding_initializer(self.label_embedding_layer.weight.data)
# DiT Blocks
self.transformer_decoder.reset_parameters()
# Final AdaLayerNorm
self.final_norm.reset_parameters()
# Regression Heads for noise and var predictions
self.noise_head.reset_parameters()
def forward(
self, input, label, diffusion_noise, timestep,
):
latent = input
# NOTE: numerical differences observed due to
# bfloat16 vs float32 `noised_latent` output
# extract diffusion constants within model
noised_latent = self.gaussian_diffusion(
latent, diffusion_noise, timestep
)
pred_noise, pred_var = self.forward_dit(noised_latent, label, timestep)
# We have pred_var = None to be consistent and
# support other samplers in the future that uses
# variance to generate samples.
return pred_noise, pred_var
def forward_dit(self, noised_latent, label, timestep):
latent_embeddings = self.patch_embedding_layer(noised_latent)
context = None
timestep_embeddings = self.timestep_embedding_layer(timestep)
label_embeddings = self.label_embedding_layer(label)
context = timestep_embeddings + label_embeddings
hidden_states = self.transformer_decoder(latent_embeddings, context)
hidden_states = self.final_norm(hidden_states, context)
# We have `pred_var = None` to be consistent and
# support other samplers in the future that uses
# variance to generate samples and VLB loss
pred_var = None
pred_noise = self.noise_head(hidden_states)
return pred_noise, pred_var
[docs] def forward_dit_with_cfg(
self, noised_latent, label, timestep, guidance_scale, num_cfg_channels=3
):
"""
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
Assumes inputs are already batched with conditional and unconditional parts
Note: For exact reproducibility reasons, classifier-free guidance is applied only
three channels by default, hence `num_cfg_channels` defaults to 3.
The standard approach to cfg applies it to all channels.
"""
half = noised_latent[: len(noised_latent) // 2]
combined = torch.cat([half, half], dim=0)
pred_noise, pred_var = self.forward_dit(combined, label, timestep)
eps, rest = (
pred_noise[:, :num_cfg_channels],
pred_noise[:, num_cfg_channels:],
) # eps shape: (bsz, num_cfg_channels, H, W)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
# (1-guidance_scale) * uncond_eps + guidance_scale * cond_eps
# `guidance_scale`` = 1 disables classifier-free guidance, while
# increasing `guidance_scale` > 1 strengthens the effect of guidance
half_eps = uncond_eps + guidance_scale * (
cond_eps - uncond_eps
) # half_eps shape: (bsz//2, num_cfg_channels, H, W)
eps = torch.cat(
[half_eps, half_eps], dim=0
) # eps shape: (bsz, num_cfg_channels, H, W)
pred_noise = torch.cat([eps, rest], dim=1)
return pred_noise, pred_var