# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from modelzoo.vision.pytorch.dit.modeling_dit import DiT
[docs]class DiTModel(nn.Module):
[docs] def __init__(self, params):
super().__init__()
model_params = params["model"].copy()
self.model = self.build_model(model_params)
def build_model(self, model_params):
model = DiT(
num_diffusion_steps=model_params["num_diffusion_steps"],
schedule_name=model_params["schedule_name"],
beta_start=model_params["beta_start"],
beta_end=model_params["beta_end"],
num_classes=model_params.pop("num_classes"),
# Embedding
embedding_dropout_rate=model_params.pop(
"embedding_dropout_rate", 0.0
),
hidden_size=model_params.pop("hidden_size"),
embedding_nonlinearity=model_params.pop("embedding_nonlinearity"),
position_embedding_type=model_params.pop("position_embedding_type"),
# Encoder
num_hidden_layers=model_params.pop("num_hidden_layers"),
layer_norm_epsilon=float(model_params.pop("layer_norm_epsilon")),
# Encoder Attn
num_heads=model_params.pop("num_heads"),
attention_type=model_params.pop(
"attention_type", "scaled_dot_product"
),
attention_softmax_fp32=model_params.pop(
"attention_softmax_fp32", True
),
dropout_rate=model_params.pop("dropout_rate"),
nonlinearity=model_params.pop("encoder_nonlinearity", "gelu"),
attention_dropout_rate=model_params.pop(
"attention_dropout_rate", 0.0
),
use_projection_bias_in_attention=model_params.pop(
"use_projection_bias_in_attention", True
),
use_ffn_bias_in_attention=model_params.pop(
"use_ffn_bias_in_attention", True
),
# Encoder ffn
filter_size=model_params.pop("filter_size"),
use_ffn_bias=model_params.pop("use_ffn_bias", True),
# Task-specific
initializer_range=model_params.pop("initializer_range", 0.02),
projection_initializer=model_params.pop(
"projection_initializer", None
),
position_embedding_initializer=model_params.pop(
"position_embedding_initializer", None
),
init_conv_like_linear=model_params.pop(
"init_conv_like_linear", False
),
attention_initializer=model_params.pop(
"attention_initializer", None
),
ffn_initializer=model_params.pop("ffn_initializer", None),
timestep_embeddding_initializer=model_params.pop(
"timestep_embeddding_initializer", None
),
label_embedding_initializer=model_params.pop(
"label_embedding_initializer", None
),
head_initializer=model_params.pop("head_initializer", None),
norm_first=model_params.pop("norm_first", True),
# vision related params
latent_size=model_params.pop("latent_size"),
latent_channels=model_params.pop("latent_channels"),
patch_size=model_params.pop("patch_size"),
use_conv_patchified_embedding=model_params.pop(
"use_conv_patchified_embedding", False
),
# Context embeddings
frequency_embedding_size=model_params.pop(
"frequency_embedding_size"
),
label_dropout_rate=model_params.pop("label_dropout_rate"),
block_type=model_params.pop("block_type"),
use_conv_transpose_unpatchify=model_params.pop(
"use_conv_transpose_unpatchify"
),
)
self.mse_loss = nn.MSELoss()
return model
def forward(self, data):
diffusion_noise = data["diffusion_noise"]
timestep = data["timestep"]
model_output = self.model(
input=data["input"],
label=data["label"],
diffusion_noise=data["diffusion_noise"],
timestep=data["timestep"],
)
pred_noise = model_output[0]
loss = self.mse_loss(pred_noise, diffusion_noise)
return loss