Source code for modelzoo.common.pytorch.layers.ViTEmbeddingLayer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from modelzoo.common.pytorch.layers.utils import (
    get_2d_fixed_position_embeddings,
    patchify_helper,
)
from modelzoo.common.pytorch.model_utils.create_initializer import (
    create_initializer,
)


[docs]class ViTEmbeddingLayer(nn.Module):
[docs] def __init__( self, image_size=[224, 224], num_channels=3, patch_size=[16, 16], hidden_size=768, initializer_range=0.02, embedding_dropout_rate=0.0, projection_initializer=None, position_embedding_initializer=None, position_embedding_type="learned", use_conv_patchified_embedding=False, prepend_cls_token=False, init_conv_like_linear=False, ): super(ViTEmbeddingLayer, self).__init__() self.image_size = image_size self.num_channels = num_channels self.patch_size = patch_size self.hidden_size = hidden_size self.initializer_range = initializer_range self.position_embedding_type = position_embedding_type self.use_conv_patchified_embedding = use_conv_patchified_embedding self.prepend_cls_token = prepend_cls_token self.init_conv_like_linear = init_conv_like_linear assert ( self.image_size[0] % self.patch_size[0] == 0 and self.image_size[1] % self.patch_size[1] == 0 ), f"image size {self.image_size} is not divisible by patch_size {self.patch_size}" assert self.position_embedding_type in [ None, "fixed", "learned", ], "Only `learned` or `fixed` position embeddings are supported for now." self.num_patches = [ (self.image_size[0] // self.patch_size[0]), (self.image_size[1] // self.patch_size[1]), ] if use_conv_patchified_embedding: self.linear_proj = nn.Conv2d( self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, ) else: self.embedding_size = ( self.patch_size[0] * self.patch_size[1] * num_channels ) self.linear_proj = nn.Linear(self.embedding_size, self.hidden_size) if self.position_embedding_type == "learned": num_position_embeddings = self.num_patches[0] * self.num_patches[1] if self.prepend_cls_token: num_position_embeddings += 1 self.position_embeddings = nn.Embedding( num_position_embeddings, self.hidden_size ) elif self.position_embedding_type == "fixed": # fixed sin&cos position_embeddings = get_2d_fixed_position_embeddings( self.num_patches, self.hidden_size, add_cls_token=prepend_cls_token, ) self.position_embeddings = torch.nn.Parameter( position_embeddings, requires_grad=False ) if self.prepend_cls_token: self.cls_embedding = nn.Parameter( torch.zeros(1, 1, self.hidden_size) ) self.cls_embedding_position_index = ( self.num_patches[0] * self.num_patches[1] ) # seq_len + 1 - 1, cls pe is the last self.default_initializer = { "name": "truncated_normal", "std": self.initializer_range, "mean": 0.0, "a": self.initializer_range * -2.0, "b": self.initializer_range * 2.0, } if projection_initializer is None: projection_initializer = self.default_initializer if position_embedding_initializer is None: position_embedding_initializer = self.default_initializer self.projection_initializer = projection_initializer self.position_embedding_initializer = position_embedding_initializer self.dropout_embd = nn.Dropout(embedding_dropout_rate) self.__reset_parameters()
def reset_parameters(self): self.__reset_parameters() def __reset_parameters(self): projection_initializer = create_initializer(self.projection_initializer) w = self.linear_proj.weight.data if self.use_conv_patchified_embedding and self.init_conv_like_linear: # Modifying fan-in fan-out by reshaping. # Bias set to zeros already projection_initializer(w.view([w.shape[0], -1])) else: projection_initializer(w) create_initializer("zeros")(self.linear_proj.bias.data) if self.prepend_cls_token: create_initializer(self.default_initializer)( self.cls_embedding.data ) if self.position_embedding_type == "learned": create_initializer(self.position_embedding_initializer)( self.position_embeddings.weight.data ) def get_image_sequence_position_embeddings(self, embeddings, indices=None): # embeddings shape [batch_size, seq_len, hidden_size], shouldn't contain cls # indices shape [batch_size, seq_len] if indices is None: position_ids = torch.arange( 0, embeddings.shape[1], device=embeddings.device, ).expand((embeddings.shape[0], -1)) else: position_ids = indices if self.position_embedding_type == "learned": position_embeddings = self.position_embeddings(position_ids) elif self.position_embedding_type == "fixed": # fixed position_ids = torch.broadcast_to( position_ids.unsqueeze(-1), ( position_ids.shape[0], position_ids.shape[1], embeddings.shape[-1], ), ).long() position_embeddings = torch.gather( self.position_embeddings.to(embeddings.dtype).expand( position_ids.shape[0], -1, -1 ), 1, position_ids, ) return position_embeddings def get_cls_token_position_embeddings(self, batch_size, dtype, device): if self.position_embedding_type == "learned": cls_indices = ( torch.ones((batch_size, 1), dtype=torch.int32, device=device,) * self.cls_embedding_position_index ) pe = self.position_embeddings(cls_indices) else: pe = ( self.position_embeddings[self.cls_embedding_position_index :, :] .to(dtype) .expand(batch_size, -1, -1) ) # [bs, 1, hidden_size] return pe
[docs] def select_patches(self, patches, patch_indices=None): """Select from patches based on patch_indices Args: patches (Tensor): shape [batch_size, full_sequence_length, hidden_size] patch_indices (Tensor): shape [batch_size., subset_sequence_length] Returns: patches (Tensor): shape [batch_size, subset_sequence_length, hidden_size] """ if patch_indices is None: return patches batch_size, subset_sequence_length = patch_indices.shape patch_indices = torch.broadcast_to( patch_indices.unsqueeze(-1), (batch_size, subset_sequence_length, patches.shape[-1]), ).long() patches = torch.gather(patches, 1, patch_indices) return patches
[docs] def forward(self, input_images, patch_indices=None): """Applies patching and linear projection to the input images. Args: input_images (Tensor): shape if use_conv_patchified_embedding ``[batch_size, num_channels, height, width]`` else ``[batch_size, sequence_len, embedding_size]``. patch_indices (Tensor): shape [batch_size, subset_seq_length]. If specified, embedding layer will select a subset of all image patches based on indices. This is used for applications like MAE. Default to None. Returns: image_embeddings (Tensor): shape ``[batch_size, sequence_length, hidden_size]``. """ batch_size = input_images.shape[0] if self.use_conv_patchified_embedding: # conv projection image_embeddings = self.linear_proj(input_images) # reshape hidden_size = image_embeddings.shape[1] image_embeddings = image_embeddings.reshape( batch_size, hidden_size, -1 ).transpose( 1, 2 ) # [bs, seq_length, hidden_size] image_embeddings = self.select_patches( image_embeddings, patch_indices=patch_indices ) else: # patchify patchified_image = patchify_helper(input_images, self.patch_size) # this saves computation compared to the conv implementation because patch selection happens before linear_proj image_embeddings = self.select_patches( patchified_image, patch_indices=patch_indices ) # linear projection image_embeddings = self.linear_proj( image_embeddings ) # [bs, seq_length, hidden_size] embeddings = image_embeddings if self.position_embedding_type is not None: image_pe = self.get_image_sequence_position_embeddings( image_embeddings, indices=patch_indices ) embeddings += image_pe if self.prepend_cls_token: expanded_cls_embedding = self.cls_embedding.type_as( image_embeddings ).expand(batch_size, -1, -1) expanded_cls_position_embedding = self.get_cls_token_position_embeddings( batch_size, image_embeddings.dtype, expanded_cls_embedding.device, ) cls_embeddings = ( expanded_cls_embedding + expanded_cls_position_embedding ) embeddings = torch.cat([cls_embeddings, embeddings], dim=1) embeddings = self.dropout_embd(embeddings) return embeddings