Source code for cerebras.modelzoo.tools.checkpoint_converters.vit

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch

from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_HF_CS,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)


[docs]class Converter_ViT_Core_HF_CS21(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # Embedding: ConversionRule( [ EquivalentSubkey( "embeddings.cls_token", "embedding_layer.cls_embedding" ), ], action=self.cls_embedding_convert, ), ConversionRule( [ EquivalentSubkey( "embeddings.position_embeddings", "embedding_layer.position_embeddings.weight", ), ], action=self.position_embeddings_convert, ), ConversionRule( [ EquivalentSubkey( "embeddings.patch_embeddings.projection", "embedding_layer.linear_proj", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # Encoder: ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.attention.query", "self_attn.proj_q_dense_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.attention.key", "self_attn.proj_k_dense_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.attention.value", "self_attn.proj_v_dense_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.output.dense", "self_attn.proj_output_dense_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "intermediate.dense", "ffn.ffn.0.linear_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey("output.dense", "ffn.ffn.1.linear_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey("layernorm_before", "norm1"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "encoder.transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey("layernorm_after", "norm2"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "layernorm", "encoder.transformer_encoder.norm", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # pooler ConversionRule( [ EquivalentSubkey( "pooler.dense", "encoder.pooler.pooler.ffn.0.linear_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ]
def cls_embedding_convert( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: new_state_dict[new_key] = old_state_dict[old_key].squeeze() else: new_state_dict[new_key] = old_state_dict[old_key].reshape(1, 1, -1) def position_embeddings_convert( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): assert ( action_fn_args["configs"][1]["model"]["position_embedding_type"] == "learned" ), "Only learned embeddings are supported" # cs vit pe puts cls token at last by default but hf put at index 0 if from_index == 0: new_state_dict[new_key] = torch.cat( [ old_state_dict[old_key][0, 1:, :], old_state_dict[old_key][0, :1, :], ], dim=0, ) else: new_state_dict[new_key] = torch.cat( [ old_state_dict[old_key][-1:, :], old_state_dict[old_key][:-1, :], ], dim=0, ).unsqueeze(0) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViT_HF_CS21
[docs]class Converter_ViT_Headless_HF_CS21(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # ViTModel has a pooling layer, ViTModelForImageClassification doesn't ConversionRule( [ r"pooler.dense\.(?:weight|bias)", ], exists="left", action=None, ), # for HF without head ConversionRule( [ EquivalentSubkey("", "vit_model."), Converter_ViT_Core_HF_CS21(), ], ), # drop classifier during CS -> HF ConversionRule( [ r"classifier.classifier.ffn.0.linear_layer\.(?:weight|bias)", ], exists="right", action=None, ), ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2"), ) @classmethod def converter_note(cls) -> str: return ( "{} ViTModel <-> {} ViTClassificationModel\n" "The HF model doesn't contain a classifier head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a classifier head initialized to default random " "values. When converting to HF, the classifier head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1]) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViT_HF_CS21
[docs] def post_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, key_prefix="", ): if converter_indices.direction == 0: # We are converting from HF ViTModel (headless) to our ViTForClassificationModel # We need to create 'classifier' and init to default values cs_config = configs[1] use_bias_in_output = cs_config["model"].get( "use_bias_in_output", False ) num_classes = cs_config["model"]["num_classes"] embed_dim = cs_config["model"]["hidden_size"] classifier_weight = torch.zeros((num_classes, embed_dim)) classifier_weight.normal_(mean=0.0, std=0.02) new_state_dict[ key_prefix + "classifier.classifier.ffn.0.linear_layer.weight" ] = classifier_weight if use_bias_in_output: lm_head_bias = torch.zeros(num_classes) new_state_dict[ key_prefix + "classifier.classifier.ffn.0.linear_layer.bias" ] = lm_head_bias super().post_model_convert( old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, key_prefix=key_prefix, )
[docs]class Converter_ViT_HF_CS21(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # for HF with head ConversionRule( [ EquivalentSubkey("vit.", "vit_model."), Converter_ViT_Core_HF_CS21(), ], ), # classifier ConversionRule( [ EquivalentSubkey( "classifier", "classifier.classifier.ffn.0.linear_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2"), ) @classmethod def converter_note(cls) -> str: return ( "{} ViTForImageClassification <-> {} ViTClassificationModel".format( cls.formats()[0], cls.formats()[1] ) ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_ViT_HF_CS21
[docs]class ConfigConverter_ViT_Core_HF_CS21(BaseConfigConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( ["model_type"], action=BaseConfigConverter.assert_factory_fn(0, "vit"), ), ConversionRule( ["use_post_embed_layer_norm"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["hidden_size"], action=self.replaceKey, ), ConversionRule( ["num_hidden_layers"], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("layer_norm_eps", "layer_norm_epsilon")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("num_attention_heads", "num_heads")], action=self.replaceKey, ), ConversionRule( ["attention_type"], action=BaseConfigConverter.assert_factory_fn( 1, "scaled_dot_product" ), ), ConversionRule( [EquivalentSubkey("hidden_dropout_prob", "dropout_rate")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("hidden_act", "nonlinearity")], action=self.convert_nonlinearity, ), ConversionRule( [ EquivalentSubkey( "attention_probs_dropout_prob", "attention_dropout_rate" ) ], action=self.replaceKey, ), ConversionRule( ["use_projection_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( [EquivalentSubkey("intermediate_size", "filter_size")], action=self.replaceKey, ), ConversionRule( ["use_ffn_bias"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["initializer_range"], action=self.replaceKey, ), ConversionRule( ["image_size"], action=self.convert_image_patch_size, ), ConversionRule( ["num_channels"], action=self.replaceKey, ), ConversionRule( ["patch_size"], action=self.convert_image_patch_size, ), ConversionRule( ["use_conv_patchified_embedding"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ] self.pre_convert_defaults[0].update( { "attention_probs_dropout_prob": 0.0, "encoder_stride": 16, "hidden_act": "gelu", "hidden_dropout_prob": 0.0, "hidden_size": 768, "image_size": 224, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "model_type": "vit", "num_attention_heads": 12, "num_channels": 3, "num_hidden_layers": 12, "patch_size": 16, "qkv_bias": True, } ) self.pre_convert_defaults[1].update( { "use_conv_patchified_embedding": True, "prepend_cls_token": True, "use_encoder_pooler_layer": True, "position_embedding_type": "learned", "num_classes": 2, }, ) self.post_convert_defaults[0].update( { "model_type": "vit", } ) self.post_convert_defaults[1].update( { "use_conv_patchified_embedding": True, "prepend_cls_token": True, "use_encoder_pooler_layer": True, "position_embedding_type": "learned", "num_classes": 2, "use_bias_in_output": True, } )
def convert_image_patch_size( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: size = old_state_dict[old_key] new_state_dict[new_key] = [size, size] else: width, height = old_state_dict[old_key] if width != height: raise ConfigConversionError( "Can't convert config with {}={}. Image width and height need to match.".format( old_key, old_state_dict[old_key] ) ) new_state_dict[new_key] = width def convert_nonlinearity( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): activation = old_state_dict[old_key] if from_index == 0: gated_hf2cs = { "silu": "swiglu", "gelu_pytorch_tanh": "gelu_new", "quick_gelu": "quick_gelu", } if activation in gated_hf2cs: activation = gated_hf2cs[activation] elif from_index == 1: gated_cs2hf = { "swiglu": "silu", "gelu_new": "gelu_pytorch_tanh", "quick_gelu": "quick_gelu", } if activation in gated_cs2hf: activation = gated_cs2hf[activation] new_state_dict[new_key] = activation def pre_config_convert( self, config, converter_indices, ): config = super().pre_config_convert(config, converter_indices) if ( converter_indices.direction == 0 and "encoder_stride" in config and config["encoder_stride"] != config["patch_size"] ): raise ConfigConversionError( f"{self.formats()[1]} model only supports encoder_stride == patch_size" ) return config def post_config_convert( self, original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ): if converter_indices.direction == 1: if "encoder_stride" not in new_config: new_config["encoder_stride"] = new_config["patch_size"] return super().post_config_convert( original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2"), )
[docs]class ConfigConverter_ViT_HF_CS21(ConfigConverter_ViT_Core_HF_CS21):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [EquivalentSubkey("num_labels", "num_classes")], action=self.replaceKey, ), *self.rules, ] self.pre_convert_defaults[1].update( { "use_encoder_pooler_layer": False, "num_classes": 2, }, ) self.post_convert_defaults[1].update( { "use_encoder_pooler_layer": False, "num_classes": 2, } )