Source code for modelzoo.vision.pytorch.dit.modeling_dit

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from modelzoo.common.pytorch.layers import ViTEmbeddingLayer
from modelzoo.common.pytorch.layers.AdaLayerNorm import AdaLayerNorm
from modelzoo.common.pytorch.model_utils.create_initializer import (
    create_initializer,
)
from modelzoo.vision.pytorch.dit.layers.DiTDecoder import DiTDecoder
from modelzoo.vision.pytorch.dit.layers.DiTDecoderLayer import DiTDecoderLayer
from modelzoo.vision.pytorch.dit.layers.GaussianDiffusion import (
    GaussianDiffusion,
)
from modelzoo.vision.pytorch.dit.layers.RegressionHead import RegressionHead
from modelzoo.vision.pytorch.dit.layers.TimestepEmbeddingLayer import (
    TimestepEmbeddingLayer,
)
from modelzoo.vision.pytorch.dit.utils import BlockType


[docs]class DiT(nn.Module):
[docs] def __init__( self, # Scheduler params num_diffusion_steps, schedule_name, beta_start, beta_end, # Embedding embedding_dropout_rate=0.0, embedding_nonlinearity="silu", position_embedding_type="learned", hidden_size=768, # Encoder num_hidden_layers=12, layer_norm_epsilon=1.0e-5, # Encoder Attn num_heads=12, attention_module_str="aiayn_attention", extra_attention_params={}, attention_type="scaled_dot_product", attention_softmax_fp32=True, dropout_rate=0.0, nonlinearity="gelu", attention_dropout_rate=0.0, use_projection_bias_in_attention=True, use_ffn_bias_in_attention=True, # Encoder ffn filter_size=3072, use_ffn_bias=True, # Task-specific initializer_range=0.02, default_initializer=None, projection_initializer=None, position_embedding_initializer=None, init_conv_like_linear=False, attention_initializer=None, ffn_initializer=None, timestep_embeddding_initializer=None, label_embedding_initializer=None, head_initializer=None, norm_first=True, # vision related params latent_size=[32, 32], latent_channels=4, patch_size=[16, 16], use_conv_patchified_embedding=False, # added DiT params frequency_embedding_size=256, num_classes=1000, label_dropout_rate=0.1, block_type=BlockType.ADALN_ZERO, use_conv_transpose_unpatchify=False, ): super(DiT, self).__init__() # Flags for lowering tests self.block_type = BlockType.get(block_type) self.initializer_range = initializer_range self.latent_channels = latent_channels self.patch_size = patch_size if default_initializer is None: default_initializer = { "name": "truncated_normal", "std": self.initializer_range, "mean": 0.0, "a": self.initializer_range * -2.0, "b": self.initializer_range * 2.0, } if attention_initializer is None: attention_initializer = default_initializer if ffn_initializer is None: ffn_initializer = default_initializer if timestep_embeddding_initializer is None: timestep_embeddding_initializer = default_initializer if label_embedding_initializer is None: label_embedding_initializer = default_initializer if head_initializer is None: head_initializer = default_initializer # embeddings self.patch_embedding_layer = ViTEmbeddingLayer( image_size=latent_size, num_channels=latent_channels, patch_size=patch_size, hidden_size=hidden_size, initializer_range=self.initializer_range, embedding_dropout_rate=embedding_dropout_rate, projection_initializer=projection_initializer, position_embedding_initializer=position_embedding_initializer, position_embedding_type=position_embedding_type, use_conv_patchified_embedding=use_conv_patchified_embedding, init_conv_like_linear=init_conv_like_linear, ) self.projection_initializer = create_initializer(projection_initializer) self.use_conv_patchified_embedding = use_conv_patchified_embedding self.timestep_embedding_layer = TimestepEmbeddingLayer( num_diffusion_steps=num_diffusion_steps, frequency_embedding_size=frequency_embedding_size, hidden_size=hidden_size, nonlinearity=embedding_nonlinearity, kernel_initializer=timestep_embeddding_initializer, ) use_cfg_embedding = label_dropout_rate > 0 self.label_embedding_layer = nn.Embedding( num_classes + use_cfg_embedding, hidden_size ) self.label_embedding_initializer = create_initializer( label_embedding_initializer ) norm_layer = ( AdaLayerNorm if self.block_type == BlockType.ADALN_ZERO else nn.LayerNorm ) decoder_layer = DiTDecoderLayer( d_model=hidden_size, nhead=num_heads, dim_feedforward=filter_size, dropout=dropout_rate, activation=nonlinearity, layer_norm_eps=layer_norm_epsilon, norm_first=norm_first, norm_layer=norm_layer, attention_module=attention_module_str, extra_attention_params=extra_attention_params, attention_dropout_rate=attention_dropout_rate, attention_type=attention_type, attention_softmax_fp32=attention_softmax_fp32, use_projection_bias_in_attention=use_projection_bias_in_attention, use_ffn_bias_in_attention=use_ffn_bias_in_attention, use_ffn_bias=use_ffn_bias, attention_initializer=attention_initializer, ffn_initializer=ffn_initializer, use_ff_layer1_dropout=False, use_ff_layer2_dropout=True, gate_res=True if self.block_type == BlockType.ADALN_ZERO else False, add_cross_attention=False, ) self.transformer_decoder = DiTDecoder( decoder_layer=decoder_layer, num_layers=num_hidden_layers, norm=None ) # regression heads self.noise_head = RegressionHead( image_size=latent_size, hidden_size=hidden_size, out_channels=latent_channels, patch_size=patch_size, use_conv_transpose_unpatchify=use_conv_transpose_unpatchify, kernel_initializer=head_initializer, ) self.final_norm = norm_layer(hidden_size, eps=layer_norm_epsilon) self.gaussian_diffusion = GaussianDiffusion( num_diffusion_steps, schedule_name, beta_start=beta_start, beta_end=beta_end, ) self.reset_parameters()
def reset_parameters(self): # Embedding layers self.patch_embedding_layer.reset_parameters() self.timestep_embedding_layer.reset_parameters() self.label_embedding_initializer(self.label_embedding_layer.weight.data) # DiT Blocks self.transformer_decoder.reset_parameters() # Final AdaLayerNorm self.final_norm.reset_parameters() # Regression Heads for noise and var predictions self.noise_head.reset_parameters() def forward( self, input, label, diffusion_noise, timestep, ): latent = input # NOTE: numerical differences observed due to # bfloat16 vs float32 `noised_latent` output # extract diffusion constants within model noised_latent = self.gaussian_diffusion( latent, diffusion_noise, timestep ) pred_noise, pred_var = self.forward_dit(noised_latent, label, timestep) # We have pred_var = None to be consistent and # support other samplers in the future that uses # variance to generate samples. return pred_noise, pred_var def forward_dit(self, noised_latent, label, timestep): latent_embeddings = self.patch_embedding_layer(noised_latent) context = None timestep_embeddings = self.timestep_embedding_layer(timestep) label_embeddings = self.label_embedding_layer(label) context = timestep_embeddings + label_embeddings hidden_states = self.transformer_decoder(latent_embeddings, context) hidden_states = self.final_norm(hidden_states, context) # We have `pred_var = None` to be consistent and # support other samplers in the future that uses # variance to generate samples and VLB loss pred_var = None pred_noise = self.noise_head(hidden_states) return pred_noise, pred_var
[docs] def forward_dit_with_cfg( self, noised_latent, label, timestep, guidance_scale, num_cfg_channels=3 ): """ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. Assumes inputs are already batched with conditional and unconditional parts Note: For exact reproducibility reasons, classifier-free guidance is applied only three channels by default, hence `num_cfg_channels` defaults to 3. The standard approach to cfg applies it to all channels. """ half = noised_latent[: len(noised_latent) // 2] combined = torch.cat([half, half], dim=0) pred_noise, pred_var = self.forward_dit(combined, label, timestep) eps, rest = ( pred_noise[:, :num_cfg_channels], pred_noise[:, num_cfg_channels:], ) # eps shape: (bsz, num_cfg_channels, H, W) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) # (1-guidance_scale) * uncond_eps + guidance_scale * cond_eps # `guidance_scale`` = 1 disables classifier-free guidance, while # increasing `guidance_scale` > 1 strengthens the effect of guidance half_eps = uncond_eps + guidance_scale * ( cond_eps - uncond_eps ) # half_eps shape: (bsz//2, num_cfg_channels, H, W) eps = torch.cat( [half_eps, half_eps], dim=0 ) # eps shape: (bsz, num_cfg_channels, H, W) pred_noise = torch.cat([eps, rest], dim=1) return pred_noise, pred_var