# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Tuple
import torch
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
BaseCheckpointConverter_CS_CS,
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
BaseConfigConverter_CS_CS,
BaseConfigConverter_HF_CS,
ConfigConversionError,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
[docs]class Converter_BertLayerNorm_HF_CS(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self, hf_name, cs_name):
super().__init__()
self.rules = [
# torch.nn.LayerNorm has .weight & .bias properties
ConversionRule(
[EquivalentSubkey(hf_name, cs_name), "\.(?:weight|bias)",],
action=self.replaceKey,
),
# Old HF implementation uses .gamma instead of .weight
ConversionRule(
[
EquivalentSubkey(hf_name, cs_name),
EquivalentSubkey(".gamma", ".weight"),
],
action=self.replaceKey,
),
# Old HF implementation uses .beta instead of .bias
ConversionRule(
[
EquivalentSubkey(hf_name, cs_name),
EquivalentSubkey(".beta", ".bias"),
],
action=self.replaceKey,
),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return None
[docs]class Converter_BertModel_CS16_CS17(BaseCheckpointConverter_CS_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Embedding:
ConversionRule(
[
EquivalentSubkey("embeddings", "embedding_layer"),
"\.word_embeddings\.weight",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("embeddings", "embedding_layer"),
"\.position_embeddings\.weight",
],
action=self.position_embeddings_convert,
),
ConversionRule(
[
EquivalentSubkey(
"embeddings.token_type_embeddings",
"embedding_layer.segment_embeddings",
),
"\.weight",
],
action=self.replaceKey,
),
ConversionRule(["embeddings\.position_ids",], exists="left",),
ConversionRule(
[
EquivalentSubkey("embeddings.", ""),
Converter_BertLayerNorm_HF_CS("LayerNorm", "embed_ln_f"),
],
action=None,
),
# Encoder Layers:
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey(
"attention.self.query", "self_attn.proj_q_dense_layer"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey(
"attention.self.key", "self_attn.proj_k_dense_layer"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey(
"attention.self.value", "self_attn.proj_v_dense_layer"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey(
"attention.output.dense",
"self_attn.proj_output_dense_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey("attention.output.", ""),
Converter_BertLayerNorm_HF_CS("LayerNorm", "norm1"),
],
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey(
"intermediate.dense", "ffn.ffn.0.linear_layer"
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey("output.dense", "ffn.ffn.1.linear_layer"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"encoder.layer", "transformer_encoder.layers",
),
"\.\d+\.",
EquivalentSubkey("output.", ""),
Converter_BertLayerNorm_HF_CS("LayerNorm", "norm2"),
],
action=None,
),
# Head:
ConversionRule(
[
"pooler\.",
EquivalentSubkey("dense", "pooler.ffn.0.linear_layer"),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
]
def position_embeddings_convert(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
self.replaceKey(
old_key, new_key, old_state_dict, new_state_dict, from_index
)
if from_index == 1:
# HF stores an register buffer with position_ids
position_id_key = re.sub(
"\.position_embeddings\.weight", ".position_ids", new_key
)
if "max_position_embeddings" in action_fn_args["configs"][0]:
max_position_embeddings = action_fn_args["configs"][0][
"max_position_embeddings"
]
else:
max_position_embeddings = action_fn_args["configs"][1]["model"][
"max_position_embeddings"
]
new_state_dict[position_id_key] = torch.arange(
max_position_embeddings
).expand((1, -1))
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_CS16_CS17
[docs]class ConfigConverter_Bert_CS16_CS17(BaseConfigConverter_CS_CS):
[docs] def __init__(self):
super().__init__()
# Config didn't change between 1.6 and 1.7. Copy all keys.
self.rules = [
ConversionRule([".*"], action=self.replaceKey),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7"))
[docs]class Converter_BertModel_CS16_CS18(BaseCheckpointConverter_CS_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule([Converter_BertModel_CS16_CS17(),], action=None,),
# Catch checkpoints from 1.7/1.8
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_BertModel_CS16_CS17(),
],
action=None,
),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("cs-1.6"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_CS16_CS18
[docs]class ConfigConverter_Bert_CS16_CS18(ConfigConverter_Bert_CS16_CS17):
[docs] def __init__(self):
super().__init__()
def pre_config_convert(
self, config, from_index,
):
config = super().pre_config_convert(config, from_index)
if from_index == 1:
if (
"pooler_nonlinearity" in config
and config["pooler_nonlinearity"]
!= config["encoder_nonlinearity"]
):
raise ConfigConversionError(
"pooler_nonlinearity was introduced in CS 1.8. Prior to that, the pooler nonlinearity must be the same as encoder_nonlinearity"
)
if "mlm_nonlinearity" in config:
if config["mlm_nonlinearity"] != "gelu":
raise ConfigConversionError(
"mlm_nonlinearity was introduced in CS 1.8. Prior to that, the mlm nonlinearity must be gelu"
)
else:
if config["encoder_nonlinearity"] != "gelu":
raise ConfigConversionError(
"mlm_nonlinearity was introduced in CS 1.8. Prior to that, the mlm nonlinearity must be gelu. However, the input config has an mlm_nonlinearity which defaults to encoder_nonlinearity = {}".format(
config["encoder_nonlinearity"]
)
)
return config
def post_config_convert(
self,
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
):
if from_index == 0:
new_config["pooler_nonlinearity"] = new_config[
"encoder_nonlinearity"
]
new_config["mlm_nonlinearity"] = "gelu"
return super().post_config_convert(
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("cs-1.6"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
[docs]class Converter_Bert_CS17_CS18(BaseCheckpointConverter_CS_CS):
[docs] def __init__(self):
super().__init__()
# Checkpoint didn't change between 1.7 and 1.8. Copy all keys.
self.rules = [
ConversionRule([".*"], action=self.replaceKey),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("cs-1.7"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
@classmethod
def converter_note(cls) -> str:
return (
"BertForPreTraining, BertForSequenceClassification, "
"BertForQuestionAnswering, and BertForSummarization classes"
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_CS17_CS18
[docs]class ConfigConverter_Bert_CS17_CS18(ConfigConverter_Bert_CS16_CS18):
[docs] def __init__(self):
# Config didn't change between 1.6 and 1.7. Therefore 1.7 <-> 1.8
# converter is equivalent to 1.6 <-> 1.8 converter.
super().__init__()
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("cs-1.7"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
[docs]class Converter_BertModel_HF_CS17(
Converter_BertModel_CS16_CS17, BaseCheckpointConverter_HF_CS
):
[docs] def __init__(self):
super().__init__()
[docs] def pre_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
from_index,
drop_unmatched_keys,
):
# Manually tie weights
if from_index == 1 and configs[1]["model"]["share_embedding_weights"]:
if (
old_state_dict.get(
"bert_encoder.embedding_layer.word_embeddings.weight", 0
)
is None
):
old_state_dict[
"bert_encoder.embedding_layer.word_embeddings.weight"
] = old_state_dict[
"bert_mlm_head.classifier.ffn.0.linear_layer.weight"
]
[docs] def pre_checkpoint_convert(
self, *args,
):
return BaseCheckpointConverter_HF_CS.pre_checkpoint_convert(
self, *args,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-1.7"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_HF_CS17
[docs]class Converter_BertModel_HF_CS18(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule([Converter_BertModel_HF_CS17(),], action=None,),
# Catch checkpoints from 1.7/1.8
ConversionRule(
[EquivalentSubkey("", "model."), Converter_BertModel_HF_CS17()],
action=None,
),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_HF_CS17
[docs]class Converter_BertPretrainModel_CS16_CS17(BaseCheckpointConverter_CS_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey("bert.", "bert_encoder."),
Converter_BertModel_CS16_CS17(),
],
),
# CLS:
ConversionRule(
[
EquivalentSubkey(
"cls.predictions.transform.dense",
"bert_mlm_head.mlm_transform.ffn.ffn.0.linear_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"cls.predictions.transform.",
"bert_mlm_head.mlm_transform.",
),
Converter_BertLayerNorm_HF_CS("LayerNorm", "ln"),
],
action=None,
),
ConversionRule(
[
EquivalentSubkey(
"cls.predictions.decoder",
"bert_mlm_head.classifier.ffn.0.linear_layer",
),
"\.weight",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"cls.predictions.decoder",
"bert_mlm_head.classifier.ffn.0.linear_layer",
),
"\.bias",
],
action=self.convert_cls_predictions_bias,
),
ConversionRule(["cls\.predictions\.bias"], exists="left"),
ConversionRule(
[
EquivalentSubkey(
"cls.seq_relationship",
"bert_cls_head.classifier.ffn.0.linear_layer",
),
"\.(?:weight|bias)",
],
action=self.replaceKey,
),
]
def convert_cls_predictions_bias(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
self.replaceKey(
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
)
if from_index == 1:
# HF stores an extra copy of the decoder bias in the predictions object itself
bias_key = re.sub("\.decoder\.", ".", new_key)
self.replaceKey(
old_key,
bias_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
)
[docs] def pre_checkpoint_convert(
self,
input_checkpoint,
output_checkpoint,
configs: Tuple[dict, dict],
from_index: int,
):
# Don't copy non model keys like optimizer state:
logging.warning(
"The Bert model changed significantly between {} and {}. As a result, the"
" optimizer state won't be included in the converted checkpoint.".format(
*self.formats()
)
)
output_checkpoint["model"] = {}
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("cs-1.6"), FormatVersions("cs-1.7"))
@classmethod
def converter_note(cls) -> str:
return "BertPretrainModel class"
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_CS16_CS17
[docs]class Converter_BertPretrainModel_CS16_CS18(BaseCheckpointConverter_CS_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule(
[Converter_BertPretrainModel_CS16_CS17(),], action=None,
),
# Catch checkpoints from 1.7/1.8
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_BertPretrainModel_CS16_CS17(),
],
action=None,
),
]
[docs] def pre_checkpoint_convert(
self,
input_checkpoint,
output_checkpoint,
configs: Tuple[dict, dict],
from_index: int,
):
# Don't copy non model keys like optimizer state:
logging.warning(
"The Bert model changed significantly between {} and {}. As a result, the"
" optimizer state won't be included in the converted checkpoint.".format(
*self.formats()
)
)
output_checkpoint["model"] = {}
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("cs-1.6"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
@classmethod
def converter_note(cls) -> str:
return "BertPretrainModel class"
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_CS16_CS18
[docs]class Converter_BertPretrainModel_HF_CS17(
Converter_BertPretrainModel_CS16_CS17, BaseCheckpointConverter_HF_CS
):
[docs] def __init__(self):
super().__init__()
[docs] def pre_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
from_index,
drop_unmatched_keys,
):
# Manually tie weights
if from_index == 1 and configs[1]["model"]["share_embedding_weights"]:
if (
old_state_dict.get(
"bert_encoder.embedding_layer.word_embeddings.weight", 0
)
is None
):
old_state_dict[
"bert_encoder.embedding_layer.word_embeddings.weight"
] = old_state_dict[
"bert_mlm_head.classifier.ffn.0.linear_layer.weight"
]
[docs] def pre_checkpoint_convert(
self, *args,
):
return BaseCheckpointConverter_HF_CS.pre_checkpoint_convert(
self, *args,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-1.7"))
@classmethod
def converter_note(cls) -> str:
return "{} <-> {} for BertForPreTraining".format(
cls.formats()[0], cls.formats()[1]
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_HF_CS17
[docs]class Converter_BertPretrainModel_HF_CS18(Converter_BertPretrainModel_HF_CS17):
[docs] def __init__(self):
super().__init__()
self.rules = [
# Catch checkpoints from Pytorch 2.0 API
ConversionRule(
[Converter_BertPretrainModel_HF_CS17(),], action=None,
),
# Catch checkpoints from 1.7/1.8
ConversionRule(
[
EquivalentSubkey("", "model."),
Converter_BertPretrainModel_HF_CS17(),
],
action=None,
),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
@classmethod
def converter_note(cls) -> str:
return "{} <-> {} for BertForPreTraining".format(
cls.formats()[0], cls.formats()[1]
)
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Bert_HF_CS18
[docs]class ConfigConverter_Bert_HF_CS17(BaseConfigConverter_HF_CS):
[docs] def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
["model_type"],
action=BaseConfigConverter.assert_factory_fn(0, "bert"),
),
# Embedding
ConversionRule(["vocab_size"], action=self.replaceKey),
ConversionRule(
["position_embedding_type"],
action=self.convert_position_embedding_type,
),
ConversionRule(
["max_position_embeddings"], action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"tie_word_embeddings", "share_embedding_weights"
)
],
action=self.replaceKey,
),
# Decoder Block
ConversionRule(["hidden_size"], action=self.replaceKey,),
ConversionRule(
[EquivalentSubkey("num_attention_heads", "num_heads")],
action=self.replaceKey,
),
ConversionRule(["num_hidden_layers"], action=self.replaceKey,),
ConversionRule(
[EquivalentSubkey("intermediate_size", "filter_size")],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("hidden_act", "encoder_nonlinearity")],
action=self.replaceKey,
),
ConversionRule(
["mlm_nonlinearity"], action=self.assert_mlm_nonlinearity,
),
ConversionRule(
["pooler_nonlinearity"],
action=BaseConfigConverter.assert_factory_fn(1, "tanh"),
),
ConversionRule(
[EquivalentSubkey("hidden_dropout_prob", "dropout_rate")],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey(
"attention_probs_dropout_prob", "attention_dropout_rate"
)
],
action=self.replaceKey,
),
ConversionRule(
["disable_nsp"],
action=BaseConfigConverter.assert_factory_fn(1, False),
),
ConversionRule(
["type_vocab_size"],
action=BaseConfigConverter.assert_factory_fn(0, 2),
),
ConversionRule(
["is_decoder"],
action=BaseConfigConverter.assert_factory_fn(0, False),
),
ConversionRule(
["add_cross_attention"],
action=BaseConfigConverter.assert_factory_fn(0, False),
),
ConversionRule(
[EquivalentSubkey("layer_norm_eps", "layer_norm_epsilon")],
action=self.replaceKey,
),
ConversionRule(
["attention_type"],
action=BaseConfigConverter.assert_factory_fn(
1, "scaled_dot_product"
),
),
ConversionRule(
["use_projection_bias_in_attention"],
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_ffn_bias_in_attention"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_ffn_bias"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_ffn_bias_in_mlm"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(
["use_output_bias_in_mlm"],
exists="right",
action=BaseConfigConverter.assert_factory_fn(1, True),
),
ConversionRule(["initializer_range"], action=self.replaceKey),
]
self.pre_convert_defaults[0].update(
{
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"layer_norm_eps": 1e-12,
"tie_word_embeddings": True,
}
)
self.pre_convert_defaults[1].update(
{"share_embedding_weights": True, "encoder_nonlinearity": "gelu",},
)
self.post_convert_defaults[0].update({"model_type": "bert"})
self.post_convert_defaults[1].update({"enable_vts": False})
def convert_position_embedding_type(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# HF supports absolute, relative_key, relative_key_query
# CS supports learned, fixed
embed_type = old_state_dict[old_key]
if from_index == 0:
if embed_type == "absolute":
new_state_dict[new_key] = "learned"
else:
raise ConfigConversionError(
"CS model doesn't support HF's position_embedding_type={}".format(
embed_type
)
)
else:
if embed_type == "learned":
new_state_dict[new_key] = "absolute"
else:
raise ConfigConversionError(
"HF model doesn't support CS's position_embedding_type={}".format(
embed_type
)
)
def assert_mlm_nonlinearity(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if old_state_dict[old_key] != old_state_dict["encoder_nonlinearity"]:
raise ConfigConversionError(
"HF model doesn't support different encoder & mlm nonlinearities"
)
def post_config_convert(
self,
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
):
if from_index == 0:
if (
"mlm_nonlinearity" not in new_config
and "encoder_nonlinearity" in new_config
and new_config["encoder_nonlinearity"] != "gelu"
):
logging.warning(
"HF used a mlm_nonlinearity of {} while CS 1.7 is fixed to gelu. Please use CS 1.8 if you want to control mlm_nonlinearity".format(
new_config["encoder_nonlinearity"]
)
)
new_config["mlm_nonlinearity"] = "gelu"
return super().post_config_convert(
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_Bert_HF_CS18(ConfigConverter_Bert_HF_CS17):
[docs] def __init__(self):
super().__init__()
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (
FormatVersions("hf"),
FormatVersions("cs-1.8", "cs-1.9", "cs-2.0"),
)
def pre_config_convert(
self, config, from_index,
):
config = super().pre_config_convert(config, from_index)
if from_index == 1:
if "pooler_nonlinearity" not in config:
if config["encoder_nonlinearity"] != "tanh":
raise ConfigConversionError(
"CS Model used a pooler_nonlinearity of {} according to encoder_nonlinearity. HF only supports tanh in the pooler nonlinearity".format(
config["encoder_nonlinearity"]
)
)
return config
def post_config_convert(
self,
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
):
if from_index == 0:
new_config["pooler_nonlinearity"] = "tanh"
if "mlm_nonlinearity" not in new_config:
new_config["mlm_nonlinearity"] = new_config[
"encoder_nonlinearity"
]
return super().post_config_convert(
original_config,
old_config,
new_config,
from_index,
drop_unmatched_keys,
)