# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from cerebras_pytorch.nn import DiceCELoss
from modelzoo.common.pytorch.layers.ConvNormActBlock import ConvNormActBlock
from modelzoo.common.pytorch.model_utils.create_initializer import (
create_initializer,
)
from modelzoo.common.pytorch.run_utils import half_dtype_instance
from modelzoo.vision.pytorch.unet.layers.Decoder import Decoder
from modelzoo.vision.pytorch.unet.layers.Encoder import Encoder
from modelzoo.vision.pytorch.unet.layers.UNetBlock import UNetBlock
[docs]class UNet(nn.Module):
"""
UNet Models
"""
def bce_loss(self, outputs, labels):
neg_outputs = -1 * outputs
zero_const = torch.tensor(
0.0, dtype=outputs.dtype, device=outputs.device
)
max_val = torch.where(neg_outputs > zero_const, neg_outputs, zero_const)
loss = (
(1 - labels)
.mul(outputs)
.add(max_val)
.add((-max_val).exp().add((neg_outputs - max_val).exp()).log())
)
mean = torch.mean(loss)
# The return needs to be a dtype of FP16 for WS
return mean.to(half_dtype_instance.half_dtype)
[docs] def __init__(self, model_params):
super(UNet, self).__init__()
self.num_classes = model_params["num_classes"]
self.skip_connect = model_params["skip_connect"]
self.downscale_method = model_params["downscale_method"]
self.residual_blocks = model_params["residual_blocks"]
self.use_conv3d = model_params["use_conv3d"]
self.downscale_first_conv = model_params["downscale_first_conv"]
self.downscale_encoder_blocks = model_params["downscale_encoder_blocks"]
self.downscale_bottleneck = model_params["downscale_bottleneck"]
self.include_background = not model_params["ignore_background_class"]
self.input_shape = [
model_params["batch_size"],
self.num_classes,
*model_params["image_shape"],
]
self.loss_type = model_params.get(
"loss", "bce" if self.num_classes <= 2 else "ssce"
).lower()
if "bce" in self.loss_type and "multilabel_bce" not in self.loss_type:
assert (
self.num_classes == 2
), "BCE loss may only be used when there are two classes!"
self.num_output_channels = 1
else:
self.num_output_channels = self.num_classes
if self.loss_type == "bce":
self.loss_fn = self.bce_loss
elif self.loss_type == "multilabel_bce":
self.loss_fn = nn.BCEWithLogitsLoss()
elif ("ssce" in self.loss_type) and ("dice" in self.loss_type):
self.loss_fn = DiceCELoss(
self.num_classes, self.input_shape, self.include_background,
)
elif "ssce" in self.loss_type:
self.loss_fn = nn.CrossEntropyLoss()
if self.residual_blocks:
assert self.downscale_method == "max_pool"
self.norm_layer = model_params.get("norm_layer", None)
self.norm_kwargs = model_params.get("norm_kwargs", None)
self.dropout_rate = model_params.get("dropout_rate", 0.0)
self.enable_bias = model_params.get(
"enable_bias", self.norm_layer == None
)
self.act = model_params["nonlinearity"].lower()
self.encoder_filters = model_params["encoder_filters"]
self.decoder_filters = model_params["decoder_filters"]
self.input_image_channels = model_params.get("input_channels")
self.initial_conv_filters = model_params.get("initial_conv_filters")
self.convs_per_block = model_params.get(
"convs_per_block", ["3x3_conv", "3x3_conv"]
)
assert (
len(self.encoder_filters) == len(self.decoder_filters) + 1
), "Number of encoder filters should be equal to number of decoder filters + 1 (bottleneck)"
# initializers
self.conv_initializer = model_params["initializer"]
self.bias_initializer = model_params["bias_initializer"]
self.norm_weight_initializer = model_params.get(
"norm_weight_initializer", {"name": "ones"}
)
self.initial_conv = None
if self.initial_conv_filters:
self.initial_conv = ConvNormActBlock(
in_channels=self.input_image_channels,
out_channels=self.initial_conv_filters,
kernel_size=3,
padding="same",
bias=self.enable_bias,
act="relu",
norm_layer=None,
use_conv3d=self.use_conv3d,
)
self.encoder = Encoder(
in_channels=self.initial_conv_filters
if self.initial_conv_filters
else self.input_image_channels,
encoder_filters=self.encoder_filters,
convs_per_block=self.convs_per_block,
bias=self.enable_bias,
norm_layer=self.norm_layer,
norm_kwargs=self.norm_kwargs,
act=self.act,
skip_connect=self.skip_connect,
residual_blocks=self.residual_blocks,
downscale_method=self.downscale_method,
dropout_rate=self.dropout_rate,
use_conv3d=self.use_conv3d,
downscale_first_conv=self.downscale_first_conv,
downscale_encoder_blocks=self.downscale_encoder_blocks,
)
self.bottleneck = UNetBlock(
in_channels=self.encoder_filters[-2],
out_channels=self.encoder_filters[-1],
encoder=False,
convs_per_block=self.convs_per_block,
skip_connect=self.skip_connect,
residual_blocks=self.residual_blocks,
norm_layer=self.norm_layer,
norm_kwargs=self.norm_kwargs,
downscale_method=self.downscale_method,
bias=self.enable_bias,
use_conv3d=self.use_conv3d,
downscale_first_conv=self.downscale_first_conv,
downscale=self.downscale_bottleneck,
)
self.decoder = Decoder(
in_channels=self.encoder_filters[-1],
decoder_filters=self.decoder_filters,
encoder_filters=self.encoder_filters,
convs_per_block=self.convs_per_block,
bias=self.enable_bias,
norm_layer=self.norm_layer,
norm_kwargs=self.norm_kwargs,
act=self.act,
skip_connect=self.skip_connect,
residual_blocks=self.residual_blocks,
downscale_method=self.downscale_method,
dropout_rate=self.dropout_rate,
use_conv3d=self.use_conv3d,
)
self.final_conv = ConvNormActBlock(
in_channels=self.decoder_filters[-1],
out_channels=self.num_output_channels,
kernel_size=1,
bias=True,
padding="same",
act="linear",
norm_layer=None,
use_conv3d=self.use_conv3d,
)
# initialize weights
self.reset_parameters()
def forward(self, inputs):
outputs = inputs
if self.initial_conv:
outputs = self.initial_conv(outputs)
outputs, skip_connections = self.encoder(outputs)
outputs = self.bottleneck(outputs)
outputs = self.decoder(outputs, skip_connections)
outputs = self.final_conv(outputs)
return outputs
def reset_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
create_initializer(self.conv_initializer)(m.weight)
if m.bias is not None:
create_initializer(self.bias_initializer)(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
create_initializer(self.norm_weight_initializer)(m.weight)
create_initializer(self.bias_initializer)(m.bias)