Source code for cerebras.modelzoo.tools.checkpoint_converters.falcon

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Tuple

from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_HF_CS,
    FormatIndices,
    FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.falcon_7b import (
    ConfigConverter_Falcon_7B_HF_CS19,
    Converter_Falcon_7B_Headless_HF_CS19,
    Converter_Falcon_7B_HF_CS19,
)
from cerebras.modelzoo.tools.checkpoint_converters.falcon_40b import (
    ConfigConverter_Falcon_40B_HF_CS20,
    Converter_Falcon_40B_Headless_HF_CS20,
    Converter_Falcon_40B_HF_CS20,
)
from cerebras.modelzoo.tools.checkpoint_converters.falcon_180b import (
    ConfigConverter_Falcon_180B_HF_CS20,
    ConfigConverter_Falcon_180B_HF_CS21,
    Converter_Falcon_180B_Headless_HF_CS20,
    Converter_Falcon_180B_Headless_HF_CS21,
    Converter_Falcon_180B_HF_CS20,
    Converter_Falcon_180B_HF_CS21,
)
from cerebras.modelzoo.tools.checkpoint_converters.gptj_hf_cs import (
    Converter_GPTJ_LMHeadModel_CS20_CS21,
)


[docs]class Converter_Falcon_Headless_HF_CS20(BaseCheckpointConverter_HF_CS): config2model_subconverters = { ConfigConverter_Falcon_7B_HF_CS19: Converter_Falcon_7B_Headless_HF_CS19, ConfigConverter_Falcon_40B_HF_CS20: Converter_Falcon_40B_Headless_HF_CS20, ConfigConverter_Falcon_180B_HF_CS20: Converter_Falcon_180B_Headless_HF_CS20, }
[docs] def __init__(self): super().__init__() self.rules = []
@classmethod def select_subconverter( cls, config, from_index: int, **kwargs, ): config_subconverter = ( cls.get_config_converter_class().select_subconverter( config, from_index ) ) return cls.config2model_subconverters[config_subconverter] @classmethod def convert(cls, checkpoint, configs, converter_indices, **kwargs): subconverter = cls.select_subconverter( configs[converter_indices.direction], converter_indices.direction ) instance = subconverter() new_checkpoint = instance.convert_helper( checkpoint, configs, converter_indices, **kwargs ) return new_checkpoint @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0")) @classmethod def converter_note(cls) -> str: return ( "{} FalconModel or RWModel <-> {} GPTJModel (configured as Falcon)\n" "The HF model doesn't contain a language model head while the CS " "one does. When converting to CS, the exported checkpoint will " "contain a language model head initialized to default random " "values. When converting to HF, the language model head will be " "dropped." ).format(cls.formats()[0], cls.formats()[1]) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Falcon_HF_CS20
[docs]class Converter_Falcon_HF_CS20(BaseCheckpointConverter_HF_CS): config2model_subconverters = { ConfigConverter_Falcon_7B_HF_CS19: Converter_Falcon_7B_HF_CS19, ConfigConverter_Falcon_40B_HF_CS20: Converter_Falcon_40B_HF_CS20, ConfigConverter_Falcon_180B_HF_CS20: Converter_Falcon_180B_HF_CS20, }
[docs] def __init__(self): super().__init__() self.rules = []
@classmethod def select_subconverter( cls, config, from_index: int, **kwargs, ): config_subconverter = ( cls.get_config_converter_class().select_subconverter( config, from_index ) ) return cls.config2model_subconverters[config_subconverter] @classmethod def convert(cls, checkpoint, configs, converter_indices, **kwargs): subconverter = cls.select_subconverter( configs[converter_indices.direction], converter_indices.direction ) instance = subconverter() new_checkpoint = instance.convert_helper( checkpoint, configs, converter_indices, **kwargs ) return new_checkpoint @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0")) @classmethod def converter_note(cls) -> str: return ( f"{cls.formats()[0]} FalconForCausalLM or RWForCausalLM <-> {cls.formats()[1]} " f"GPTJModel (configured as Falcon) with LM head" ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Falcon_HF_CS20
[docs]class ConfigConverter_Falcon_HF_CS20(BaseConfigConverter_HF_CS):
[docs] def __init__(self): super().__init__() self.rules = []
@classmethod def select_subconverter( cls, config, from_index: int, **kwargs, ): logging.info("HF's Falcon 7B, 40B, and 180B use different codebases.") if from_index == 0: if config.get("model_type", "") == "falcon": logging.info( "The model that you're using was generated using the 180B " "style codebase (model=FalconModel)" ) return ConfigConverter_Falcon_180B_HF_CS20 elif "n_head_kv" not in config: # MQA, 7b structure logging.info( "The model that you're using was generated using the 7B " "style codebase (model=RefinedWeb) which only supports " "multi-query attention (not grouped query)." ) return ConfigConverter_Falcon_7B_HF_CS19 else: # GQA, 40B structure logging.info( "The model that you're using was generated using the 40B " "style codebase (model=RefinedWeb) with grouped query " "attention support" ) return ConfigConverter_Falcon_40B_HF_CS20 else: logging.info( "The output will be formatted for the official 180B style " "codebase (model=FalconModel) rather than the 7B or 40B style " "codebases (model=RefinedWeb)" ) return ConfigConverter_Falcon_180B_HF_CS20 @classmethod def convert( cls, config, converter_indices: FormatIndices, drop_unmatched_keys: bool = False, no_progress_bar: bool = True, debug: bool = False, ): subconverter = cls.select_subconverter( config, converter_indices.direction ) instance = subconverter() return instance.convert_helper( config, converter_indices, drop_unmatched_keys=drop_unmatched_keys, no_progress_bar=no_progress_bar, debug=debug, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9", "cs-2.0"))
########################################################### # In CS 2.1, we refactored the embedding layer. # CS 2.0 <> CS 2.1, and HF <> CS 2.1 converters: ###########################################################
[docs]class Converter_Falcon_CS20_CS21(Converter_GPTJ_LMHeadModel_CS20_CS21):
[docs] def __init__(self): super().__init__()
@classmethod def converter_note(cls) -> str: return "GPT2LMHeadModel class (configured as falcon)"
[docs]class ConfigConverter_Falcon_HF_CS21(ConfigConverter_Falcon_HF_CS20): @classmethod def select_subconverter( cls, config, from_index: int, **kwargs, ): sub_converter = super().select_subconverter( config, from_index, **kwargs ) # Only CS21 is different because others don't support alibi if sub_converter == ConfigConverter_Falcon_180B_HF_CS20: sub_converter = ConfigConverter_Falcon_180B_HF_CS21 return sub_converter @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2"))
[docs]class Converter_Falcon_Headless_HF_CS21(Converter_Falcon_Headless_HF_CS20): config2model_subconverters = { ConfigConverter_Falcon_7B_HF_CS19: Converter_Falcon_7B_Headless_HF_CS19, ConfigConverter_Falcon_40B_HF_CS20: Converter_Falcon_40B_Headless_HF_CS20, ConfigConverter_Falcon_180B_HF_CS21: Converter_Falcon_180B_Headless_HF_CS21, } @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Falcon_HF_CS21
[docs]class Converter_Falcon_HF_CS21(Converter_Falcon_Headless_HF_CS20): config2model_subconverters = { ConfigConverter_Falcon_7B_HF_CS19: Converter_Falcon_7B_HF_CS19, ConfigConverter_Falcon_40B_HF_CS20: Converter_Falcon_40B_HF_CS20, ConfigConverter_Falcon_180B_HF_CS21: Converter_Falcon_180B_HF_CS21, } @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_Falcon_HF_CS21