Source code for modelzoo.common.pytorch.model_utils.checkpoint_converters.vit

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch

from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_HF_CS,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)


[docs]class Converter_ViT_Core_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # Embedding: ConversionRule( [ EquivalentSubkey( "embeddings.cls_token", "embedding_layer.cls_embedding" ), ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "embeddings.position_embeddings", "embedding_layer.position_embeddings.weight", ), ], action=self.position_embeddings_convert, ), ConversionRule( [ EquivalentSubkey( "embeddings.patch_embeddings.projection", "embedding_layer.linear_proj", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), # Encoder: ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey( "attention.attention.query", "self_attn.proj_q_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey( "attention.attention.key", "self_attn.proj_k_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey( "attention.attention.value", "self_attn.proj_v_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey( "attention.output.dense", "self_attn.proj_output_dense_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey( "intermediate.dense", "ffn.ffn.0.linear_layer" ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey("output.dense", "ffn.ffn.1.linear_layer"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey("layernorm_before", "norm1"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), "\.\d+\.", EquivalentSubkey("layernorm_after", "norm2"), "\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "layernorm", "encoder.transformer_encoder.norm", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), # pooler ConversionRule( [ EquivalentSubkey( "pooler.dense", "encoder.pooler.pooler.ffn.0.linear_layer", ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ]
def position_embeddings_convert( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): assert ( action_fn_args["configs"][1]["model"]["position_embedding_type"] == "learned" ), "Only learned embeddings are supported" # cs vit pe puts cls token at last by default but hf put at index 0 if from_index == 0: new_state_dict[new_key] = torch.cat( [ old_state_dict[old_key][0, 1:, :], old_state_dict[old_key][0, :1, :], ], dim=0, ) else: new_state_dict[new_key] = torch.cat( [ old_state_dict[old_key][-1:, :], old_state_dict[old_key][:-1, :], ], dim=0, ).unsqueeze(0) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViT_HF_CS19
[docs]class Converter_ViT_Headless_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # for HF without head ConversionRule( [ EquivalentSubkey("", "vit_model."), Converter_ViT_Core_HF_CS19(), ], ), # drop classifier during CS -> HF ConversionRule( ["classifier.classifier.ffn.0.linear_layer\.(?:weight|bias)",], exists="right", action=None, ), ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0")) @classmethod def converter_note(cls) -> str: return ( "{} ViTModel <-> {} ViTClassificationModel\n" "The HF model doesn't contain a classifier head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a classifier head initialized to default random " "values. When converting to HF, the classifier head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1]) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViT_HF_CS19
[docs] def post_model_convert( self, old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, ): if from_index == 0: # We are converting from HF ViTModel (headless) to our ViTForClassificationModel # We need to create 'classifier' and init to default values hf_config = configs[0] cs_config = configs[1] use_bias_in_output = cs_config["model"].get( "use_bias_in_output", False ) num_classes = cs_config["model"]["num_classes"] embed_dim = cs_config["model"]["hidden_size"] classifier_weight = torch.zeros((num_classes, embed_dim)) classifier_weight.normal_(mean=0.0, std=0.02) new_state_dict[ "classifier.classifier.ffn.0.linear_layer.weight" ] = classifier_weight if use_bias_in_output: lm_head_bias = torch.zeros(num_classes) new_state_dict[ "classifier.classifier.ffn.0.linear_layer.bias" ] = lm_head_bias cs_config["model"]["use_encoder_pooler_layer"] = ( "pooler.dense.weight" in old_state_dict or "vit.pooler.dense.weight" in old_state_dict ) super().post_model_convert( old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, )
[docs]class Converter_ViT_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # for HF with head ConversionRule( [ EquivalentSubkey("vit.", "vit_model."), Converter_ViT_Core_HF_CS19(), ], ), # classifier ConversionRule( [ EquivalentSubkey( "classifier", "classifier.classifier.ffn.0.linear_layer" ), "\.(?:weight|bias)", ], action=self.replaceKey, ), ]
[docs] def post_model_convert( self, old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, ): if from_index == 0: hf_config = configs[0] cs_config = configs[1] cs_config["model"]["use_encoder_pooler_layer"] = ( "pooler.dense.weight" in old_state_dict ) super().post_model_convert( old_state_dict, new_state_dict, configs, from_index, drop_unmatched_keys, )
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0")) @classmethod def converter_note(cls) -> str: return "{} ViTForImageClassification <-> {} ViTClassificationModel".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViT_HF_CS19
[docs]class ConfigConverter_ViT_HF_CS19(BaseConfigConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( ["model_type"], action=BaseConfigConverter.assert_factory_fn(0, "vit"), ), ConversionRule(["hidden_size"], action=self.replaceKey,), ConversionRule(["num_hidden_layers"], action=self.replaceKey,), ConversionRule( [EquivalentSubkey("layer_norm_eps", "layer_norm_epsilon")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("num_attention_heads", "num_heads")], action=self.replaceKey, ), ConversionRule( ["attention_type"], action=BaseConfigConverter.assert_factory_fn( 1, "scaled_dot_product" ), ), ConversionRule( [EquivalentSubkey("hidden_dropout_prob", "dropout_rate")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("hidden_act", "encoder_nonlinearity")], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "attention_probs_dropout_prob", "attention_dropout_rate" ) ], action=self.replaceKey, ), ConversionRule( ["use_projection_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( [EquivalentSubkey("intermediate_size", "filter_size")], action=self.replaceKey, ), ConversionRule( ["use_ffn_bias"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule(["initializer_range"], action=self.replaceKey,), ConversionRule( ["image_size"], action=self.convert_image_patch_size, ), ConversionRule(["num_channels"], action=self.replaceKey,), ConversionRule( ["patch_size"], action=self.convert_image_patch_size, ), ConversionRule( ["use_conv_patchified_embedding"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( [EquivalentSubkey("num_labels", "num_classes")], action=self.replaceKey, ), ] self.pre_convert_defaults[0].update( { "attention_probs_dropout_prob": 0.0, "encoder_stride": 16, "hidden_act": "gelu", "hidden_dropout_prob": 0.0, "hidden_size": 768, "image_size": 224, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "model_type": "vit", "num_attention_heads": 12, "num_channels": 3, "num_hidden_layers": 12, "patch_size": 16, "qkv_bias": True, } ) self.pre_convert_defaults[1].update( { "use_conv_patchified_embedding": True, "prepend_cls_token": True, "use_encoder_pooler_layer": False, "position_embedding_type": "learned", "num_classes": 2, }, ) self.post_convert_defaults[0].update( {"model_type": "vit",} ) self.post_convert_defaults[1].update( { "use_conv_patchified_embedding": True, "prepend_cls_token": True, "use_encoder_pooler_layer": False, "position_embedding_type": "learned", "num_classes": 2, "use_bias_in_output": True, } )
def convert_image_patch_size( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: size = old_state_dict[old_key] new_state_dict[new_key] = [size, size] else: width, height = old_state_dict[old_key] if width != height: raise ConfigConversionError( "Can't convert config with {}={}. Image width and heigh need to match.".format( old_key, old_state_dict[old_key] ) ) new_state_dict[new_key] = width def pre_config_convert( self, config, from_index, ): config = super().pre_config_convert(config, from_index) if ( from_index == 0 and "encoder_stride" in config and config["encoder_stride"] != config["patch_size"] ): raise ConfigConversionError( f"{self.formats()[1]} model only supports encoder_stride == patch_size" ) return config def post_config_convert( self, original_config, old_config, new_config, from_index, drop_unmatched_keys, ): if from_index == 1: if "encoder_stride" not in new_config: new_config["encoder_stride"] = new_config["patch_size"] return super().post_config_convert( original_config, old_config, new_config, from_index, drop_unmatched_keys, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0"))