Source code for cerebras.modelzoo.tools.checkpoint_converters.bert

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Tuple

import torch

from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_CS_CS,
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_CS_CS,
    BaseConfigConverter_HF_CS,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatIndices,
    FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.helper import (
    Build_HF_CS_Converter_WithOptionalModel,
    maybe_tie_lm_head,
)


[docs]class Converter_BertLayerNorm_HF_CS(BaseCheckpointConverter_HF_CS): def __init__(self, hf_name, cs_name): super().__init__() self.rules = [ # torch.nn.LayerNorm has .weight & .bias properties ConversionRule( [ EquivalentSubkey(hf_name, cs_name), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # Old HF implementation uses .gamma instead of .weight ConversionRule( [ EquivalentSubkey(hf_name, cs_name), EquivalentSubkey(".gamma", ".weight"), ], action=self.replaceKey, ), # Old HF implementation uses .beta instead of .bias ConversionRule( [ EquivalentSubkey(hf_name, cs_name), EquivalentSubkey(".beta", ".bias"), ], action=self.replaceKey, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None
[docs]class Converter_BertModel_CS16_CS17(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() self.rules = [ # Embedding: ConversionRule( [ EquivalentSubkey("embeddings", "embedding_layer"), r"\.word_embeddings\.weight", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("embeddings", "embedding_layer"), r"\.position_embeddings\.weight", ], action=self.position_embeddings_convert, ), ConversionRule( [ EquivalentSubkey( "embeddings.token_type_embeddings", "embedding_layer.segment_embeddings", ), r"\.weight", ], action=self.replaceKey, ), ConversionRule( [ r"embeddings\.position_ids", ], exists="left", ), ConversionRule( [ EquivalentSubkey("embeddings.", ""), Converter_BertLayerNorm_HF_CS("LayerNorm", "embed_ln_f"), ], action=None, ), # Encoder Layers: ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.self.query", "self_attn.proj_q_dense_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.self.key", "self_attn.proj_k_dense_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.self.value", "self_attn.proj_v_dense_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "attention.output.dense", "self_attn.proj_output_dense_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey("attention.output.", ""), Converter_BertLayerNorm_HF_CS("LayerNorm", "norm1"), ], action=None, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey( "intermediate.dense", "ffn.ffn.0.linear_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey("output.dense", "ffn.ffn.1.linear_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "encoder.layer", "transformer_encoder.layers", ), r"\.\d+\.", EquivalentSubkey("output.", ""), Converter_BertLayerNorm_HF_CS("LayerNorm", "norm2"), ], action=None, ), # Head: ConversionRule( [ r"pooler\.", EquivalentSubkey("dense", "pooler.ffn.0.linear_layer"), r"\.(?:weight|bias)", ], action=self.convert_pooler_factory_fn(), ), ]
[docs] def convert_pooler_factory_fn(self): """ DPR, which uses two BERT sub-converters, requires different behavior of the pooler conversion, so we generalize to allow overriding. """ def bert_pooler_convert( old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): return self.replaceKey( old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ) return bert_pooler_convert
def position_embeddings_convert( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): self.replaceKey( old_key, new_key, old_state_dict, new_state_dict, from_index ) if from_index == 1: # HF stores an register buffer with position_ids position_id_key = re.sub( r"\.position_embeddings\.weight", ".position_ids", new_key ) if "max_position_embeddings" in action_fn_args["configs"][0]: max_position_embeddings = action_fn_args["configs"][0][ "max_position_embeddings" ] else: max_position_embeddings = action_fn_args["configs"][1]["model"][ "max_position_embeddings" ] new_state_dict[position_id_key] = torch.arange( max_position_embeddings ).expand((1, -1)) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS16_CS17
[docs]class ConfigConverter_Bert_CS16_CS17(BaseConfigConverter_CS_CS): def __init__(self): super().__init__() # Config didn't change between 1.6 and 1.7. Copy all keys. self.rules = [ ConversionRule([".*"], action=self.replaceKey), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7"))
[docs]class Converter_BertModel_CS16_CS18(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_BertModel_CS16_CS17(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertModel_CS16_CS17(), ], action=None, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.6"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS16_CS18
[docs]class ConfigConverter_Bert_CS16_CS18(ConfigConverter_Bert_CS16_CS17): def pre_config_convert( self, config, converter_indices, ): config = super().pre_config_convert(config, converter_indices) if converter_indices.direction == 1: if ( "pooler_nonlinearity" in config and config["pooler_nonlinearity"] != config["encoder_nonlinearity"] ): raise ConfigConversionError( "pooler_nonlinearity was introduced in CS 1.8. Prior to that, the pooler " "nonlinearity must be the same as encoder_nonlinearity." ) if "mlm_nonlinearity" in config: if config["mlm_nonlinearity"] != "gelu": raise ConfigConversionError( "mlm_nonlinearity was introduced in CS 1.8. Prior to that, the mlm " "nonlinearity must be gelu." ) else: if config["encoder_nonlinearity"] != "gelu": raise ConfigConversionError( f"mlm_nonlinearity was introduced in CS 1.8. Prior to that, the mlm " f"nonlinearity must be gelu. However, the input config has an " f"mlm_nonlinearity which defaults to encoder_nonlinearity = " f"{config['encoder_nonlinearity']}" ) return config def post_config_convert( self, original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ): if converter_indices.direction == 0: new_config["pooler_nonlinearity"] = new_config[ "encoder_nonlinearity" ] new_config["mlm_nonlinearity"] = "gelu" return super().post_config_convert( original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.6"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), )
[docs]class Converter_Bert_CS17_CS18(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() # Checkpoint didn't change between 1.7 and 1.8. Copy all keys. self.rules = [ ConversionRule([".*"], action=self.replaceKey), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.7"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), ) @classmethod def converter_note(cls) -> str: return ( "BertForPreTraining, BertForSequenceClassification, " "BertForQuestionAnswering, and BertForSummarization classes" ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS17_CS18
# Config didn't change between 1.6 and 1.7. Therefore 1.7 <-> 1.8 # converter is equivalent to 1.6 <-> 1.8 converter.
[docs]class ConfigConverter_Bert_CS17_CS18(ConfigConverter_Bert_CS16_CS18): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.7"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), )
[docs]class Converter_BertModel_HF_CS17( Converter_BertModel_CS16_CS17, BaseCheckpointConverter_HF_CS ): def pre_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, ): # Manually tie weights if ( converter_indices.direction == 1 and configs[1]["model"]["share_embedding_weights"] ): if ( old_state_dict.get( "bert_encoder.embedding_layer.word_embeddings.weight", 0 ) is None ): old_state_dict[ "bert_encoder.embedding_layer.word_embeddings.weight" ] = old_state_dict[ "bert_mlm_head.classifier.ffn.0.linear_layer.weight" ] def pre_checkpoint_convert( self, *args, ): return BaseCheckpointConverter_HF_CS.pre_checkpoint_convert( self, *args, ) def extract_model_dict(self, *args): return BaseCheckpointConverter_HF_CS.extract_model_dict(self, *args) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS17
[docs]class Converter_BertModel_HF_CS18(BaseCheckpointConverter_HF_CS): def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_BertModel_HF_CS17(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [EquivalentSubkey("", "model."), Converter_BertModel_HF_CS17()], action=None, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS17
[docs]class Converter_BertPretrainModel_CS16_CS17(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("bert.", "bert_encoder."), Converter_BertModel_CS16_CS17(), ], ), # CLS: ConversionRule( [ EquivalentSubkey( "cls.predictions.transform.dense", "bert_mlm_head.mlm_transform.ffn.ffn.0.linear_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "cls.predictions.transform.", "bert_mlm_head.mlm_transform.", ), Converter_BertLayerNorm_HF_CS("LayerNorm", "ln"), ], action=None, ), ConversionRule( [ EquivalentSubkey( "cls.predictions.decoder", "bert_mlm_head.classifier.ffn.0.linear_layer", ), r"\.weight", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "cls.predictions.decoder", "bert_mlm_head.classifier.ffn.0.linear_layer", ), r"\.bias", ], action=self.convert_cls_predictions_bias, ), ConversionRule([r"cls\.predictions\.bias"], exists="left"), ConversionRule( [ EquivalentSubkey( "cls.seq_relationship", "bert_cls_head.classifier.ffn.0.linear_layer", ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ] def convert_cls_predictions_bias( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): self.replaceKey( old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ) if from_index == 1: # HF stores an extra copy of the decoder bias in the predictions object itself bias_key = re.sub(r"\.decoder\.", ".", new_key) self.replaceKey( old_key, bias_key, old_state_dict, new_state_dict, from_index, action_fn_args, ) def pre_checkpoint_convert( self, input_checkpoint, output_checkpoint, configs: Tuple[dict, dict], converter_indices: FormatIndices, ): # Don't copy non model keys like optimizer state: logging.warning( "The Bert model changed significantly between {} and {}. As a result, the" " optimizer state won't be included in the converted checkpoint.".format( *self.formats() ) ) output_checkpoint["model"] = {} @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7")) @classmethod def converter_note(cls) -> str: return "BertPretrainModel class" @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS16_CS17
[docs]class Converter_BertPretrainModel_CS16_CS18(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_BertPretrainModel_CS16_CS17(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertPretrainModel_CS16_CS17(), ], action=None, ), ] def pre_checkpoint_convert( self, input_checkpoint, output_checkpoint, configs: Tuple[dict, dict], converter_indices: FormatIndices, ): # Don't copy non model keys like optimizer state: logging.warning( "The Bert model changed significantly between {} and {}. As a result, the" " optimizer state won't be included in the converted checkpoint.".format( *self.formats() ) ) output_checkpoint["model"] = {} @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.6"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), ) @classmethod def converter_note(cls) -> str: return "BertPretrainModel class" @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS16_CS18
[docs]class Converter_BertPretrainModel_HF_CS17( Converter_BertPretrainModel_CS16_CS17, BaseCheckpointConverter_HF_CS ): def pre_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, ): # Manually tie weights old_state_dict = dict(old_state_dict) if converter_indices.direction == 1 and configs[1]["model"].get( "share_embedding_weights", False ): if ( old_state_dict.get( "bert_encoder.embedding_layer.word_embeddings.weight", 0 ) is None ): old_state_dict[ "bert_encoder.embedding_layer.word_embeddings.weight" ] = old_state_dict[ "bert_mlm_head.classifier.ffn.0.linear_layer.weight" ] def pre_checkpoint_convert( self, *args, ): return BaseCheckpointConverter_HF_CS.pre_checkpoint_convert( self, *args, ) def extract_model_dict(self, *args): return BaseCheckpointConverter_HF_CS.extract_model_dict(self, *args) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7")) @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForPreTraining".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS17
[docs]class Converter_BertPretrainModel_HF_CS18(Converter_BertPretrainModel_HF_CS17): def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_BertPretrainModel_HF_CS17(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_BertPretrainModel_HF_CS17(), ], action=None, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), ) @classmethod def converter_note(cls) -> str: return "{} <-> {} for BertForPreTraining".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS18
[docs]class ConfigConverter_Bert_HF_CS17(BaseConfigConverter_HF_CS): def __init__(self): super().__init__() # allows DPR child class to set model_type without being # overriden in the super().init() call if not hasattr(self, "model_type"): self.model_type = "bert" self.rules = [ ConversionRule( ["model_type"], action=BaseConfigConverter.assert_factory_fn( 0, self.model_type ), ), # Embedding ConversionRule(["vocab_size"], action=self.replaceKey), ConversionRule( ["position_embedding_type"], action=self.convert_position_embedding_type, ), ConversionRule( ["max_position_embeddings"], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "tie_word_embeddings", "share_embedding_weights" ) ], action=self.replaceKey, ), # Decoder Block ConversionRule( ["hidden_size"], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("num_attention_heads", "num_heads")], action=self.replaceKey, ), ConversionRule( ["num_hidden_layers"], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("intermediate_size", "filter_size")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("hidden_act", "encoder_nonlinearity")], action=self.replaceKey, ), ConversionRule( ["mlm_nonlinearity"], action=self.assert_mlm_nonlinearity, ), ConversionRule( ["pooler_nonlinearity"], action=BaseConfigConverter.assert_factory_fn(1, "tanh"), ), ConversionRule( [EquivalentSubkey("hidden_dropout_prob", "dropout_rate")], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "attention_probs_dropout_prob", "attention_dropout_rate" ) ], action=self.replaceKey, ), ConversionRule( ["disable_nsp"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["type_vocab_size"], action=BaseConfigConverter.assert_factory_fn(0, 2), ), ConversionRule( ["is_decoder"], action=BaseConfigConverter.assert_factory_fn(0, False), ), ConversionRule( ["add_cross_attention"], action=BaseConfigConverter.assert_factory_fn(0, False), ), ConversionRule( [EquivalentSubkey("layer_norm_eps", "layer_norm_epsilon")], action=self.replaceKey, ), ConversionRule( ["attention_type"], action=BaseConfigConverter.assert_factory_fn( 1, "scaled_dot_product" ), ), ConversionRule( ["use_projection_bias_in_attention"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias_in_mlm"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_output_bias_in_mlm"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule(["initializer_range"], action=self.replaceKey), ] self.pre_convert_defaults[0].update( { "vocab_size": 30522, "hidden_size": 768, "num_hidden_layers": 12, "num_attention_heads": 12, "intermediate_size": 3072, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 512, "layer_norm_eps": 1e-12, "tie_word_embeddings": True, } ) self.pre_convert_defaults[1].update( { "share_embedding_weights": True, "encoder_nonlinearity": "gelu", }, ) self.post_convert_defaults[0].update({"model_type": "bert"}) self.post_convert_defaults[1].update({"enable_vts": False}) def convert_position_embedding_type( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): # HF supports absolute, relative_key, relative_key_query # CS supports learned, fixed embed_type = old_state_dict[old_key] if from_index == 0: if embed_type == "absolute": new_state_dict[new_key] = "learned" else: raise ConfigConversionError( "CS model doesn't support HF's position_embedding_type={}".format( embed_type ) ) else: if embed_type == "learned": new_state_dict[new_key] = "absolute" else: raise ConfigConversionError( "HF model doesn't support CS's position_embedding_type={}".format( embed_type ) ) def assert_mlm_nonlinearity( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if ( old_state_dict[old_key] != old_state_dict["encoder_nonlinearity"] and old_state_dict[old_key] is not None ): raise ConfigConversionError( "HF model doesn't support different encoder & mlm nonlinearities" ) def post_config_convert( self, original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ): if converter_indices.direction == 0: if ( "mlm_nonlinearity" not in new_config and "encoder_nonlinearity" in new_config and new_config["encoder_nonlinearity"] != "gelu" ): logging.warning( f"HF used a mlm_nonlinearity of {new_config['encoder_nonlinearity']} while " f"CS 1.7 is fixed to gelu. Please use CS 1.8 if you want to control " f"mlm_nonlinearity." ) new_config["mlm_nonlinearity"] = "gelu" return super().post_config_convert( original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_Bert_HF_CS18(ConfigConverter_Bert_HF_CS17): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"), ) def pre_config_convert( self, config, converter_indices, ): config = super().pre_config_convert(config, converter_indices) if converter_indices.direction == 1: if "pooler_nonlinearity" not in config: if config["encoder_nonlinearity"] != "tanh": raise ConfigConversionError( f"CS Model used a pooler_nonlinearity of {config['encoder_nonlinearity']} " f"according to encoder_nonlinearity. HF only supports tanh in the pooler " f"nonlinearity." ) return config def post_config_convert( self, original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ): if converter_indices.direction == 0: new_config["pooler_nonlinearity"] = "tanh" if "mlm_nonlinearity" not in new_config: new_config["mlm_nonlinearity"] = new_config[ "encoder_nonlinearity" ] return super().post_config_convert( original_config, old_config, new_config, converter_indices, drop_unmatched_keys, )
[docs]class Converter_Bert_CS18_CS20(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() # Checkpoint didn't change between 1.8/1.9 and 2.0. Handle weight tying # and copy all keys. self.rules = [ ConversionRule( [ "(?:model.|)", EquivalentSubkey( "bert_encoder.embedding_layer.word_embeddings", "bert_mlm_head.classifier.ffn.0.linear_layer", ), "\.weight", ], action=maybe_tie_lm_head, ), ConversionRule( [ "(?:model.|)", EquivalentSubkey( "bert_mlm_head.classifier.ffn.0.linear_layer", "bert_encoder.embedding_layer.word_embeddings", ), "\.weight", ], action=maybe_tie_lm_head, ), ConversionRule([".*"], action=self.replaceKey), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.8", "cs-1.9"), FormatVersions("cs-2.0"), ) @classmethod def converter_note(cls) -> str: return "BertForPreTraining class" @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_CS18_CS20
# Config didn't change between 1.8/1.9 and 2.0.
[docs]class ConfigConverter_Bert_CS18_CS20(BaseConfigConverter_CS_CS): def __init__(self): super().__init__() self.rules = [ ConversionRule([".*"], action=self.replaceKey), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("cs-1.8", "cs-1.9"), FormatVersions("cs-2.0"), )
########################################################### # In CS 2.1, we refactored the embedding layer. # CS 2.0 <> CS 2.1, and HF <> CS 2.1 converters: ###########################################################
[docs]class Converter_Bert_CS20_CS21(BaseCheckpointConverter_CS_CS): def __init__(self): super().__init__() self.rules = [ # Refactored embeddings (BERT only supported fixed): ConversionRule( [ "(?:model\.|)", "(?:bert_encoder|bert)\.", EquivalentSubkey( "embedding_layer.position_embeddings.weight", "embedding_layer.position_embeddings.embed.weight", ), ], action=self.replaceKey, ), ConversionRule( [ "(?:model\.|)", "(?:bert_encoder|bert)\.", EquivalentSubkey( "embedding_layer.position_embeddings", "embedding_layer.position_embeddings.fpe", ), ], action=self.replaceKey, ), # Copy everything else ConversionRule([".*"], action=self.replaceKey), ] @classmethod def converter_note(cls) -> str: return ( "BertForPreTraining, BertForSequenceClassification, " "BertForQuestionAnswering, and BertForSummarization classes" ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-2.0"), FormatVersions("cs-2.1")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_BertModel_CS20_CS21
[docs]class ConfigConverter_BertModel_CS20_CS21(BaseConfigConverter_CS_CS): def __init__(self): super().__init__() # No differences in config self.rules = [ ConversionRule([".*"], action=self.replaceKey), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-2.0"), FormatVersions("cs-2.1"))
[docs]class ConfigConverter_Bert_HF_CS21(ConfigConverter_Bert_HF_CS18): "CS 2.1 config is the same as CS 2.0" def __init__(self): super().__init__() self.post_convert_defaults[1].update({"freeze_ffn_bias_in_glu": False}) del self.post_convert_defaults[1]["enable_vts"] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2", "cs-2.3"), )
[docs]class Converter_BertModel_WithoutOptionalModel_HF_CS21( Converter_BertModel_HF_CS17 ): def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("embeddings", "embedding_layer"), "\.position_embeddings", EquivalentSubkey("", ".embed"), "\.weight", ], action=self.position_embeddings_convert, ), *self.rules, ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS21
[docs]class Converter_BertPretrainModel_WithoutOptionalModel_HF_CS21( Converter_BertPretrainModel_HF_CS17 ): def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("bert.", "bert_encoder."), Converter_BertModel_WithoutOptionalModel_HF_CS21(), ], ), *self.rules, ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2", "cs-2.3"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS21
Converter_BertPretrainModel_HF_CS21 = Build_HF_CS_Converter_WithOptionalModel( "Converter_BertPretrainModel_HF_CS21", Converter_BertPretrainModel_WithoutOptionalModel_HF_CS21, derived_class=Converter_BertPretrainModel_WithoutOptionalModel_HF_CS21, )
[docs]class ConfigConverter_Bert_HF_CS23(ConfigConverter_Bert_HF_CS21): def supports_mup_conversion(self): return True @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.3"), )
[docs]class Converter_BertModel_WithoutOptionalModel_HF_CS23( Converter_BertModel_WithoutOptionalModel_HF_CS21 ): def supports_mup_conversion(self): return True @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.3"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS23
[docs]class Converter_BertPretrainModel_WithoutOptionalModel_HF_CS23( Converter_BertPretrainModel_WithoutOptionalModel_HF_CS21 ): def supports_mup_conversion(self): return True @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.3"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Bert_HF_CS23
Converter_BertPretrainModel_HF_CS23 = Build_HF_CS_Converter_WithOptionalModel( "Converter_BertPretrainModel_HF_CS23", Converter_BertPretrainModel_WithoutOptionalModel_HF_CS23, derived_class=Converter_BertPretrainModel_WithoutOptionalModel_HF_CS23, )