# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Optional, Tuple
from modelzoo.common.pytorch.model_utils.checkpoint_converters.base_converter import (
BaseCheckpointConverter,
BaseConfigConverter,
ConfigConversionError,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
# CS models may contain an extra 'model.' prefix. During HF -> CS conversion,
# we do not want to output checkpoints with this prefix. In CS -> HF conversion,
# we want to handle both the extra 'model.' prefix and no prefix cases.
[docs]def Build_HF_CS_Converter_WithOptionalModel(
name,
converter,
derived_class,
config_converter_class=None,
formats=None,
converter_note_fn=None,
):
assert issubclass(
derived_class, BaseCheckpointConverter
), "derived_class parameter must be a subclass of BaseCheckpointConverter"
class ConverterWithOptionalModel(derived_class):
def __init__(self) -> None:
super().__init__()
self.rules = [
ConversionRule([converter(),], action=None,),
# If above did not match, try to apply conversion with stripped
# 'model.' prefix
ConversionRule(
[EquivalentSubkey("", "model."), converter(),], action=None,
),
]
ConverterWithOptionalModel.__name__ = name
if config_converter_class:
assert issubclass(
config_converter_class, BaseConfigConverter
), "config_converter_class parameter must be a subclass of BaseConfigConverter"
@staticmethod
def _get_config_converter_class() -> BaseConfigConverter:
return config_converter_class
ConverterWithOptionalModel.get_config_converter_class = (
_get_config_converter_class
)
ConverterWithOptionalModel.__abstractmethods__ = ConverterWithOptionalModel.__abstractmethods__.difference(
{"get_config_converter_class"}
)
if formats:
assert (
isinstance(formats, tuple)
and len(formats) == 2
and all(isinstance(e, FormatVersions) for e in formats)
), "formats argument must be a tuple of two FormatVersions"
@staticmethod
def _formats_fn() -> Tuple[FormatVersions, FormatVersions]:
return formats
ConverterWithOptionalModel.formats = _formats_fn
ConverterWithOptionalModel.__abstractmethods__ = ConverterWithOptionalModel.__abstractmethods__.difference(
{"formats"}
)
if converter_note_fn:
@classmethod
def _converter_note(cls) -> str:
return converter_note_fn(cls)
ConverterWithOptionalModel.converter_note = _converter_note
ConverterWithOptionalModel.__abstractmethods__ = ConverterWithOptionalModel.__abstractmethods__.difference(
{"converter_note"}
)
return ConverterWithOptionalModel
[docs]def convert_use_rms_layer_norm_helper(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
new_state_dict[new_key] = (
"rmsnorm" if old_state_dict[old_key] else "layernorm"
)
else:
if old_state_dict[old_key] == "rmsnorm":
new_state_dict[new_key] = True
elif old_state_dict[old_key] == "layernorm":
new_state_dict[new_key] = False
else:
raise ConfigConversionError(
"{} did not support {}".format(
self.formats()[0], old_state_dict[old_key]
)
)
[docs]def convert_use_biasless_layer_norm_helper(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
new_state_dict[new_key] = (
"biasless-layernorm" if old_state_dict[old_key] else "layernorm"
)
else:
if old_state_dict[old_key] == "biasless-layernorm":
new_state_dict[new_key] = True
elif old_state_dict[old_key] == "layernorm":
new_state_dict[new_key] = False
else:
raise ConfigConversionError(
"{} did not support {}".format(
self.formats()[0], old_state_dict[old_key]
)
)
# Old cstorch checkpoints had a bug where aliased weights would show up as None
# This helper function fixes this by tying old_key and new_key together
# if either one doesn't exist or is None.
[docs]def tie_none_weights(
old_key: str,
new_key: str,
old_state_dict: OrderedDict,
new_state_dict: OrderedDict,
from_index: int,
action_fn_args: Optional[dict] = None,
) -> None:
r"""
Ties weights stored at old_key & new_key
"""
if new_key not in old_state_dict or (
old_state_dict[old_key] is not None and old_state_dict[new_key] is None
):
new_state_dict[old_key] = old_state_dict[old_key]
new_state_dict[new_key] = old_state_dict[old_key]
elif (
old_state_dict[old_key] is None and old_state_dict[new_key] is not None
):
new_state_dict[old_key] = old_state_dict[new_key]
new_state_dict[new_key] = old_state_dict[new_key]
else:
new_state_dict[old_key] = old_state_dict[old_key]
# Ties old_key and new_key if share_embedding_weights is enabled in the config
# (default is enabled)
[docs]def maybe_tie_lm_head(
old_key: str,
new_key: str,
old_state_dict: OrderedDict,
new_state_dict: OrderedDict,
from_index: int,
action_fn_args: Optional[dict] = None,
) -> None:
cs_config = action_fn_args["configs"][1]
if cs_config["model"].get("share_embedding_weights", True):
tie_none_weights(
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
)
else:
new_state_dict[old_key] = old_state_dict[old_key]
[docs]def transpose_key_if_2D(
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# HF checkpoint stores some layers as Conv2D instead of Linear.
# In those cases, we need to transpose the weight matrix for the
# dimensions to line up when converting.
x = old_state_dict[old_key]
if len(x.shape) == 2:
x = x.transpose(0, 1)
new_state_dict[new_key] = x