Source code for cerebras.modelzoo.tools.checkpoint_converters.gpt_neox_hf_cs

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Tuple

import torch

from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_HF_CS,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.gptj_hf_cs import (
    ConfigConverter_GPTJModel_CS18_CS20,
    Converter_GPTJ_LMHeadModel_CS18_CS20,
    Converter_GPTJ_LMHeadModel_CS20_CS21,
)


[docs]class Converter_GPT_Neox_Attention_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("dense", "proj_output_dense_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("query_key_value", "proj_q_dense_layer"), r"\.(?:weight|bias)", ], action=self.qkv_converter, ), ConversionRule( [ EquivalentSubkey("query_key_value", "proj_k_dense_layer"), r"\.(?:weight|bias)", ], action=self.assert_already_converted, ), ConversionRule( [ EquivalentSubkey("query_key_value", "proj_v_dense_layer"), r"\.(?:weight|bias)", ], action=self.assert_already_converted, ), ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None def interleave_helper(self, t, cs_config): rotary_dim = cs_config["model"]["rotary_dim"] if len(t.shape) == 3: to_rotate = t[:, :rotary_dim, :] to_pass = t[:, rotary_dim:, :] to_rotate = ( to_rotate.reshape(t.shape[0], 2, -1, t.shape[-1]) .permute(0, 2, 1, 3) .reshape(t.shape[0], -1, t.shape[-1]) ) interleaved = torch.cat((to_rotate, to_pass), dim=1) elif len(t.shape) == 2: to_rotate = t[:, :rotary_dim] to_pass = t[:, rotary_dim:] to_rotate = ( to_rotate.reshape(t.shape[0], 2, -1) .permute(0, 2, 1) .reshape(t.shape[0], -1) ) interleaved = torch.cat((to_rotate, to_pass), dim=1) else: assert False, ( "shape of query, key, value projection tensor has to have shape of length 2 " "(biases) or 3 (weights) when converting from HF to CS." ) return interleaved def reverse_interleave_helper(self, t, cs_config, num_heads=None): if num_heads is None: num_heads = cs_config["model"]["num_heads"] rotary_dim = cs_config["model"]["rotary_dim"] if len(t.shape) == 2: t = t.reshape(num_heads, -1, t.shape[-1]) to_rotate = t[:, :rotary_dim, :] to_pass = t[:, rotary_dim:, :] # pylint: disable=redefined-builtin reversed = ( to_rotate.reshape(num_heads, -1, 2, t.shape[-1]) .permute(0, 2, 1, 3) .reshape(num_heads, rotary_dim, t.shape[-1]) ) reversed = torch.cat((reversed, to_pass), dim=1) elif len(t.shape) == 1: t = t.reshape(num_heads, -1) to_rotate = t[:, :rotary_dim] to_pass = t[:, rotary_dim:] reversed = ( to_rotate.reshape(num_heads, -1, 2) .permute(0, 2, 1) .reshape(num_heads, -1) ) reversed = torch.cat((reversed, to_pass), dim=1) else: assert False, ( "shape of query, key, value projection tensor has to have shape of length 1 " "(biases) or 2 (weights) when converting from CS to HF." ) return reversed def qkv_converter( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: self.qkv_converter_hf_to_cs17( old_key, new_key, old_state_dict, new_state_dict, action_fn_args ) else: self.qkv_converter_cs17_to_hf( old_key, new_key, old_state_dict, new_state_dict, action_fn_args ) def qkv_converter_hf_to_cs17( self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args ): # HF represents Q, K, and V in a packed format (torch.Size(3*hidden, hidden)). We need to # unpack the weight and bias tensor for CS 1.7 format. q_key = new_key k_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key) v_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key) cs_config = action_fn_args["configs"][1] num_heads = cs_config["model"]["num_heads"] if new_key.endswith(".bias"): assert len(old_state_dict[old_key].shape) == 1 packed_dim = old_state_dict[old_key].shape[0] embed_dim = packed_dim // 3 head_size = embed_dim // num_heads assert 3 * embed_dim == packed_dim, ( f"Invalid tensor shape {old_state_dict[old_key].shape} at {old_key}. Bias should " f"be divisible by 3 since Q, K, and V are packed." ) split_by_num_heads = old_state_dict[old_key].reshape(num_heads, -1) query, key, value = torch.split( split_by_num_heads, head_size, dim=1 ) query = self.interleave_helper(query, cs_config) key = self.interleave_helper(key, cs_config) query = query.reshape(-1) value = value.reshape(-1) key = key.reshape(-1) new_state_dict[q_key] = query new_state_dict[k_key] = key new_state_dict[v_key] = value elif new_key.endswith(".weight"): packed_dim, dim = old_state_dict[old_key].shape head_size = dim // num_heads assert 3 * dim == packed_dim, ( f"Invalid tensor shape {old_state_dict[old_key].shape} at {old_key}. The first " f"dimension (packed_dim) should be 3x the second dimension (embed_dim) since " f"Q, K, and V are packed." ) split_by_num_heads = old_state_dict[old_key].reshape( num_heads, -1, dim ) query, key, value = torch.split( split_by_num_heads, head_size, dim=1 ) query = self.interleave_helper(query, cs_config) key = self.interleave_helper(key, cs_config) query = query.reshape(-1, dim) value = value.reshape(-1, dim) key = key.reshape(-1, dim) new_state_dict[q_key] = query new_state_dict[k_key] = key new_state_dict[v_key] = value else: raise ValueError("Invalid key after conversion: {}".format(new_key)) def qkv_converter_cs17_to_hf( self, old_key, new_key, old_state_dict, new_state_dict, action_fn_args ): # HF represents Q, K, and V in a packed format. It also contains # special ".bias" and ".masked_bias" register buffers that need to be # initialized q_key = old_key k_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_k_dense_layer.", q_key) v_key = re.sub(r"\.proj_q_dense_layer\.", ".proj_v_dense_layer.", q_key) assert ( k_key in old_state_dict ), "Expected the following key to exist! {}".format(k_key) assert ( v_key in old_state_dict ), "Expected the following key to exist! {}".format(v_key) query = old_state_dict[q_key] value = old_state_dict[v_key] key = old_state_dict[k_key] if new_key.endswith(".bias"): cs_config = action_fn_args["configs"][1] max_positions = cs_config["model"]["max_position_embeddings"] rotary_dim = cs_config["model"]["rotary_dim"] num_heads = cs_config["model"]["num_heads"] hf_config = action_fn_args["configs"][0] rotary_emb_base = hf_config["rotary_emb_base"] # map qkv query = self.reverse_interleave_helper(query, cs_config) key = self.reverse_interleave_helper(key, cs_config) value = value.reshape(num_heads, -1) packed_qkv = torch.cat( ( query, key, value, ), dim=-1, ) packed_qkv = packed_qkv.reshape(-1) new_state_dict[new_key] = packed_qkv # build model params that don't exist in CS models attn_bias_key = re.sub(r"\.query_key_value\.", ".", new_key) new_state_dict[attn_bias_key] = torch.tril( torch.ones((max_positions, max_positions), dtype=torch.uint8) ).view(1, 1, max_positions, max_positions) masked_bias_key = re.sub( r"\.query_key_value\.", ".masked_", new_key ) new_state_dict[masked_bias_key] = torch.tensor(-1e9) inv_freq_key = re.sub( r"\.query_key_value\.bias", ".rotary_emb.inv_freq", new_key ) new_state_dict[inv_freq_key] = 1.0 / ( rotary_emb_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim) ) elif new_key.endswith(".weight"): num_heads = action_fn_args["configs"][1]["model"]["num_heads"] hidden_size = query.shape[-1] query = self.reverse_interleave_helper( query, action_fn_args["configs"][1] ) key = self.reverse_interleave_helper( key, action_fn_args["configs"][1] ) value = value.reshape(num_heads, -1, value.shape[-1]) packed_qkv = torch.cat( ( query, key, value, ), dim=1, ) packed_qkv = packed_qkv.reshape(-1, hidden_size) new_state_dict[new_key] = packed_qkv else: raise ValueError("Invalid key after conversion: {}".format(new_key)) def assert_already_converted( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: # We should never hit this case as this key should have been matched # already assert False, "Invalid key: {}".format(old_key) else: # When we convert from CS -> HF, the proj_q_dense_layer should also handle # conversion of proj_k_dense_layer and proj_v_dense_layer since HF # represents these three layers in a packed format. We simply need # to test that the key containing the packed format has already # been converted. assert ( new_key in new_state_dict ), "Key should've been already converted: {} -> {}".format( old_key, new_key )
[docs]class Converter_GPT_Neox_Headless_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # embedding ConversionRule( [ EquivalentSubkey( "embed_in", "embedding_layer.word_embeddings" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # final layer norm ConversionRule( [ EquivalentSubkey( "final_layer_norm", "transformer_decoder.norm" ), r"\.(?:weight|bias)", ], action=self.replace_final_norm, ), # attention ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("attention.", "self_attn."), Converter_GPT_Neox_Attention_HF_CS17(), ], action=None, ), # 2 layernorms ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("input_layernorm", "norm1"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("post_attention_layernorm", "norm3"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # ffn ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey( "mlp.dense_h_to_4h", "ffn.ffn.0.linear_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey( "mlp.dense_4h_to_h", "ffn.ffn.1.linear_layer" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # others ConversionRule([r"lm_head\.(?:weight|bias)"], exists="right"), ConversionRule([r"ln_f\.(?:weight|bias)"], exists="right"), ConversionRule( [r"layers\.\d+\.attention\.rotary_emb\.inv_freq"], exists="left" ), ConversionRule( [ r"layers\.\d+\.attention\.(?:masked_bias|bias)", ], exists="left", ), ]
def replace_final_norm( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): new_state_dict[new_key] = old_state_dict[old_key] # CS 1.7 has both "ln_f" and "transformer_decoder.norm" # we need to copy the original ("ln_f") too: if from_index == 0: ln_f_key = re.sub(r"transformer_decoder\.norm\.", "ln_f.", new_key) new_state_dict[ln_f_key] = old_state_dict[old_key]
[docs] def pre_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, ): if converter_indices.direction == 0: logging.warning( "{} GPT Neox has a language model head (lm_head) " "while {} GPTNeoxModel does not. Initializing lm_head to default.".format( self.formats()[1], self.formats()[0] ) ) # Manually tie weights if ( converter_indices.direction == 1 and configs[1]["model"]["share_embedding_weights"] ): if ( old_state_dict.get("embedding_layer.word_embeddings.weight", 0) is None ): old_state_dict[ "embedding_layer.word_embeddings.weight" ] = old_state_dict["lm_head.weight"]
[docs] def post_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, key_prefix="", ): if converter_indices.direction == 0: # We are converting from HF GPTNeoxModel (which is headless) -> CS GPTNeoxModel # (which has a head). We need to create 'lm_head' and init to default values hf_config = configs[0] cs_config = configs[1] use_bias_in_output = cs_config["model"].get( "use_bias_in_output", False ) vocab_size = cs_config["model"]["vocab_size"] embed_dim = cs_config["model"]["hidden_size"] if hf_config["tie_word_embeddings"]: lm_head_weight = old_state_dict['embed_in.weight'] else: lm_head_weight = torch.zeros((vocab_size, embed_dim)) lm_head_weight.normal_(mean=0.0, std=0.02) new_state_dict[key_prefix + "lm_head.weight"] = lm_head_weight if use_bias_in_output: lm_head_bias = torch.zeros(vocab_size) new_state_dict[key_prefix + "lm_head.bias"] = lm_head_bias super().post_model_convert( old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, key_prefix=key_prefix, )
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7")) @classmethod def converter_note(cls) -> str: return ( "{} GPTNeoXForCausalLM <-> {} GPTJModel (configured as neox)\n" "The HF model doesn't contain a language model head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a language model head initialized to default random " "values. When converting to HF, the language model head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1]) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS17
[docs]class Converter_GPT_Neox_Headless_HF_CS18(Converter_GPT_Neox_Headless_HF_CS17):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_GPT_Neox_Headless_HF_CS17(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_GPT_Neox_Headless_HF_CS17(), ], action=None, ), ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9")) @classmethod def converter_note(cls) -> str: return ( "{} GPTNeoXForCausalLM <-> {} GPTJModel (configured as neox)\n" "The HF model doesn't contain a language model head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a language model head initialized to default random " "values. When converting to HF, the language model head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1]) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS18
[docs]class Converter_GPT_Neox_LMHeadModel_HF_CS17(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("embed_out", "lm_head"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("gpt_neox.", ""), Converter_GPT_Neox_Headless_HF_CS17(), ], action=None, ), ]
[docs] def pre_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, ): # Manually tie weights if ( converter_indices.direction == 1 and configs[1]["model"]["share_embedding_weights"] ): if ( old_state_dict.get("embedding_layer.word_embeddings.weight", 0) is None ): old_state_dict[ "embedding_layer.word_embeddings.weight" ] = old_state_dict["lm_head.weight"]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7")) @classmethod def converter_note(cls) -> str: return "{} GPTNeoXForCausalLM <-> {} GPTJModel (configured as neox) with LM head".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS17
[docs]class Converter_GPT_Neox_LMHeadModel_HF_CS18(BaseCheckpointConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_GPT_Neox_LMHeadModel_HF_CS17(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_GPT_Neox_LMHeadModel_HF_CS17(), ], action=None, ), ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9")) @classmethod def converter_note(cls) -> str: return "{} GPTNeoXForCausalLM <-> {} GPTJModel (configured as neox) with LM head".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS18
[docs]class ConfigConverter_GPT_Neox_HF_CS17(BaseConfigConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( ["model_type"], action=BaseConfigConverter.assert_factory_fn(0, "gpt_neox"), ), # Embedding ConversionRule(["vocab_size"], action=self.replaceKey), ConversionRule( [EquivalentSubkey("rotary", "position_embedding_type")], action=BaseConfigConverter.assert_factory_fn(1, "rotary"), ), ConversionRule( ["use_position_embedding"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["embedding_dropout_rate"], action=BaseConfigConverter.assert_factory_fn(1, 0.0), ), ConversionRule( [ EquivalentSubkey( "tie_word_embeddings", "share_embedding_weights" ) ], action=self.replaceKey, ), # Decoder Block ConversionRule( ["hidden_size"], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("num_attention_heads", "num_heads")], action=self.replaceKey, ), ConversionRule( ["num_hidden_layers"], action=self.replaceKey, ), ConversionRule( ["max_position_embeddings"], action=self.replaceKey, ), ConversionRule( ["scale_attn_weights"], action=BaseConfigConverter.assert_factory_fn(0, True), ), ConversionRule( ["attention_type"], action=BaseConfigConverter.assert_factory_fn( 1, "scaled_dot_product" ), ), ConversionRule( ["use_projection_bias_in_attention"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ffn_bias"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( [EquivalentSubkey("intermediate_size", "filter_size")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("hidden_act", "nonlinearity")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("layer_norm_eps", "layer_norm_epsilon")], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey( "attention_dropout", "attention_dropout_rate" ) ], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("hidden_dropout", "residual_dropout_rate")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("rotary_pct", "rotary_dim")], action=self.rotary_dim_converter, ), ConversionRule( ["rotary_emb_base"], action=BaseConfigConverter.assert_factory_fn(0, 10000), ), ConversionRule( ["use_bias_in_output"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["use_parallel_residual"], action=BaseConfigConverter.assert_factory_fn(0, True), ), ConversionRule(["initializer_range"], action=self.replaceKey), ConversionRule( ["embedding_layer_norm"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["fixed_sparse_attention"], action=BaseConfigConverter.assert_factory_fn(1, None), ), ConversionRule( ["norm_first"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ff_layer1_dropout"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["use_untied_layer_norm"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ] self.pre_convert_defaults[0].update( { "vocab_size": 50432, "hidden_size": 6144, "num_hidden_layers": 44, "num_attention_heads": 64, "intermediate_size": 24576, "hidden_act": "gelu", "rotary_pct": 0.25, "rotary_emb_base": 10000, "max_position_embeddings": 2048, "initializer_range": 0.02, "layer_norm_eps": 1e-5, "tie_word_embeddings": False, "use_parallel_residual": True, } ) self.pre_convert_defaults[1].update( { "max_position_embeddings": 1024, "embedding_dropout_rate": 0.1, "share_embedding_weights": True, "residual_dropout_rate": 0.1, "nonlinearity": "gelu", "layer_norm_epsilon": 1.0e-5, "use_ffn_bias": False, "use_untied_layer_norm": False, "attention_dropout_rate": 0.1, "use_projection_bias_in_attention": True, "use_ffn_bias_in_attention": True, "initializer_range": 0.02, "use_bias_in_output": False, "norm_first": True, }, ) self.post_convert_defaults[0].update( { "rotary_pct": 1.0, "rotary_emb_base": 10000, "model_type": "gpt_neox", }, ) self.post_convert_defaults[1].update( { "attention_type": "scaled_dot_product", "use_untied_layer_norm": True, "use_projection_bias_in_attention": True, "use_ffn_bias_in_attention": True, "use_ffn_bias": True, "use_bias_in_output": False, "embedding_dropout_rate": 0.0, "residual_dropout_rate": 0.0, "attention_dropout_rate": 0.0, }, )
def rotary_dim_converter( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: new_state_dict[new_key] = int( ( old_state_dict["hidden_size"] // old_state_dict["num_attention_heads"] ) * old_state_dict[old_key] ) else: head_size = ( old_state_dict["hidden_size"] // old_state_dict["num_heads"] ) new_state_dict[new_key] = old_state_dict[old_key] / head_size @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.7"))
[docs]class ConfigConverter_GPT_Neox_HF_CS18(ConfigConverter_GPT_Neox_HF_CS17): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.8", "cs-1.9"))
[docs]class Converter_GPT_Neox_LMHeadModel_CS18_CS20( Converter_GPTJ_LMHeadModel_CS18_CS20 ): r""" NeoX uses the GPTJ backbone """ @classmethod def converter_note(cls) -> str: return "GPTJModel (configured as neox)" @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_Headless_CS18_CS20
[docs]class ConfigConverter_GPT_Neox_Headless_CS18_CS20( ConfigConverter_GPTJModel_CS18_CS20 ): r""" NeoX uses the GPTJ backbone """
[docs]class Converter_GPT_Neox_Headless_HF_CS20(Converter_GPT_Neox_Headless_HF_CS18): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.0")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS20
[docs]class Converter_GPT_Neox_LMHeadModel_HF_CS20( Converter_GPT_Neox_LMHeadModel_HF_CS18 ): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.0")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS20
[docs]class ConfigConverter_GPT_Neox_HF_CS20(ConfigConverter_GPT_Neox_HF_CS18):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( ["norm_type"], action=BaseConfigConverter.assert_factory_fn(1, "layernorm"), ), *self.rules, ]
@staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.0"))
########################################################### # In CS 2.1, we refactored the embedding layer. # CS 2.0 <> CS 2.1. We don't need a separate HF <> CS 2.1 converters since # HF only supports RoPE which doesn't produce any checkpoint keys. ###########################################################
[docs]class Converter_GPT_Neox_LMHeadModel_CS20_CS21( Converter_GPTJ_LMHeadModel_CS20_CS21 ):
[docs] def __init__(self): super().__init__()
@classmethod def converter_note(cls) -> str: return "GPTJLMHeadModel class (configured as GPT-NeoX)"
[docs]class Converter_GPT_Neox_Headless_HF_CS21(Converter_GPT_Neox_Headless_HF_CS20): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS21
[docs]class Converter_GPT_Neox_LMHeadModel_HF_CS21( Converter_GPT_Neox_LMHeadModel_HF_CS20 ): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_GPT_Neox_HF_CS21
[docs]class ConfigConverter_GPT_Neox_HF_CS21(ConfigConverter_GPT_Neox_HF_CS20):
[docs] def __init__(self): super().__init__() self.rules = [ ConversionRule( [EquivalentSubkey("rope_scaling", "pos_scaling_factor")], action=self.convert_pi, ), *self.rules, ] self.pre_convert_defaults[0].update( { "rope_scaling": None, } ) self.pre_convert_defaults[1].update( { "pos_scaling_factor": 1.0, }, )
def convert_pi( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: if old_state_dict[old_key] is None: new_state_dict[new_key] = 1.0 else: scaling_type = old_state_dict[old_key]["type"].lower() if scaling_type != "linear": raise ValueError( f"Only `rope_scaling` type `linear` is currently supported, " f"but got type `{scaling_type}`." ) new_state_dict[new_key] = old_state_dict[old_key]["factor"] else: if old_state_dict[old_key] == 1.0: new_state_dict[new_key] = None else: new_state_dict[new_key] = { "type": "linear", "factor": old_state_dict[old_key], } @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2"))