# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
ConfigConversionError,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
from modelzoo.common.pytorch.model_utils.checkpoint_converters.vit import (
ConfigConverter_ViT_HF_CS19,
Converter_ViT_Core_HF_CS19,
)
[docs]class Converter_ViTMAE_Core_HF_CS19(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# ViT Encoder
ConversionRule(
[EquivalentSubkey("vit.", "",), Converter_ViT_Core_HF_CS19(),],
action=None,
),
# Decoder
ConversionRule(
["decoder.", EquivalentSubkey("mask_token", "mask_embedding")],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_pos_embed", "position_embeddings.weight"
),
],
action=self.position_embeddings_convert,
),
ConversionRule(
[
EquivalentSubkey(
"decoder.decoder_embed", "encoder_decoder_projection"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers"
),
"\.\d+\.",
EquivalentSubkey(
"attention.attention.query",
"self_attn.proj_q_dense_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers"
),
"\.\d+\.",
EquivalentSubkey(
"attention.attention.key",
"self_attn.proj_k_dense_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers"
),
"\.\d+\.",
EquivalentSubkey(
"attention.attention.value",
"self_attn.proj_v_dense_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers"
),
"\.\d+\.",
EquivalentSubkey(
"attention.output.dense",
"self_attn.proj_output_dense_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey(
"intermediate.dense", "ffn.ffn.0.linear_layer"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey("output.dense", "ffn.ffn.1.linear_layer"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey("layernorm_before", "norm1"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_layers", "encoder.transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey("layernorm_after", "norm2"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
# Norm
ConversionRule(
[
"decoder.",
EquivalentSubkey(
"decoder_norm", "encoder.transformer_encoder.norm"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
# Pred
ConversionRule(
[
"decoder.",
EquivalentSubkey("decoder_pred", "output_projection"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
]
def position_embeddings_convert(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
assert (
action_fn_args["configs"][1]["model"]["position_embedding_type"]
== "learned"
), "Only learned embeddings are supported"
# cs vit pe puts cls token at last by default but hf put at index 0
if from_index == 0:
new_state_dict[new_key] = torch.cat(
[
old_state_dict[old_key][0, 1:, :],
old_state_dict[old_key][0, :1, :],
],
dim=0,
)
else:
new_state_dict[new_key] = torch.cat(
[
old_state_dict[old_key][-1:, :],
old_state_dict[old_key][:-1, :],
],
dim=0,
).unsqueeze(0)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0"))
@classmethod
def converter_note(cls) -> str:
return "{} ViTMAEForPreTraining <-> {} ViTMAEModel".format(
cls.formats()[0], cls.formats()[1]
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_ViTMAE_HF_CS19
[docs]class ConfigConverter_ViTMAE_HF_CS19(ConfigConverter_ViT_HF_CS19):
[docs] def __init__(self):
super().__init__()
decoder_rules = [
ConversionRule(
["model_type"],
action=BaseConfigConverter.assert_factory_fn(0, "vit_mae"),
),
ConversionRule(
[EquivalentSubkey("hidden_act", "encoder_nonlinearity")],
action=self.convert_nonlinearity,
),
ConversionRule(
["decoder_nonlinearity"],
action=self.assert_decoder_nonlinearity,
),
ConversionRule(
[
EquivalentSubkey(
"decoder_num_attention_heads", "decoder_num_heads"
)
],
action=self.replaceKey,
),
ConversionRule(["decoder_hidden_size"], action=self.replaceKey),
ConversionRule(
["decoder_num_hidden_layers"], action=self.replaceKey
),
ConversionRule(
[
EquivalentSubkey(
"decoder_intermediate_size", "decoder_filter_size"
)
],
action=self.replaceKey,
),
ConversionRule(["mask_ratio"], action=self.replaceKey,),
]
self.rules = decoder_rules + self.rules
del self.pre_convert_defaults[0]["encoder_stride"]
self.pre_convert_defaults[0].update(
{
"decoder_intermediate_size": 2048,
"decoder_num_attention_heads": 12,
"decoder_num_hidden_layers": 8,
"mask_ratio": 0.75,
"norm_pix_loss": False,
}
)
self.pre_convert_defaults[1].update({"mask_ratio": 0.75})
self.post_convert_defaults[0].update(
{"model_type": "vit_mae", "norm_pix_loss": False}
)
self.post_convert_defaults[1].update(
{"mask_ratio": "0.75",}
)
def convert_nonlinearity(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
activation = old_state_dict[old_key]
new_state_dict[new_key] = activation
if from_index == 0:
# set `encoder_nonlinearity` and `decoder_nonlinearity` to `activation`
new_state_dict["decoder_nonlinearity"] = activation
def assert_decoder_nonlinearity(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if old_state_dict["encoder_nonlinearity"] != old_state_dict[old_key]:
raise ConfigConversionError(
"Encoder & Decoder nonlinearities must be the same in HF model. Got: {} vs {}".format(
old_state_dict["encoder_nonlinearity"],
old_state_dict[old_key],
)
)
def post_config_convert(
self,
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
):
final_config = super().post_config_convert(
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
)
# pop extraneous keys from ViT ConfigConverter
if from_index == 0:
final_config["model"].pop("num_classes")
else:
final_config.pop("encoder_stride")
final_config.pop("num_labels")
return final_config