Source code for modelzoo.vision.pytorch.dit.model

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn

from modelzoo.vision.pytorch.dit.modeling_dit import DiT


[docs]class DiTModel(nn.Module):
[docs] def __init__(self, params): super().__init__() model_params = params["model"].copy() self.model = self.build_model(model_params)
def build_model(self, model_params): model = DiT( num_diffusion_steps=model_params["num_diffusion_steps"], schedule_name=model_params["schedule_name"], beta_start=model_params["beta_start"], beta_end=model_params["beta_end"], num_classes=model_params.pop("num_classes"), # Embedding embedding_dropout_rate=model_params.pop( "embedding_dropout_rate", 0.0 ), hidden_size=model_params.pop("hidden_size"), embedding_nonlinearity=model_params.pop("embedding_nonlinearity"), position_embedding_type=model_params.pop("position_embedding_type"), # Encoder num_hidden_layers=model_params.pop("num_hidden_layers"), layer_norm_epsilon=float(model_params.pop("layer_norm_epsilon")), # Encoder Attn num_heads=model_params.pop("num_heads"), attention_type=model_params.pop( "attention_type", "scaled_dot_product" ), attention_softmax_fp32=model_params.pop( "attention_softmax_fp32", True ), dropout_rate=model_params.pop("dropout_rate"), nonlinearity=model_params.pop("encoder_nonlinearity", "gelu"), attention_dropout_rate=model_params.pop( "attention_dropout_rate", 0.0 ), use_projection_bias_in_attention=model_params.pop( "use_projection_bias_in_attention", True ), use_ffn_bias_in_attention=model_params.pop( "use_ffn_bias_in_attention", True ), # Encoder ffn filter_size=model_params.pop("filter_size"), use_ffn_bias=model_params.pop("use_ffn_bias", True), # Task-specific initializer_range=model_params.pop("initializer_range", 0.02), projection_initializer=model_params.pop( "projection_initializer", None ), position_embedding_initializer=model_params.pop( "position_embedding_initializer", None ), init_conv_like_linear=model_params.pop( "init_conv_like_linear", False ), attention_initializer=model_params.pop( "attention_initializer", None ), ffn_initializer=model_params.pop("ffn_initializer", None), timestep_embeddding_initializer=model_params.pop( "timestep_embeddding_initializer", None ), label_embedding_initializer=model_params.pop( "label_embedding_initializer", None ), head_initializer=model_params.pop("head_initializer", None), norm_first=model_params.pop("norm_first", True), # vision related params latent_size=model_params.pop("latent_size"), latent_channels=model_params.pop("latent_channels"), patch_size=model_params.pop("patch_size"), use_conv_patchified_embedding=model_params.pop( "use_conv_patchified_embedding", False ), # Context embeddings frequency_embedding_size=model_params.pop( "frequency_embedding_size" ), label_dropout_rate=model_params.pop("label_dropout_rate"), block_type=model_params.pop("block_type"), use_conv_transpose_unpatchify=model_params.pop( "use_conv_transpose_unpatchify" ), ) self.mse_loss = nn.MSELoss() return model def forward(self, data): diffusion_noise = data["diffusion_noise"] timestep = data["timestep"] model_output = self.model( input=data["input"], label=data["label"], diffusion_noise=data["diffusion_noise"], timestep=data["timestep"], ) pred_noise = model_output[0] loss = self.mse_loss(pred_noise, diffusion_noise) return loss