# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py
import torch.nn as nn
from modelzoo.vision.pytorch.dit.layers.vae.UNetMidBlock2D import UNetMidBlock2D
from modelzoo.vision.pytorch.dit.layers.vae.UpDecoderBlock2D import (
UpDecoderBlock2D,
)
[docs]class Decoder(nn.Module):
[docs] def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=norm_num_groups,
temb_channels=None,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
assert (
up_block_type == "UpDecoderBlock2D"
), f"Support for {up_block_type} not added"
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock2D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
resnet_time_scale_shift="default",
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0],
num_groups=norm_num_groups,
eps=1e-6,
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, 3, padding=1
)
def forward(self, z):
sample = z
sample = self.conv_in(sample)
# middle
sample = self.mid_block(sample)
# up
for up_block in self.up_blocks:
sample = up_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample