Source code for modelzoo.common.pytorch.model_utils.checkpoint_converters.vit_mae

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch

from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.vit import (
    ConfigConverter_ViT_HF_CS19,
    Converter_ViT_Core_HF_CS19,
)


[docs]class Converter_ViTMAE_Core_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # ViT Encoder ConversionRule( [EquivalentSubkey("vit.", "",), Converter_ViT_Core_HF_CS19(),], action=None, ), # Decoder ConversionRule( ["decoder.", EquivalentSubkey("mask_token", "mask_embedding")], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_pos_embed", "position_embeddings.weight" ), ], action=self.position_embeddings_convert, ), ConversionRule( [ EquivalentSubkey( "decoder.decoder_embed", "encoder_decoder_projection" ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers" ), "\.\d+\.", EquivalentSubkey( "attention.attention.query", "self_attn.proj_q_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers" ), "\.\d+\.", EquivalentSubkey( "attention.attention.key", "self_attn.proj_k_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers" ), "\.\d+\.", EquivalentSubkey( "attention.attention.value", "self_attn.proj_v_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers" ), "\.\d+\.", EquivalentSubkey( "attention.output.dense", "self_attn.proj_output_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey( "intermediate.dense", "ffn.ffn.0.linear_layer" ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey("output.dense", "ffn.ffn.1.linear_layer"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey("layernorm_before", "norm1"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_layers", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey("layernorm_after", "norm2"), "\.(?:weight|bias)", ], action=self.replaceKey, ), # Norm ConversionRule( [ "decoder.", EquivalentSubkey( "decoder_norm", "encoder.transformer_encoder.norm" ), "\.(?:weight|bias)", ], action=self.replaceKey, ), # Pred ConversionRule( [ "decoder.", EquivalentSubkey("decoder_pred", "output_projection"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ]
def position_embeddings_convert( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): assert ( action_fn_args["configs"][1]["model"]["position_embedding_type"] == "learned" ), "Only learned embeddings are supported" # cs vit pe puts cls token at last by default but hf put at index 0 if from_index == 0: new_state_dict[new_key] = torch.cat( [ old_state_dict[old_key][0, 1:, :], old_state_dict[old_key][0, :1, :], ], dim=0, ) else: new_state_dict[new_key] = torch.cat( [ old_state_dict[old_key][-1:, :], old_state_dict[old_key][:-1, :], ], dim=0, ).unsqueeze(0) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0")) @classmethod def converter_note(cls) -> str: return "{} ViTMAEForPreTraining <-> {} ViTMAEModel".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViTMAE_HF_CS19
[docs]class ConfigConverter_ViTMAE_HF_CS19(ConfigConverter_ViT_HF_CS19):
[docs] def __init__(self): super().__init__() decoder_rules = [ ConversionRule( ["model_type"], action=BaseConfigConverter.assert_factory_fn(0, "vit_mae"), ), ConversionRule( [EquivalentSubkey("hidden_act", "encoder_nonlinearity")], action=self.convert_nonlinearity, ), ConversionRule( ["decoder_nonlinearity"], action=self.assert_decoder_nonlinearity, ), ConversionRule( [ EquivalentSubkey( "decoder_num_attention_heads", "decoder_num_heads" ) ], action=self.replaceKey, ), ConversionRule(["decoder_hidden_size"], action=self.replaceKey), ConversionRule( ["decoder_num_hidden_layers"], action=self.replaceKey ), ConversionRule( [ EquivalentSubkey( "decoder_intermediate_size", "decoder_filter_size" ) ], action=self.replaceKey, ), ConversionRule(["mask_ratio"], action=self.replaceKey,), ] self.rules = decoder_rules + self.rules del self.pre_convert_defaults[0]["encoder_stride"] self.pre_convert_defaults[0].update( { "decoder_intermediate_size": 2048, "decoder_num_attention_heads": 12, "decoder_num_hidden_layers": 8, "mask_ratio": 0.75, "norm_pix_loss": False, } ) self.pre_convert_defaults[1].update({"mask_ratio": 0.75}) self.post_convert_defaults[0].update( {"model_type": "vit_mae", "norm_pix_loss": False} ) self.post_convert_defaults[1].update( {"mask_ratio": "0.75",} )
def convert_nonlinearity( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): activation = old_state_dict[old_key] new_state_dict[new_key] = activation if from_index == 0: # set `encoder_nonlinearity` and `decoder_nonlinearity` to `activation` new_state_dict["decoder_nonlinearity"] = activation def assert_decoder_nonlinearity( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if old_state_dict["encoder_nonlinearity"] != old_state_dict[old_key]: raise ConfigConversionError( "Encoder & Decoder nonlinearities must be the same in HF model. Got: {} vs {}".format( old_state_dict["encoder_nonlinearity"], old_state_dict[old_key], ) ) def post_config_convert( self, original_config, old_config, new_config, from_index, drop_unmatched_keys, ): final_config = super().post_config_convert( original_config, old_config, new_config, from_index, drop_unmatched_keys, ) # pop extraneous keys from ViT ConfigConverter if from_index == 0: final_config["model"].pop("num_classes") else: final_config.pop("encoder_stride") final_config.pop("num_labels") return final_config