Source code for modelzoo.vision.pytorch.unet.layers.Encoder

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import zip_longest

import torch.nn as nn
from torch.nn import Dropout, MaxPool2d, MaxPool3d

from modelzoo.vision.pytorch.unet.layers.UNetBlock import UNetBlock


[docs]class Encoder(nn.Module): """ Args: in_channels (int): The input channel dimension before passing through the encoder. encoder_filters ([int]): List of filter sizes for each block in the encoder. convs_per_block ([str]): List of conv specifications for each conv in the block. bias (bool): Flag to use bias vectors. norm_layer (nn.Module): Desired normalization layer. norm_kwargs (dict): A dictionary of the arguments to pass to the constructor of the normalization layer. act (str): Activation to use. skip_connect (bool): Flag for if the model concatenates encoder outputs to decoder inputs. residual_blocks (bool): Flag for using residual connections at the end of each block. downscale_method (str):Downscaling method at the end of each block. dropout_rate (float): The probability that each element is dropped. use_conv3d (bool): 3D convolutions will be used when set to True downscale_first_conv (bool): If True, the first convolution operation in each UNetBlock will be downscaled. If False, the last convolution in each UNetBlock will be downscaled downscale_encoder_blocks (bool or [bool]): bool or list of bools that determine whether each block in the Encoder includes downsampling. Length of the list must correspond to the number of UNetBlocks in the Encoder. If a single bool is provided, all blocks will use this value. """
[docs] def __init__( self, in_channels, encoder_filters, convs_per_block, bias, norm_layer, norm_kwargs, act="relu", skip_connect=True, residual_blocks=False, downscale_method="max_pool", dropout_rate=0.0, use_conv3d=False, downscale_first_conv=False, downscale_encoder_blocks=True, ): super(Encoder, self).__init__() self.skip_connect = skip_connect dropout_layers = [] pooling_layers = [] unet_blocks = [] if isinstance(downscale_encoder_blocks, bool): downscale_encoder_blocks = [downscale_encoder_blocks] * len( encoder_filters ) for block_idx in range(len(encoder_filters) - 1): unet_blocks.append( UNetBlock( in_channels=encoder_filters[block_idx - 1] if block_idx else in_channels, out_channels=encoder_filters[block_idx], encoder=True, convs_per_block=convs_per_block, skip_connect=skip_connect, norm_layer=norm_layer, norm_kwargs=norm_kwargs, downscale_method=downscale_method, bias=bias, residual_blocks=residual_blocks, act=act, use_conv3d=use_conv3d, downscale_first_conv=downscale_first_conv, downscale=downscale_encoder_blocks[block_idx], ) ) if downscale_method == "max_pool": if use_conv3d: max_pool = MaxPool3d(kernel_size=2, stride=2) else: max_pool = MaxPool2d(kernel_size=2, stride=2) pooling_layers.append(max_pool) if dropout_rate: dropout_layers.append(Dropout(p=dropout_rate)) self.unet_blocks = nn.ModuleList(unet_blocks) self.pooling_layers = nn.ModuleList(pooling_layers) self.dropout_layers = nn.ModuleList(dropout_layers)
def forward(self, inputs): skip_connections = [] outputs = inputs for unet_block, pooling_layer, dropout_layer in zip_longest( self.unet_blocks, self.pooling_layers, self.dropout_layers ): outputs, skip_connection = unet_block(outputs) if self.skip_connect: skip_connections.append(skip_connection) if pooling_layer: outputs = pooling_layer(outputs) if dropout_layer: outputs = dropout_layer(outputs) return outputs, skip_connections